diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index 23e7fb1d8..8bd0a7736 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -144,17 +144,27 @@ class GenerationConfig: Note: The default value varies by model, see the `Model.top_k` attribute of the `Model` returned the `genai.get_model` function. - + seed: + Optional. Seed used in decoding. If not set, the request uses a randomly generated seed. response_mime_type: Optional. Output response mimetype of the generated candidate text. Supported mimetype: `text/plain`: (default) Text output. + `text/x-enum`: for use with a string-enum in `response_schema` `application/json`: JSON response in the candidates. response_schema: Optional. Specifies the format of the JSON requested if response_mime_type is `application/json`. + presence_penalty: + Optional. + frequency_penalty: + Optional. + response_logprobs: + Optional. If true, export the `logprobs` results in response. + logprobs: + Optional. Number of candidates of log probabilities to return at each step of decoding. """ candidate_count: int | None = None @@ -163,8 +173,13 @@ class GenerationConfig: temperature: float | None = None top_p: float | None = None top_k: int | None = None + seed: int | None = None response_mime_type: str | None = None response_schema: protos.Schema | Mapping[str, Any] | type | None = None + presence_penalty: float | None = None + frequency_penalty: float | None = None + response_logprobs: bool | None = None + logprobs: int | None = None GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig] @@ -306,6 +321,7 @@ def _join_code_execution_result(result_1, result_2): def _join_candidates(candidates: Iterable[protos.Candidate]): + """Joins stream chunks of a single candidate.""" candidates = tuple(candidates) index = candidates[0].index # These should all be the same. @@ -321,6 +337,7 @@ def _join_candidates(candidates: Iterable[protos.Candidate]): def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]): + """Joins stream chunks where each chunk is a list of candidate chunks.""" # Assuming that is a candidate ends, it is no longer returned in the list of # candidates and that's why candidates have an index candidates = collections.defaultdict(list) @@ -344,10 +361,15 @@ def _join_prompt_feedbacks( def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]): chunks = tuple(chunks) + if "usage_metadata" in chunks[-1]: + usage_metadata = chunks[-1].usage_metadata + else: + usage_metadata = None + return protos.GenerateContentResponse( candidates=_join_candidate_lists(c.candidates for c in chunks), prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks), - usage_metadata=chunks[-1].usage_metadata, + usage_metadata=usage_metadata, ) @@ -541,7 +563,8 @@ def __str__(self) -> str: _result = _result.replace("\n", "\n ") if self._error: - _error = f",\nerror=<{self._error.__class__.__name__}> {self._error}" + + _error = f",\nerror={repr(self._error)}" else: _error = "" diff --git a/google/generativeai/types/model_types.py b/google/generativeai/types/model_types.py index 03922a64e..ff66d6339 100644 --- a/google/generativeai/types/model_types.py +++ b/google/generativeai/types/model_types.py @@ -143,7 +143,9 @@ def idecode_time(parent: dict["str", Any], name: str): def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel: if isinstance(tuned_model, protos.TunedModel): - tuned_model = type(tuned_model).to_dict(tuned_model) # pytype: disable=attribute-error + tuned_model = type(tuned_model).to_dict( + tuned_model, including_default_value_fields=False + ) # pytype: disable=attribute-error tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None)) base_model = tuned_model.pop("base_model", None) @@ -195,6 +197,7 @@ class TunedModel: create_time: datetime.datetime | None = None update_time: datetime.datetime | None = None tuning_task: TuningTask | None = None + reader_project_numbers: list[int] | None = None @property def permissions(self) -> permission_types.Permissions: diff --git a/setup.py b/setup.py index 29841ba1d..0575dcd28 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def get_version(): release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-ai-generativelanguage==0.6.9", + "google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz", "google-api-core", "google-api-python-client", "google-auth>=2.15.0", # 2.15 adds API key auth support diff --git a/tests/test_files.py b/tests/test_files.py index 063f1ce3a..cb48316bd 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from google.generativeai.types import file_types import collections import datetime import os -from typing import Iterable, Union +from typing import Iterable, Sequence import pathlib import google @@ -37,12 +38,13 @@ def __init__(self, test): def create_file( self, - path: Union[str, pathlib.Path, os.PathLike], + path: str | pathlib.Path | os.PathLike, *, - mime_type: Union[str, None] = None, - name: Union[str, None] = None, - display_name: Union[str, None] = None, + mime_type: str | None = None, + name: str | None = None, + display_name: str | None = None, resumable: bool = True, + metadata: Sequence[tuple[str, str]] = (), ) -> protos.File: self.observed_requests.append( dict( diff --git a/tests/test_generation.py b/tests/test_generation.py index 0cc3bfd07..a1461e8b5 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -1,4 +1,20 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect +import json import string import textwrap from typing_extensions import TypedDict @@ -22,6 +38,8 @@ class Person(TypedDict): class UnitTests(parameterized.TestCase): + maxDiff = None + @parameterized.named_parameters( [ "protos.GenerationConfig", @@ -416,12 +434,8 @@ def test_join_prompt_feedbacks(self): ], "role": "assistant", }, - "citation_metadata": {"citation_sources": []}, "index": 0, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [], + "citation_metadata": {}, }, { "content": { @@ -429,11 +443,7 @@ def test_join_prompt_feedbacks(self): "role": "assistant", }, "index": 1, - "citation_metadata": {"citation_sources": []}, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [], + "citation_metadata": {}, }, { "content": { @@ -458,17 +468,16 @@ def test_join_prompt_feedbacks(self): }, ] }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [], }, ] def test_join_candidates(self): candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] result = generation_types._join_candidate_lists(candidate_lists) - self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r) for r in result]) + self.assertEqual( + self.MERGED_CANDIDATES, + [type(r).to_dict(r, including_default_value_fields=False) for r in result], + ) def test_join_chunks(self): chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] @@ -480,6 +489,10 @@ def test_join_chunks(self): ], ) + chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata( + prompt_token_count=5 + ) + result = generation_types._join_chunks(chunks) expected = protos.GenerateContentResponse( @@ -495,10 +508,18 @@ def test_join_chunks(self): } ], }, + "usage_metadata": {"prompt_token_count": 5}, }, ) - self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(expected)) + expected = json.dumps( + type(expected).to_dict(expected, including_default_value_fields=False), indent=4 + ) + result = json.dumps( + type(result).to_dict(result, including_default_value_fields=False), indent=4 + ) + + self.assertEqual(expected, result) def test_generate_content_response_iterator_end_to_end(self): chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 79c1ac36f..fa69099ba 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -935,8 +935,7 @@ def test_repr_for_streaming_start_to_finish(self): "citation_metadata": {} } ], - "prompt_feedback": {}, - "usage_metadata": {} + "prompt_feedback": {} }), )""" ) @@ -964,8 +963,7 @@ def test_repr_for_streaming_start_to_finish(self): "citation_metadata": {} } ], - "prompt_feedback": {}, - "usage_metadata": {} + "prompt_feedback": {} }), )""" ) @@ -998,10 +996,10 @@ def test_repr_error_info_for_stream_prompt_feedback_blocked(self): } }), ), - error= prompt_feedback { + error=BlockedPromptException(prompt_feedback { block_reason: SAFETY } - """ + )""" ) self.assertEqual(expected, result) @@ -1056,11 +1054,10 @@ def no_throw(): "citation_metadata": {} } ], - "prompt_feedback": {}, - "usage_metadata": {} + "prompt_feedback": {} }), ), - error= """ + error=ValueError()""" ) self.assertEqual(expected, result) @@ -1095,43 +1092,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): response = chat.send_message("hello2", stream=True) result = repr(response) - expected = textwrap.dedent( - """\ - response: - GenerateContentResponse( - done=True, - iterator=None, - result=protos.GenerateContentResponse({ - "candidates": [ - { - "content": { - "parts": [ - { - "text": "abc" - } - ] - }, - "finish_reason": "SAFETY", - "index": 0, - "citation_metadata": {} - } - ], - "prompt_feedback": {}, - "usage_metadata": {} - }), - ), - error= content { - parts { - text: "abc" - } - } - finish_reason: SAFETY - index: 0 - citation_metadata { - } - """ - ) - self.assertEqual(expected, result) + self.assertIn("StopCandidateException", result) def test_repr_for_multi_turn_chat(self): # Multi turn chat