Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add generation_config to count_tokens #4563

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,18 @@ def test_count_tokens_from_text(self):
response_with_si_and_tool.total_billable_characters
> response_with_si.total_billable_characters
)
# content + generation_config
response_with_generation_config = model.count_tokens(
content,
generation_config=generative_models.GenerationConfig(
response_schema=_RESPONSE_SCHEMA_STRUCT
),
)
assert (
response_with_generation_config.total_tokens
> response_with_si_and_tool.total_tokens
)
assert (
response_with_generation_config.total_billable_characters
> response_with_si_and_tool.total_billable_characters
)
65 changes: 49 additions & 16 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,7 @@ def _validate_generate_content_parameters(
_validate_safety_settings_type_as_valid_sequence(safety_settings)

if generation_config:
if not isinstance(
generation_config,
(gapic_content_types.GenerationConfig, GenerationConfig, Dict),
):
raise TypeError(
"generation_config must either be a GenerationConfig object or a dictionary representation of it."
)
_validate_generation_config_type(generation_config)

if tools:
_validate_tools_type_as_valid_sequence(tools)
Expand Down Expand Up @@ -243,6 +237,17 @@ def _validate_safety_settings_type_as_valid_sequence(
)


def _validate_generation_config_type(
generation_config: GenerationConfigType,
) -> None:
if not isinstance(
generation_config,
(gapic_content_types.GenerationConfig, GenerationConfig, Dict),
):
raise TypeError(
"generation_config must either be a GenerationConfig object or a dictionary representation of it."
)

def _validate_tools_type_as_valid_sequence(tools: List["Tool"]):
for tool in tools:
if not isinstance(tool, (gapic_tool_types.Tool, Tool)):
Expand Down Expand Up @@ -482,14 +487,7 @@ def _prepare_request(

gapic_generation_config: Optional[gapic_content_types.GenerationConfig] = None
if generation_config:
if isinstance(generation_config, gapic_content_types.GenerationConfig):
gapic_generation_config = generation_config
elif isinstance(generation_config, GenerationConfig):
gapic_generation_config = generation_config._raw_generation_config
elif isinstance(generation_config, Dict):
gapic_generation_config = gapic_content_types.GenerationConfig(
**generation_config
)
gapic_generation_config = _to_generation_config(generation_config)

gapic_safety_settings = None
if safety_settings:
Expand Down Expand Up @@ -872,7 +870,11 @@ async def async_generator():
return async_generator()

def count_tokens(
self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None
self,
contents: ContentsType,
*,
tools: Optional[List["Tool"]] = None,
generation_config: Optional[GenerationConfigType] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens.

Expand All @@ -885,6 +887,7 @@ def count_tokens(
* List[Union[str, Image, Part]],
* List[Content]
tools: A list of tools (functions) that the model can try calling.
generation_config: Parameters for the generate_content method.

Returns:
A CountTokensResponse object that has the following attributes:
Expand All @@ -894,19 +897,22 @@ def count_tokens(
request = self._prepare_request(
contents=contents,
tools=tools,
generation_config=generation_config,
)
return self._gapic_count_tokens(
prediction_resource_name=self._prediction_resource_name,
contents=request.contents,
system_instruction=request.system_instruction,
tools=request.tools,
generation_config=request.generation_config,
)

async def count_tokens_async(
self,
contents: ContentsType,
*,
tools: Optional[List["Tool"]] = None,
generation_config: Optional[GenerationConfigType] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
"""Counts tokens asynchronously.

Expand All @@ -919,6 +925,7 @@ async def count_tokens_async(
* List[Union[str, Image, Part]],
* List[Content]
tools: A list of tools (functions) that the model can try calling.
generation_config: Parameters for the generate_content method.

Returns:
And awaitable for a CountTokensResponse object that has the following attributes:
Expand All @@ -928,12 +935,14 @@ async def count_tokens_async(
request = self._prepare_request(
contents=contents,
tools=tools,
generation_config=generation_config,
)
return await self._gapic_count_tokens_async(
prediction_resource_name=self._prediction_resource_name,
contents=request.contents,
system_instruction=request.system_instruction,
tools=request.tools,
generation_config=request.generation_config,
)

def _gapic_count_tokens(
Expand All @@ -942,14 +951,17 @@ def _gapic_count_tokens(
contents: List[gapic_content_types.Content],
system_instruction: Optional[gapic_content_types.Content] = None,
tools: Optional[List[gapic_tool_types.Tool]] = None,
generation_config: Optional[gapic_content_types.GenerationConfig] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
request = gapic_prediction_service_types.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
print("=====>", request)
return self._prediction_client.count_tokens(request=request)

async def _gapic_count_tokens_async(
Expand All @@ -958,14 +970,17 @@ async def _gapic_count_tokens_async(
contents: List[gapic_content_types.Content],
system_instruction: Optional[gapic_content_types.Content] = None,
tools: Optional[List[gapic_tool_types.Tool]] = None,
generation_config: Optional[gapic_content_types.GenerationConfig] = None,
) -> gapic_prediction_service_types.CountTokensResponse:
request = gapic_prediction_service_types.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
print("=====>", request)
return await self._prediction_async_client.count_tokens(request=request)

def compute_tokens(
Expand Down Expand Up @@ -2790,6 +2805,18 @@ def _to_content(
return gapic_content_types.Content(parts=parts, role=role)


def _to_generation_config(
generation_config: GenerationConfigType,
) -> gapic_content_types.GenerationConfig:
"""Converts generation config to gapic_prediction_service_types.GenerationConfig object."""
if isinstance(generation_config, gapic_content_types.GenerationConfig):
return generation_config
elif isinstance(generation_config, GenerationConfig):
return generation_config._raw_generation_config
elif isinstance(generation_config, Dict):
return gapic_content_types.GenerationConfig(**generation_config)


def _append_response(
base_response: GenerationResponse,
new_response: GenerationResponse,
Expand Down Expand Up @@ -3186,14 +3213,17 @@ def _gapic_count_tokens(
contents: List[types_v1.Content],
system_instruction: Optional[types_v1.Content] = None,
tools: Optional[List[types_v1.Tool]] = None,
generation_config: Optional[types_v1.GenerationConfig] = None,
) -> types_v1.CountTokensResponse:
request = types_v1.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
print("=====>", request)
return self._llm_utility_client.count_tokens(request=request)

async def _gapic_count_tokens_async(
Expand All @@ -3202,14 +3232,17 @@ async def _gapic_count_tokens_async(
contents: List[types_v1.Content],
system_instruction: Optional[types_v1.Content] = None,
tools: Optional[List[types_v1.Tool]] = None,
generation_config: Optional[types_v1.GenerationConfig] = None,
) -> types_v1.CountTokensResponse:
request = types_v1.CountTokensRequest(
endpoint=prediction_resource_name,
model=prediction_resource_name,
contents=contents,
system_instruction=system_instruction,
tools=tools,
generation_config=generation_config,
)
print("=====>", request)
return await self._llm_utility_async_client.count_tokens(request=request)

# The compute_tokens methods need to be overridden since the request types differ.
Expand Down
27 changes: 26 additions & 1 deletion vertexai/tokenization/_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@
Image,
Tool,
PartsType,
GenerationConfigType,
_validate_generation_config_type,
_validate_tools_type_as_valid_sequence,
_validate_contents_type_as_valid_sequence,
_content_types_to_gapic_contents,
_to_content,
_to_generation_config,
)

from vertexai.tokenization._tokenizer_loading import (
Expand Down Expand Up @@ -172,7 +176,6 @@ def _to_gapic_contents(
_validate_contents_type_as_valid_sequence(contents)
_assert_no_image_contents_type(contents)
gapic_contents = _content_types_to_gapic_contents(contents)
# _assert_text_only_content_types_sequence(gapic_contents)
return gapic_contents


Expand Down Expand Up @@ -325,6 +328,18 @@ def add_function_response(
f"Function response argument contains unsupported types for token counting. Supported fields {counted_function_response}. Got {function_response}."
)

def add_generation_config(
self, generation_config: gapic_content_types.GenerationConfig
) -> None:
if generation_config.response_schema:
counted_generation_config = self._schema_traverse(
generation_config._pb.response_schema
)
if counted_generation_config._pb != generation_config._pb:
raise ValueError(
f"Generation config argument contains unsupported types for token counting. Supported fields {counted_generation_config}. Got {generation_config}."
)

def _function_declaration_traverse(
self, function_declaration: gapic_tool_types.FunctionDeclaration
) -> gapic_tool_types.FunctionDeclaration:
Expand Down Expand Up @@ -450,6 +465,7 @@ def count_tokens(
*,
tools: Optional[List["Tool"]] = None,
system_instruction: Optional[PartsType] = None,
generation_config: Optional[GenerationConfigType] = None,
) -> CountTokensResult:
r"""Counts the number of tokens in the text-only contents.

Expand All @@ -472,6 +488,11 @@ def count_tokens(
A CountTokensResult object containing the total number of tokens in
the contents.
"""
if generation_config:
_validate_generation_config_type(generation_config)

if tools:
_validate_tools_type_as_valid_sequence(tools)

text_accumulator = _TextsAccumulator()
if _is_string_inputs(contents):
Expand All @@ -490,6 +511,10 @@ def count_tokens(
else:
text_accumulator.add_content(_to_content(system_instruction))

if generation_config:
canonical_generation_config = _to_generation_config(generation_config)
text_accumulator.add_generation_config(canonical_generation_config)

return self._sentencepiece_adapter.count_tokens(text_accumulator.get_texts())

def compute_tokens(self, contents: ContentsType) -> ComputeTokensResult:
Expand Down