Skip to content

Commit

Permalink
feat: Add generation_config to count_tokens
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688307166
  • Loading branch information
happy-qiao authored and copybara-github committed Oct 23, 2024
1 parent da76253 commit b2f4a64
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 17 deletions.
15 changes: 15 additions & 0 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,18 @@ def test_count_tokens_from_text(self):
response_with_si_and_tool.total_billable_characters
> response_with_si.total_billable_characters
)
# content + generation_config
response_with_generation_config = model.count_tokens(
content,
generation_config=generative_models.GenerationConfig(
response_schema=_RESPONSE_SCHEMA_STRUCT
),
)
assert (
response_with_generation_config.total_tokens
> response_with_si_and_tool.total_tokens
)
assert (
response_with_generation_config.total_billable_characters
> response_with_si_and_tool.total_billable_characters
)
65 changes: 49 additions & 16 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,7 @@ def _validate_generate_content_parameters(
_validate_safety_settings_type_as_valid_sequence(safety_settings)

if generation_config:
if not isinstance(
generation_config,
(gapic_content_types.GenerationConfig, GenerationConfig, Dict),
):
raise TypeError(
"generation_config must either be a GenerationConfig object or a dictionary representation of it."
)
_validate_generation_config_type(generation_config)

if tools:
_validate_tools_type_as_valid_sequence(tools)
Expand Down Expand Up @@ -243,6 +237,17 @@ def _validate_safety_settings_type_as_valid_sequence(
)


def _validate_generation_config_type(
generation_config: GenerationConfigType,
) -> None:
if not isinstance(
generation_config,
(gapic_content_types.GenerationConfig, GenerationConfig, Dict),
):
raise TypeError(
"generation_config must either be a GenerationConfig object or a dictionary representation of it."
)

def _validate_tools_type_as_valid_sequence(tools: List["Tool"]):
for tool in tools:
if not isinstance(tool, (gapic_tool_types.Tool, Tool)):
Expand Down Expand Up @@ -482,14 +487,7 @@ def _prepare_request(

gapic_generation_config: Optional[gapic_content_types.GenerationConfig] = None
if generation_config:
if isinstance(generation_config, gapic_content_types.GenerationConfig):
gapic_generation_config = generation_config
elif isinstance(generation_config, GenerationConfig):
gapic_generation_config = generation_config._raw_generation_config
elif isinstance(generation_config, Dict):
gapic_generation_config = gapic_content_types.GenerationConfig(
**generation_config
)
gapic_generation_config = _to_generation_config(generation_config)

gapic_safety_settings = None
if safety_settings:
Expand Down Expand Up @@ -872,7 +870,11 @@ async def async_generator():
return async_generator()

def count_tokens(
self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None
self,
contents: ContentsType,
*,
tools: Optional[List["Tool"]] = None,
generation_config: Optional[GenerationConfigType] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens.
Expand All @@ -885,6 +887,7 @@ def count_tokens(
* List[Union[str, Image, Part]],
* List[Content]
tools: A list of tools (functions) that the model can try calling.
generation_config: Parameters for the generate_content method.
Returns:
A CountTokensResponse object that has the following attributes:
Expand All @@ -894,19 +897,22 @@ def count_tokens(
request = self._prepare_request(
contents=contents,
tools=tools,
generation_config=generation_config,
)
return self._gapic_count_tokens(
prediction_resource_name=self._prediction_resource_name,
contents=request.contents,
system_instruction=request.system_instruction,
tools=request.tools,
generation_config=request.generation_config,
)

async def count_tokens_async(
self,
contents: ContentsType,
*,
tools: Optional[List["Tool"]] = None,
generation_config: Optional[GenerationConfigType] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens asynchronously.
Expand All @@ -919,6 +925,7 @@ async def count_tokens_async(
* List[Union[str, Image, Part]],
* List[Content]
tools: A list of tools (functions) that the model can try calling.
generation_config: Parameters for the generate_content method.
Returns:
And awaitable for a CountTokensResponse object that has the following attributes:
Expand All @@ -928,12 +935,14 @@ async def count_tokens_async(
request = self._prepare_request(
contents=contents,
tools=tools,
generation_config=generation_config,
)
return await self._gapic_count_tokens_async(
prediction_resource_name=self._prediction_resource_name,
contents=request.contents,
system_instruction=request.system_instruction,
tools=request.tools,
generation_config=request.generation_config,
)

def _gapic_count_tokens(
Expand All @@ -942,14 +951,17 @@ def _gapic_count_tokens(
contents: List[gapic_content_types.Content],
system_instruction: Optional[gapic_content_types.Content] = None,
tools: Optional[List[gapic_tool_types.Tool]] = None,
generation_config: Optional[gapic_content_types.GenerationConfig] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
request = gapic_prediction_service_types.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
print("=====>", request)
return self._prediction_client.count_tokens(request=request)

async def _gapic_count_tokens_async(
Expand All @@ -958,14 +970,17 @@ async def _gapic_count_tokens_async(
contents: List[gapic_content_types.Content],
system_instruction: Optional[gapic_content_types.Content] = None,
tools: Optional[List[gapic_tool_types.Tool]] = None,
generation_config: Optional[gapic_content_types.GenerationConfig] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
request = gapic_prediction_service_types.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
print("=====>", request)
return await self._prediction_async_client.count_tokens(request=request)

def compute_tokens(
Expand Down Expand Up @@ -2790,6 +2805,18 @@ def _to_content(
return gapic_content_types.Content(parts=parts, role=role)


def _to_generation_config(
generation_config: GenerationConfigType,
) -> gapic_content_types.GenerationConfig:
"""Converts generation config to gapic_prediction_service_types.GenerationConfig object."""
if isinstance(generation_config, gapic_content_types.GenerationConfig):
return generation_config
elif isinstance(generation_config, GenerationConfig):
return generation_config._raw_generation_config
elif isinstance(generation_config, Dict):
return gapic_content_types.GenerationConfig(**generation_config)


def _append_response(
base_response: GenerationResponse,
new_response: GenerationResponse,
Expand Down Expand Up @@ -3186,14 +3213,17 @@ def _gapic_count_tokens(
contents: List[types_v1.Content],
system_instruction: Optional[types_v1.Content] = None,
tools: Optional[List[types_v1.Tool]] = None,
generation_config: Optional[types_v1.GenerationConfig] = None,
) -> types_v1.CountTokensResponse:
request = types_v1.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
print("=====>", request)
return self._llm_utility_client.count_tokens(request=request)

async def _gapic_count_tokens_async(
Expand All @@ -3202,14 +3232,17 @@ async def _gapic_count_tokens_async(
contents: List[types_v1.Content],
system_instruction: Optional[types_v1.Content] = None,
tools: Optional[List[types_v1.Tool]] = None,
generation_config: Optional[types_v1.GenerationConfig] = None,
) -> types_v1.CountTokensResponse:
request = types_v1.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
print("=====>", request)
return await self._llm_utility_async_client.count_tokens(request=request)

# The compute_tokens methods need to be overridden since the request types differ.
Expand Down
27 changes: 26 additions & 1 deletion vertexai/tokenization/_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@
Image,
Tool,
PartsType,
GenerationConfigType,
_validate_generation_config_type,
_validate_tools_type_as_valid_sequence,
_validate_contents_type_as_valid_sequence,
_content_types_to_gapic_contents,
_to_content,
_to_generation_config,
)

from vertexai.tokenization._tokenizer_loading import (
Expand Down Expand Up @@ -172,7 +176,6 @@ def _to_gapic_contents(
_validate_contents_type_as_valid_sequence(contents)
_assert_no_image_contents_type(contents)
gapic_contents = _content_types_to_gapic_contents(contents)
# _assert_text_only_content_types_sequence(gapic_contents)
return gapic_contents


Expand Down Expand Up @@ -325,6 +328,18 @@ def add_function_response(
f"Function response argument contains unsupported types for token counting. Supported fields {counted_function_response}. Got {function_response}."
)

def add_generation_config(
self, generation_config: gapic_content_types.GenerationConfig
) -> None:
if generation_config.response_schema:
counted_generation_config = self._schema_traverse(
generation_config._pb.response_schema
)
if counted_generation_config._pb != generation_config._pb:
raise ValueError(
f"Generation config argument contains unsupported types for token counting. Supported fields {counted_generation_config}. Got {generation_config}."
)

def _function_declaration_traverse(
self, function_declaration: gapic_tool_types.FunctionDeclaration
) -> gapic_tool_types.FunctionDeclaration:
Expand Down Expand Up @@ -450,6 +465,7 @@ def count_tokens(
*,
tools: Optional[List["Tool"]] = None,
system_instruction: Optional[PartsType] = None,
generation_config: Optional[GenerationConfigType] = None,
) -> CountTokensResult:
r"""Counts the number of tokens in the text-only contents.
Expand All @@ -472,6 +488,11 @@ def count_tokens(
A CountTokensResult object containing the total number of tokens in
the contents.
"""
if generation_config:
_validate_generation_config_type(generation_config)

if tools:
_validate_tools_type_as_valid_sequence(tools)

text_accumulator = _TextsAccumulator()
if _is_string_inputs(contents):
Expand All @@ -490,6 +511,10 @@ def count_tokens(
else:
text_accumulator.add_content(_to_content(system_instruction))

if generation_config:
canonical_generation_config = _to_generation_config(generation_config)
text_accumulator.add_generation_config(canonical_generation_config)

return self._sentencepiece_adapter.count_tokens(text_accumulator.get_texts())

def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult:
Expand Down

0 comments on commit b2f4a64

Please sign in to comment.