From a0df599900ae5cc10fa33d50d8847fc9f5d7385d Mon Sep 17 00:00:00 2001 From: Qiao Wang Date: Mon, 21 Oct 2024 16:11:54 -0700 Subject: [PATCH] feat: Add generation_config to count_tokens PiperOrigin-RevId: 688307166 --- .../system/vertexai/test_generative_models.py | 15 +++++++ .../generative_models/_generative_models.py | 41 ++++++++++++++----- vertexai/tokenization/_tokenizers.py | 28 ++++++++++++- 3 files changed, 72 insertions(+), 12 deletions(-) diff --git a/tests/system/vertexai/test_generative_models.py b/tests/system/vertexai/test_generative_models.py index 2f05d659c8..179d5c10c4 100644 --- a/tests/system/vertexai/test_generative_models.py +++ b/tests/system/vertexai/test_generative_models.py @@ -596,3 +596,18 @@ def test_count_tokens_from_text(self): response_with_si_and_tool.total_billable_characters > response_with_si.total_billable_characters ) + # content + generation_config + response_with_generation_config = model.count_tokens( + content, + generation_config=generative_models.GenerationConfig( + response_schema=_RESPONSE_SCHEMA_STRUCT + ), + ) + assert ( + response_with_generation_config.total_tokens + > response_with_si_and_tool.total_tokens + ) + assert ( + response_with_generation_config.total_billable_characters + > response_with_si_and_tool.total_billable_characters + ) diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index 805123bc1d..91c582d387 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -163,7 +163,7 @@ def _validate_generate_content_parameters( labels: Optional[Dict[str, str]] = None, ) -> None: """Validates the parameters for a generate_content call.""" - if not contents: + if not contents and not isinstance(contents, str): raise TypeError("contents must not be empty") _validate_contents_type_as_valid_sequence(contents) @@ -482,14 +482,7 @@ def _prepare_request( gapic_generation_config: Optional[gapic_content_types.GenerationConfig] = None if generation_config: - if isinstance(generation_config, gapic_content_types.GenerationConfig): - gapic_generation_config = generation_config - elif isinstance(generation_config, GenerationConfig): - gapic_generation_config = generation_config._raw_generation_config - elif isinstance(generation_config, Dict): - gapic_generation_config = gapic_content_types.GenerationConfig( - **generation_config - ) + gapic_generation_config = _to_generation_config(generation_config) gapic_safety_settings = None if safety_settings: @@ -872,7 +865,11 @@ async def async_generator(): return async_generator() def count_tokens( - self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None + self, + contents: ContentsType, + *, + tools: Optional[List["Tool"]] = None, + generation_config: Optional[GenerationConfigType] = None, ) -> gapic_prediction_service_types.CountTokensResponse: """Counts tokens. @@ -885,6 +882,7 @@ def count_tokens( * List[Union[str, Image, Part]], * List[Content] tools: A list of tools (functions) that the model can try calling. + generation_config: Parameters for the generate_content method. Returns: A CountTokensResponse object that has the following attributes: @@ -894,12 +892,14 @@ def count_tokens( request = self._prepare_request( contents=contents, tools=tools, + generation_config=generation_config, ) return self._gapic_count_tokens( prediction_resource_name=self._prediction_resource_name, contents=request.contents, system_instruction=request.system_instruction, tools=request.tools, + generation_config=request.generation_config, ) async def count_tokens_async( @@ -907,6 +907,7 @@ async def count_tokens_async( contents: ContentsType, *, tools: Optional[List["Tool"]] = None, + generation_config: Optional[GenerationConfigType] = None, ) -> gapic_prediction_service_types.CountTokensResponse: """Counts tokens asynchronously. @@ -919,6 +920,7 @@ async def count_tokens_async( * List[Union[str, Image, Part]], * List[Content] tools: A list of tools (functions) that the model can try calling. + generation_config: Parameters for the generate_content method. Returns: And awaitable for a CountTokensResponse object that has the following attributes: @@ -928,12 +930,14 @@ async def count_tokens_async( request = self._prepare_request( contents=contents, tools=tools, + generation_config=generation_config, ) return await self._gapic_count_tokens_async( prediction_resource_name=self._prediction_resource_name, contents=request.contents, system_instruction=request.system_instruction, tools=request.tools, + generation_config=request.generation_config, ) def _gapic_count_tokens( @@ -942,6 +946,7 @@ def _gapic_count_tokens( contents: List[gapic_content_types.Content], system_instruction: Optional[gapic_content_types.Content] = None, tools: Optional[List[gapic_tool_types.Tool]] = None, + generation_config: Optional[gapic_content_types.GenerationConfig] = None, ) -> gapic_prediction_service_types.CountTokensResponse: request = gapic_prediction_service_types.CountTokensRequest( endpoint=prediction_resource_name, @@ -949,6 +954,7 @@ def _gapic_count_tokens( contents=contents, system_instruction=system_instruction, tools=tools, + generation_config=generation_config, ) return self._prediction_client.count_tokens(request=request) @@ -958,6 +964,7 @@ async def _gapic_count_tokens_async( contents: List[gapic_content_types.Content], system_instruction: Optional[gapic_content_types.Content] = None, tools: Optional[List[gapic_tool_types.Tool]] = None, + generation_config: Optional[gapic_content_types.GenerationConfig] = None, ) -> gapic_prediction_service_types.CountTokensResponse: request = gapic_prediction_service_types.CountTokensRequest( endpoint=prediction_resource_name, @@ -965,6 +972,7 @@ async def _gapic_count_tokens_async( contents=contents, system_instruction=system_instruction, tools=tools, + generation_config=generation_config, ) return await self._prediction_async_client.count_tokens(request=request) @@ -2790,6 +2798,18 @@ def _to_content( return gapic_content_types.Content(parts=parts, role=role) +def _to_generation_config( + generation_config: GenerationConfigType, +) -> gapic_content_types.GenerationConfig: + """Converts generation config to gapic_prediction_service_types.GenerationConfig object.""" + if isinstance(generation_config, gapic_content_types.GenerationConfig): + return generation_config + elif isinstance(generation_config, GenerationConfig): + return generation_config._raw_generation_config + elif isinstance(generation_config, Dict): + return gapic_content_types.GenerationConfig(**generation_config) + + def _append_response( base_response: GenerationResponse, new_response: GenerationResponse, @@ -3186,6 +3206,7 @@ def _gapic_count_tokens( contents: List[types_v1.Content], system_instruction: Optional[types_v1.Content] = None, tools: Optional[List[types_v1.Tool]] = None, + generation_config: Optional[types_v1.GenerationConfig] = None, ) -> types_v1.CountTokensResponse: request = types_v1.CountTokensRequest( endpoint=prediction_resource_name, diff --git a/vertexai/tokenization/_tokenizers.py b/vertexai/tokenization/_tokenizers.py index 17e2b5b823..45a44a3096 100644 --- a/vertexai/tokenization/_tokenizers.py +++ b/vertexai/tokenization/_tokenizers.py @@ -27,9 +27,12 @@ Image, Tool, PartsType, + GenerationConfigType, + _validate_generate_content_parameters, _validate_contents_type_as_valid_sequence, _content_types_to_gapic_contents, _to_content, + _to_generation_config, ) from vertexai.tokenization._tokenizer_loading import ( @@ -172,7 +175,6 @@ def _to_gapic_contents( _validate_contents_type_as_valid_sequence(contents) _assert_no_image_contents_type(contents) gapic_contents = _content_types_to_gapic_contents(contents) - # _assert_text_only_content_types_sequence(gapic_contents) return gapic_contents @@ -325,6 +327,18 @@ def add_function_response( f"Function response argument contains unsupported types for token counting. Supported fields {counted_function_response}. Got {function_response}." ) + def add_generation_config( + self, generation_config: gapic_content_types.GenerationConfig + ) -> None: + if generation_config.response_schema: + counted_generation_config = self._schema_traverse( + generation_config._pb.response_schema + ) + if counted_generation_config._pb != generation_config._pb: + raise ValueError( + f"Generation config argument contains unsupported types for token counting. Supported fields {counted_generation_config}. Got {generation_config}." + ) + def _function_declaration_traverse( self, function_declaration: gapic_tool_types.FunctionDeclaration ) -> gapic_tool_types.FunctionDeclaration: @@ -450,6 +464,7 @@ def count_tokens( *, tools: Optional[List["Tool"]] = None, system_instruction: Optional[PartsType] = None, + generation_config: Optional[GenerationConfigType] = None, ) -> CountTokensResult: r"""Counts the number of tokens in the text-only contents. @@ -472,7 +487,12 @@ def count_tokens( A CountTokensResult object containing the total number of tokens in the contents. """ - + _validate_generate_content_parameters( + contents, + tools=tools, + system_instruction=system_instruction, + generation_config=generation_config, + ) text_accumulator = _TextsAccumulator() if _is_string_inputs(contents): text_accumulator.add_texts(contents) @@ -490,6 +510,10 @@ def count_tokens( else: text_accumulator.add_content(_to_content(system_instruction)) + if generation_config: + canonical_generation_config = _to_generation_config(generation_config) + text_accumulator.add_generation_config(canonical_generation_config) + return self._sentencepiece_adapter.count_tokens(text_accumulator.get_texts()) def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult: