Skip to content

Commit

Permalink
Add rev14 parameters and fixes. (#561)
Browse files Browse the repository at this point in the history
* Add rev14 parameters

Change-Id: I16f2b1f5820a6cf867b9abb04ffd5c6e6d2d947b

* Fix flakey repr test

Change-Id: I89bcf1494cf72c6ee28f2b52d0345cbb40859862

* format

Change-Id: I81cff23e9ce20cc20b4a0632d557c71f536fd485

* Use client preview

Change-Id: I2d8a4ee2e9e4b6e00a16a9dac1136a2fa18d7a28

* Fix tests

Change-Id: If8fbbba1966aa42601adec877e60d851d4f03b72

* Fix tuned model tests

Change-Id: I5ace9222954be7d903ebbdabab9efc663fa79174

* Fix tests

Change-Id: Ifa610965c5d6c38123080a7e16416ac325418285

* format

Change-Id: I15fd5701dd5c4200461a32c968fa19e375403a7e

* pytype

Change-Id: I08f74d08c4e93bbfdf353370b5dd57d8bf86a637

* pytype

Change-Id: If81b86c176008cd9a99e3b879fbd3af086ec2235

* 3.9 tests

Change-Id: I13e66016327aae0b0f3274e941bc615f379e5669
  • Loading branch information
MarkDaoust authored Sep 23, 2024
1 parent 4f42118 commit 36e001a
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 72 deletions.
29 changes: 26 additions & 3 deletions google/generativeai/types/generation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,27 @@ class GenerationConfig:
Note: The default value varies by model, see the
`Model.top_k` attribute of the `Model` returned the
`genai.get_model` function.
seed:
Optional. Seed used in decoding. If not set, the request uses a randomly generated seed.
response_mime_type:
Optional. Output response mimetype of the generated candidate text.
Supported mimetype:
`text/plain`: (default) Text output.
`text/x-enum`: for use with a string-enum in `response_schema`
`application/json`: JSON response in the candidates.
response_schema:
Optional. Specifies the format of the JSON requested if response_mime_type is
`application/json`.
presence_penalty:
Optional.
frequency_penalty:
Optional.
response_logprobs:
Optional. If true, export the `logprobs` results in response.
logprobs:
Optional. Number of candidates of log probabilities to return at each step of decoding.
"""

candidate_count: int | None = None
Expand All @@ -163,8 +173,13 @@ class GenerationConfig:
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
seed: int | None = None
response_mime_type: str | None = None
response_schema: protos.Schema | Mapping[str, Any] | type | None = None
presence_penalty: float | None = None
frequency_penalty: float | None = None
response_logprobs: bool | None = None
logprobs: int | None = None


GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig]
Expand Down Expand Up @@ -306,6 +321,7 @@ def _join_code_execution_result(result_1, result_2):


def _join_candidates(candidates: Iterable[protos.Candidate]):
"""Joins stream chunks of a single candidate."""
candidates = tuple(candidates)

index = candidates[0].index # These should all be the same.
Expand All @@ -321,6 +337,7 @@ def _join_candidates(candidates: Iterable[protos.Candidate]):


def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]):
"""Joins stream chunks where each chunk is a list of candidate chunks."""
# Assuming that is a candidate ends, it is no longer returned in the list of
# candidates and that's why candidates have an index
candidates = collections.defaultdict(list)
Expand All @@ -344,10 +361,15 @@ def _join_prompt_feedbacks(

def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]):
chunks = tuple(chunks)
if "usage_metadata" in chunks[-1]:
usage_metadata = chunks[-1].usage_metadata
else:
usage_metadata = None

return protos.GenerateContentResponse(
candidates=_join_candidate_lists(c.candidates for c in chunks),
prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
usage_metadata=chunks[-1].usage_metadata,
usage_metadata=usage_metadata,
)


Expand Down Expand Up @@ -541,7 +563,8 @@ def __str__(self) -> str:
_result = _result.replace("\n", "\n ")

if self._error:
_error = f",\nerror=<{self._error.__class__.__name__}> {self._error}"

_error = f",\nerror={repr(self._error)}"
else:
_error = ""

Expand Down
5 changes: 4 additions & 1 deletion google/generativeai/types/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def idecode_time(parent: dict["str", Any], name: str):

def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel:
if isinstance(tuned_model, protos.TunedModel):
tuned_model = type(tuned_model).to_dict(tuned_model) # pytype: disable=attribute-error
tuned_model = type(tuned_model).to_dict(
tuned_model, including_default_value_fields=False
) # pytype: disable=attribute-error
tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None))

base_model = tuned_model.pop("base_model", None)
Expand Down Expand Up @@ -195,6 +197,7 @@ class TunedModel:
create_time: datetime.datetime | None = None
update_time: datetime.datetime | None = None
tuning_task: TuningTask | None = None
reader_project_numbers: list[int] | None = None

@property
def permissions(self) -> permission_types.Permissions:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_version():
release_status = "Development Status :: 5 - Production/Stable"

dependencies = [
"google-ai-generativelanguage==0.6.9",
"google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz",
"google-api-core",
"google-api-python-client",
"google-auth>=2.15.0", # 2.15 adds API key auth support
Expand Down
12 changes: 7 additions & 5 deletions tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from google.generativeai.types import file_types

import collections
import datetime
import os
from typing import Iterable, Union
from typing import Iterable, Sequence
import pathlib

import google
Expand All @@ -37,12 +38,13 @@ def __init__(self, test):

def create_file(
self,
path: Union[str, pathlib.Path, os.PathLike],
path: str | pathlib.Path | os.PathLike,
*,
mime_type: Union[str, None] = None,
name: Union[str, None] = None,
display_name: Union[str, None] = None,
mime_type: str | None = None,
name: str | None = None,
display_name: str | None = None,
resumable: bool = True,
metadata: Sequence[tuple[str, str]] = (),
) -> protos.File:
self.observed_requests.append(
dict(
Expand Down
53 changes: 37 additions & 16 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
import string
import textwrap
from typing_extensions import TypedDict
Expand All @@ -22,6 +38,8 @@ class Person(TypedDict):


class UnitTests(parameterized.TestCase):
maxDiff = None

@parameterized.named_parameters(
[
"protos.GenerationConfig",
Expand Down Expand Up @@ -416,24 +434,16 @@ def test_join_prompt_feedbacks(self):
],
"role": "assistant",
},
"citation_metadata": {"citation_sources": []},
"index": 0,
"finish_reason": 0,
"safety_ratings": [],
"token_count": 0,
"grounding_attributions": [],
"citation_metadata": {},
},
{
"content": {
"parts": [{"text": "Tell me a story about a magic backpack"}],
"role": "assistant",
},
"index": 1,
"citation_metadata": {"citation_sources": []},
"finish_reason": 0,
"safety_ratings": [],
"token_count": 0,
"grounding_attributions": [],
"citation_metadata": {},
},
{
"content": {
Expand All @@ -458,17 +468,16 @@ def test_join_prompt_feedbacks(self):
},
]
},
"finish_reason": 0,
"safety_ratings": [],
"token_count": 0,
"grounding_attributions": [],
},
]

def test_join_candidates(self):
candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS]
result = generation_types._join_candidate_lists(candidate_lists)
self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r) for r in result])
self.assertEqual(
self.MERGED_CANDIDATES,
[type(r).to_dict(r, including_default_value_fields=False) for r in result],
)

def test_join_chunks(self):
chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS]
Expand All @@ -480,6 +489,10 @@ def test_join_chunks(self):
],
)

chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata(
prompt_token_count=5
)

result = generation_types._join_chunks(chunks)

expected = protos.GenerateContentResponse(
Expand All @@ -495,10 +508,18 @@ def test_join_chunks(self):
}
],
},
"usage_metadata": {"prompt_token_count": 5},
},
)

self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(expected))
expected = json.dumps(
type(expected).to_dict(expected, including_default_value_fields=False), indent=4
)
result = json.dumps(
type(result).to_dict(result, including_default_value_fields=False), indent=4
)

self.assertEqual(expected, result)

def test_generate_content_response_iterator_end_to_end(self):
chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS]
Expand Down
53 changes: 7 additions & 46 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,8 +935,7 @@ def test_repr_for_streaming_start_to_finish(self):
"citation_metadata": {}
}
],
"prompt_feedback": {},
"usage_metadata": {}
"prompt_feedback": {}
}),
)"""
)
Expand Down Expand Up @@ -964,8 +963,7 @@ def test_repr_for_streaming_start_to_finish(self):
"citation_metadata": {}
}
],
"prompt_feedback": {},
"usage_metadata": {}
"prompt_feedback": {}
}),
)"""
)
Expand Down Expand Up @@ -998,10 +996,10 @@ def test_repr_error_info_for_stream_prompt_feedback_blocked(self):
}
}),
),
error=<BlockedPromptException> prompt_feedback {
error=BlockedPromptException(prompt_feedback {
block_reason: SAFETY
}
"""
)"""
)
self.assertEqual(expected, result)

Expand Down Expand Up @@ -1056,11 +1054,10 @@ def no_throw():
"citation_metadata": {}
}
],
"prompt_feedback": {},
"usage_metadata": {}
"prompt_feedback": {}
}),
),
error=<ValueError> """
error=ValueError()"""
)
self.assertEqual(expected, result)

Expand Down Expand Up @@ -1095,43 +1092,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self):
response = chat.send_message("hello2", stream=True)

result = repr(response)
expected = textwrap.dedent(
"""\
response:
GenerateContentResponse(
done=True,
iterator=None,
result=protos.GenerateContentResponse({
"candidates": [
{
"content": {
"parts": [
{
"text": "abc"
}
]
},
"finish_reason": "SAFETY",
"index": 0,
"citation_metadata": {}
}
],
"prompt_feedback": {},
"usage_metadata": {}
}),
),
error=<StopCandidateException> content {
parts {
text: "abc"
}
}
finish_reason: SAFETY
index: 0
citation_metadata {
}
"""
)
self.assertEqual(expected, result)
self.assertIn("StopCandidateException", result)

def test_repr_for_multi_turn_chat(self):
# Multi turn chat
Expand Down

0 comments on commit 36e001a

Please sign in to comment.