From 9ec51c42639319c67e6e6a25d93e65aa40671b40 Mon Sep 17 00:00:00 2001 From: Amy Wu Date: Mon, 14 Oct 2024 20:11:21 -0700 Subject: [PATCH] feat: Release API key support for GenerateContent to Public Preview PiperOrigin-RevId: 685926683 --- google/cloud/aiplatform/initializer.py | 17 ++++ tests/unit/vertex_rag/test_rag_data.py | 2 +- .../generative_models/_generative_models.py | 88 ++++++++++++++----- 3 files changed, 82 insertions(+), 25 deletions(-) diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py index 08c3528c57..bf2a6bd39a 100644 --- a/google/cloud/aiplatform/initializer.py +++ b/google/cloud/aiplatform/initializer.py @@ -221,6 +221,14 @@ def init( f"{api_transport} is not a valid transport type. " + f"Valid transport types: {VALID_TRANSPORT_TYPES}" ) + else: + # Raise error if api_transport other than rest is specified for usage with API key. + if not project and not api_transport: + api_transport = "rest" + elif not project and api_transport != "rest": + raise ValueError( + f"{api_transport} is not supported with API keys. " + ) if location: utils.validate_region(location) if experiment_description and experiment is None: @@ -236,6 +244,9 @@ def init( logging.info("project/location updated, reset Experiment config.") metadata._experiment_tracker.reset() + if project and api_key: + logging.info("Both a project and API key have been provided. The project will take precedence over the API key.") + # Then we change the main state if api_endpoint is not None: self._api_endpoint = api_endpoint @@ -438,7 +449,13 @@ def get_client_options( api_endpoint = self.api_endpoint + if api_endpoint is None and not self._project and not self._location and not location_override: + # Default endpoint is location invariant if using API key + api_endpoint = "aiplatform.googleapis.com" + + # If both project and API key are passed in, project takes precedence. if api_endpoint is None: + # Form the default endpoint to use with no API key. if not (self.location or location_override): raise ValueError( "No location found. Provide or initialize SDK with a location." diff --git a/tests/unit/vertex_rag/test_rag_data.py b/tests/unit/vertex_rag/test_rag_data.py index 907b64db79..9573ccd243 100644 --- a/tests/unit/vertex_rag/test_rag_data.py +++ b/tests/unit/vertex_rag/test_rag_data.py @@ -306,7 +306,7 @@ class TestRagDataManagement: def setup_method(self): importlib.reload(aiplatform.initializer) importlib.reload(aiplatform) - aiplatform.init() + aiplatform.init(project=tc.TEST_PROJECT, location=tc.TEST_REGION) def teardown_method(self): aiplatform.initializer.global_pool.shutdown(wait=True) diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index 805123bc1d..b87fa4cc15 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -388,13 +388,25 @@ def __init__( def _prediction_client(self) -> prediction_service.PredictionServiceClient: # Switch to @functools.cached_property once its available. if not getattr(self, "_prediction_client_value", None): - self._prediction_client_value = ( - aiplatform_initializer.global_config.create_client( - client_class=prediction_service.PredictionServiceClient, - location_override=self._location, - prediction_client=True, + if ( + aiplatform_initializer.global_config.api_key + and not aiplatform_initializer.global_config.project + ): + self._prediction_client_value = ( + aiplatform_initializer.global_config.create_client( + client_class=prediction_service.PredictionServiceClient, + api_key=aiplatform_initializer.global_config.api_key, + prediction_client=True, + ) + ) + else: + self._prediction_client_value = ( + aiplatform_initializer.global_config.create_client( + client_class=prediction_service.PredictionServiceClient, + location_override=self._location, + prediction_client=True, + ) ) - ) return self._prediction_client_value @property @@ -403,26 +415,46 @@ def _prediction_async_client( ) -> prediction_service.PredictionServiceAsyncClient: # Switch to @functools.cached_property once its available. if not getattr(self, "_prediction_async_client_value", None): - self._prediction_async_client_value = ( - aiplatform_initializer.global_config.create_client( - client_class=prediction_service.PredictionServiceAsyncClient, - location_override=self._location, - prediction_client=True, + if ( + aiplatform_initializer.global_config.api_key + and not aiplatform_initializer.global_config.project + ): + raise RuntimeError( + "Using an api key is not supported yet for async clients." + ) + else: + self._prediction_async_client_value = ( + aiplatform_initializer.global_config.create_client( + client_class=prediction_service.PredictionServiceAsyncClient, + location_override=self._location, + prediction_client=True, + ) ) - ) return self._prediction_async_client_value @property def _llm_utility_client(self) -> llm_utility_service.LlmUtilityServiceClient: # Switch to @functools.cached_property once its available. if not getattr(self, "_llm_utility_client_value", None): - self._llm_utility_client_value = ( - aiplatform_initializer.global_config.create_client( - client_class=llm_utility_service.LlmUtilityServiceClient, - location_override=self._location, - prediction_client=True, + if ( + aiplatform_initializer.global_config.api_key + and not aiplatform_initializer.global_config.project + ): + self._llm_utility_client_value = ( + aiplatform_initializer.global_config.create_client( + client_class=llm_utility_service.LlmUtilityServiceClient, + api_key=aiplatform_initializer.global_config.api_key, + prediction_client=True, + ) + ) + else: + self._llm_utility_client_value = ( + aiplatform_initializer.global_config.create_client( + client_class=llm_utility_service.LlmUtilityServiceClient, + location_override=self._location, + prediction_client=True, + ) ) - ) return self._llm_utility_client_value @property @@ -431,13 +463,21 @@ def _llm_utility_async_client( ) -> llm_utility_service.LlmUtilityServiceAsyncClient: # Switch to @functools.cached_property once its available. if not getattr(self, "_llm_utility_async_client_value", None): - self._llm_utility_async_client_value = ( - aiplatform_initializer.global_config.create_client( - client_class=llm_utility_service.LlmUtilityServiceAsyncClient, - location_override=self._location, - prediction_client=True, + if ( + aiplatform_initializer.global_config.api_key + and not aiplatform_initializer.global_config.project + ): + raise RuntimeError( + "Using an api key is not supported yet for async clients." + ) + else: + self._llm_utility_async_client_value = ( + aiplatform_initializer.global_config.create_client( + client_class=llm_utility_service.LlmUtilityServiceAsyncClient, + location_override=self._location, + prediction_client=True, + ) ) - ) return self._llm_utility_async_client_value def _prepare_request(