Skip to content

Commit

Permalink
feat: Add generation_config to count_tokens
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688307166
  • Loading branch information
happy-qiao authored and copybara-github committed Oct 22, 2024
1 parent f713417 commit a0df599
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 12 deletions.
15 changes: 15 additions & 0 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,18 @@ def test_count_tokens_from_text(self):
response_with_si_and_tool.total_billable_characters
> response_with_si.total_billable_characters
)
# content + generation_config
response_with_generation_config = model.count_tokens(
content,
generation_config=generative_models.GenerationConfig(
response_schema=_RESPONSE_SCHEMA_STRUCT
),
)
assert (
response_with_generation_config.total_tokens
> response_with_si_and_tool.total_tokens
)
assert (
response_with_generation_config.total_billable_characters
> response_with_si_and_tool.total_billable_characters
)
41 changes: 31 additions & 10 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _validate_generate_content_parameters(
labels: Optional[Dict[str, str]] = None,
) -> None:
"""Validates the parameters for a generate_content call."""
if not contents:
if not contents and not isinstance(contents, str):
raise TypeError("contents must not be empty")

_validate_contents_type_as_valid_sequence(contents)
Expand Down Expand Up @@ -482,14 +482,7 @@ def _prepare_request(

gapic_generation_config: Optional[gapic_content_types.GenerationConfig] = None
if generation_config:
if isinstance(generation_config, gapic_content_types.GenerationConfig):
gapic_generation_config = generation_config
elif isinstance(generation_config, GenerationConfig):
gapic_generation_config = generation_config._raw_generation_config
elif isinstance(generation_config, Dict):
gapic_generation_config = gapic_content_types.GenerationConfig(
**generation_config
)
gapic_generation_config = _to_generation_config(generation_config)

gapic_safety_settings = None
if safety_settings:
Expand Down Expand Up @@ -872,7 +865,11 @@ async def async_generator():
return async_generator()

def count_tokens(
self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None
self,
contents: ContentsType,
*,
tools: Optional[List["Tool"]] = None,
generation_config: Optional[GenerationConfigType] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens.
Expand All @@ -885,6 +882,7 @@ def count_tokens(
* List[Union[str, Image, Part]],
* List[Content]
tools: A list of tools (functions) that the model can try calling.
generation_config: Parameters for the generate_content method.
Returns:
A CountTokensResponse object that has the following attributes:
Expand All @@ -894,19 +892,22 @@ def count_tokens(
request = self._prepare_request(
contents=contents,
tools=tools,
generation_config=generation_config,
)
return self._gapic_count_tokens(
prediction_resource_name=self._prediction_resource_name,
contents=request.contents,
system_instruction=request.system_instruction,
tools=request.tools,
generation_config=request.generation_config,
)

async def count_tokens_async(
self,
contents: ContentsType,
*,
tools: Optional[List["Tool"]] = None,
generation_config: Optional[GenerationConfigType] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens asynchronously.
Expand All @@ -919,6 +920,7 @@ async def count_tokens_async(
* List[Union[str, Image, Part]],
* List[Content]
tools: A list of tools (functions) that the model can try calling.
generation_config: Parameters for the generate_content method.
Returns:
And awaitable for a CountTokensResponse object that has the following attributes:
Expand All @@ -928,12 +930,14 @@ async def count_tokens_async(
request = self._prepare_request(
contents=contents,
tools=tools,
generation_config=generation_config,
)
return await self._gapic_count_tokens_async(
prediction_resource_name=self._prediction_resource_name,
contents=request.contents,
system_instruction=request.system_instruction,
tools=request.tools,
generation_config=request.generation_config,
)

def _gapic_count_tokens(
Expand All @@ -942,13 +946,15 @@ def _gapic_count_tokens(
contents: List[gapic_content_types.Content],
system_instruction: Optional[gapic_content_types.Content] = None,
tools: Optional[List[gapic_tool_types.Tool]] = None,
generation_config: Optional[gapic_content_types.GenerationConfig] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
request = gapic_prediction_service_types.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
return self._prediction_client.count_tokens(request=request)

Expand All @@ -958,13 +964,15 @@ async def _gapic_count_tokens_async(
contents: List[gapic_content_types.Content],
system_instruction: Optional[gapic_content_types.Content] = None,
tools: Optional[List[gapic_tool_types.Tool]] = None,
generation_config: Optional[gapic_content_types.GenerationConfig] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
request = gapic_prediction_service_types.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
return await self._prediction_async_client.count_tokens(request=request)

Expand Down Expand Up @@ -2790,6 +2798,18 @@ def _to_content(
return gapic_content_types.Content(parts=parts, role=role)


def _to_generation_config(
generation_config: GenerationConfigType,
) -> gapic_content_types.GenerationConfig:
"""Converts generation config to gapic_prediction_service_types.GenerationConfig object."""
if isinstance(generation_config, gapic_content_types.GenerationConfig):
return generation_config
elif isinstance(generation_config, GenerationConfig):
return generation_config._raw_generation_config
elif isinstance(generation_config, Dict):
return gapic_content_types.GenerationConfig(**generation_config)


def _append_response(
base_response: GenerationResponse,
new_response: GenerationResponse,
Expand Down Expand Up @@ -3186,6 +3206,7 @@ def _gapic_count_tokens(
contents: List[types_v1.Content],
system_instruction: Optional[types_v1.Content] = None,
tools: Optional[List[types_v1.Tool]] = None,
generation_config: Optional[types_v1.GenerationConfig] = None,
) -> types_v1.CountTokensResponse:
request = types_v1.CountTokensRequest(
endpoint=prediction_resource_name,
Expand Down
28 changes: 26 additions & 2 deletions vertexai/tokenization/_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
Image,
Tool,
PartsType,
GenerationConfigType,
_validate_generate_content_parameters,
_validate_contents_type_as_valid_sequence,
_content_types_to_gapic_contents,
_to_content,
_to_generation_config,
)

from vertexai.tokenization._tokenizer_loading import (
Expand Down Expand Up @@ -172,7 +175,6 @@ def _to_gapic_contents(
_validate_contents_type_as_valid_sequence(contents)
_assert_no_image_contents_type(contents)
gapic_contents = _content_types_to_gapic_contents(contents)
# _assert_text_only_content_types_sequence(gapic_contents)
return gapic_contents


Expand Down Expand Up @@ -325,6 +327,18 @@ def add_function_response(
f"Function response argument contains unsupported types for token counting. Supported fields {counted_function_response}. Got {function_response}."
)

def add_generation_config(
self, generation_config: gapic_content_types.GenerationConfig
) -> None:
if generation_config.response_schema:
counted_generation_config = self._schema_traverse(
generation_config._pb.response_schema
)
if counted_generation_config._pb != generation_config._pb:
raise ValueError(
f"Generation config argument contains unsupported types for token counting. Supported fields {counted_generation_config}. Got {generation_config}."
)

def _function_declaration_traverse(
self, function_declaration: gapic_tool_types.FunctionDeclaration
) -> gapic_tool_types.FunctionDeclaration:
Expand Down Expand Up @@ -450,6 +464,7 @@ def count_tokens(
*,
tools: Optional[List["Tool"]] = None,
system_instruction: Optional[PartsType] = None,
generation_config: Optional[GenerationConfigType] = None,
) -> CountTokensResult:
r"""Counts the number of tokens in the text-only contents.
Expand All @@ -472,7 +487,12 @@ def count_tokens(
A CountTokensResult object containing the total number of tokens in
the contents.
"""

_validate_generate_content_parameters(
contents,
tools=tools,
system_instruction=system_instruction,
generation_config=generation_config,
)
text_accumulator = _TextsAccumulator()
if _is_string_inputs(contents):
text_accumulator.add_texts(contents)
Expand All @@ -490,6 +510,10 @@ def count_tokens(
else:
text_accumulator.add_content(_to_content(system_instruction))

if generation_config:
canonical_generation_config = _to_generation_config(generation_config)
text_accumulator.add_generation_config(canonical_generation_config)

return self._sentencepiece_adapter.count_tokens(text_accumulator.get_texts())

def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult:
Expand Down

0 comments on commit a0df599

Please sign in to comment.