Skip to content

Commit

Permalink
feat: Release API key support for GenerateContent to Public Preview
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685926683
  • Loading branch information
yinghsienwu authored and copybara-github committed Oct 21, 2024
1 parent af50cd4 commit 9ec51c4
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 25 deletions.
17 changes: 17 additions & 0 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,14 @@ def init(
f"{api_transport} is not a valid transport type. "
+ f"Valid transport types: {VALID_TRANSPORT_TYPES}"
)
else:
# Raise error if api_transport other than rest is specified for usage with API key.
if not project and not api_transport:
api_transport = "rest"
elif not project and api_transport != "rest":
raise ValueError(
f"{api_transport} is not supported with API keys. "
)
if location:
utils.validate_region(location)
if experiment_description and experiment is None:
Expand All @@ -236,6 +244,9 @@ def init(
logging.info("project/location updated, reset Experiment config.")
metadata._experiment_tracker.reset()

if project and api_key:
logging.info("Both a project and API key have been provided. The project will take precedence over the API key.")

# Then we change the main state
if api_endpoint is not None:
self._api_endpoint = api_endpoint
Expand Down Expand Up @@ -438,7 +449,13 @@ def get_client_options(

api_endpoint = self.api_endpoint

if api_endpoint is None and not self._project and not self._location and not location_override:
# Default endpoint is location invariant if using API key
api_endpoint = "aiplatform.googleapis.com"

# If both project and API key are passed in, project takes precedence.
if api_endpoint is None:
# Form the default endpoint to use with no API key.
if not (self.location or location_override):
raise ValueError(
"No location found. Provide or initialize SDK with a location."
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/vertex_rag/test_rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class TestRagDataManagement:
def setup_method(self):
importlib.reload(aiplatform.initializer)
importlib.reload(aiplatform)
aiplatform.init()
aiplatform.init(project=tc.TEST_PROJECT, location=tc.TEST_REGION)

def teardown_method(self):
aiplatform.initializer.global_pool.shutdown(wait=True)
Expand Down
88 changes: 64 additions & 24 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,25 @@ def __init__(
def _prediction_client(self) -> prediction_service.PredictionServiceClient:
# Switch to @functools.cached_property once its available.
if not getattr(self, "_prediction_client_value", None):
self._prediction_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=prediction_service.PredictionServiceClient,
location_override=self._location,
prediction_client=True,
if (
aiplatform_initializer.global_config.api_key
and not aiplatform_initializer.global_config.project
):
self._prediction_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=prediction_service.PredictionServiceClient,
api_key=aiplatform_initializer.global_config.api_key,
prediction_client=True,
)
)
else:
self._prediction_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=prediction_service.PredictionServiceClient,
location_override=self._location,
prediction_client=True,
)
)
)
return self._prediction_client_value

@property
Expand All @@ -403,26 +415,46 @@ def _prediction_async_client(
) -> prediction_service.PredictionServiceAsyncClient:
# Switch to @functools.cached_property once its available.
if not getattr(self, "_prediction_async_client_value", None):
self._prediction_async_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=prediction_service.PredictionServiceAsyncClient,
location_override=self._location,
prediction_client=True,
if (
aiplatform_initializer.global_config.api_key
and not aiplatform_initializer.global_config.project
):
raise RuntimeError(
"Using an api key is not supported yet for async clients."
)
else:
self._prediction_async_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=prediction_service.PredictionServiceAsyncClient,
location_override=self._location,
prediction_client=True,
)
)
)
return self._prediction_async_client_value

@property
def _llm_utility_client(self) -> llm_utility_service.LlmUtilityServiceClient:
# Switch to @functools.cached_property once its available.
if not getattr(self, "_llm_utility_client_value", None):
self._llm_utility_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=llm_utility_service.LlmUtilityServiceClient,
location_override=self._location,
prediction_client=True,
if (
aiplatform_initializer.global_config.api_key
and not aiplatform_initializer.global_config.project
):
self._llm_utility_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=llm_utility_service.LlmUtilityServiceClient,
api_key=aiplatform_initializer.global_config.api_key,
prediction_client=True,
)
)
else:
self._llm_utility_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=llm_utility_service.LlmUtilityServiceClient,
location_override=self._location,
prediction_client=True,
)
)
)
return self._llm_utility_client_value

@property
Expand All @@ -431,13 +463,21 @@ def _llm_utility_async_client(
) -> llm_utility_service.LlmUtilityServiceAsyncClient:
# Switch to @functools.cached_property once its available.
if not getattr(self, "_llm_utility_async_client_value", None):
self._llm_utility_async_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=llm_utility_service.LlmUtilityServiceAsyncClient,
location_override=self._location,
prediction_client=True,
if (
aiplatform_initializer.global_config.api_key
and not aiplatform_initializer.global_config.project
):
raise RuntimeError(
"Using an api key is not supported yet for async clients."
)
else:
self._llm_utility_async_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=llm_utility_service.LlmUtilityServiceAsyncClient,
location_override=self._location,
prediction_client=True,
)
)
)
return self._llm_utility_async_client_value

def _prepare_request(
Expand Down

0 comments on commit 9ec51c4

Please sign in to comment.