Skip to content

Commit

Permalink
feat: Add async REST support for transport override
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686157935
  • Loading branch information
matthew29tang authored and copybara-github committed Oct 25, 2024
1 parent 1f3b2d8 commit c0d8edd
Show file tree
Hide file tree
Showing 13 changed files with 359 additions and 56 deletions.
35 changes: 29 additions & 6 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,34 @@ def create_client(
user_agent=user_agent,
)

# Async rest requires async credentials
client_credentials = credentials or self.credentials
if self._api_transport == "rest" and "Async" in client_class.__name__:
try:
import google.auth.aio
except ImportError:
raise ImportError(
"Async REST transport requires async credentials which is \n"
+ "only supported in google-auth >= 2.35.0. Reinstall "
+ "the SDK using the async_rest extras installation: \n"
+ "pip install google-cloud-aiplatform[async_rest]"
)
if not isinstance(
client_credentials, google.auth.aio.credentials.Credentials
):
raise ValueError(
"Async REST transport requires async credentials. "
+ "Configure the credentials parameter in vertexai.init()"
+ "with supported async credentials of type"
+ "google.auth.aio.credentials.Credentials.\n"
+ "Example:\n"
+ "from google.auth.aio.credentials import StaticCredentials\n"
+ "async_credentials = StaticCredentials(token=YOUR_TOKEN_HERE)\n"
+ "vertexai.init(project=PROJECT_ID, location=LOCATION, credentials=async_credentials)"
)

kwargs = {
"credentials": credentials or self.credentials,
"credentials": client_credentials,
"client_options": self.get_client_options(
location_override=location_override,
prediction_client=prediction_client,
Expand All @@ -592,11 +618,8 @@ def create_client(
# Do not pass "grpc", rely on gapic defaults unless "rest" is specified
if self._api_transport == "rest":
if "Async" in client_class.__name__:
# Warn user that "rest" is not supported and use grpc instead
logging.warning(
"REST is not supported for async clients, "
+ "falling back to grpc."
)
# Need to specify rest_asyncio for async clients
kwargs["transport"] = "rest_asyncio"
else:
kwargs["transport"] = self._api_transport

Expand Down
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@
"xgboost_ray",
]

# Don't add this to testing_extra_require so we can test behavior of the
# libraries with and without async REST support.
async_rest_extra_require = [
"google-api-core[grpc, async_rest] >= 2.21.0",
"google-auth[aiohttp] >= 2.35.0",
]

reasoning_engine_extra_require = [
"cloudpickle >= 3.0, < 4.0",
"google-cloud-trace < 2",
Expand Down Expand Up @@ -195,6 +202,7 @@
+ profiler_extra_require
+ tokenization_testing_extra_require
+ [
"aiohttp", # Required for async_rest
"bigframes; python_version>='3.10'",
# google-api-core 2.x is required since kfp requires protobuf > 4
"google-api-core >= 2.11, < 3.0.0",
Expand Down Expand Up @@ -276,6 +284,7 @@
"langchain": langchain_extra_require,
"langchain_testing": langchain_testing_extra_require,
"tokenization": tokenization_extra_require,
"async_rest": async_rest_extra_require,
},
python_requires=">=3.8",
classifiers=[
Expand Down
3 changes: 2 additions & 1 deletion testing/constraints-3.10.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# This constraints file is required for unit tests.
# List all library dependencies and extras in this file.
google-api-core
google-api-core==2.21.0 # Tests google-api-core with rest async support
google-auth==2.35.0 # Tests google-auth with rest async support
proto-plus==1.22.3
protobuf
mock==4.0.2
Expand Down
3 changes: 2 additions & 1 deletion testing/constraints-3.11.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# This constraints file is required for unit tests.
# List all library dependencies and extras in this file.
google-api-core
google-api-core==2.21.0 # Tests google-api-core with rest async support
google-auth==2.35.0 # Tests google-auth with rest async support
proto-plus
protobuf
mock==4.0.2
Expand Down
3 changes: 2 additions & 1 deletion testing/constraints-3.12.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# This constraints file is required for unit tests.
# List all library dependencies and extras in this file.
google-api-core
google-api-core==2.21.0 # Tests google-api-core with rest async support
google-auth==2.35.0 # Tests google-auth with rest async support
proto-plus
protobuf
mock==4.0.2
Expand Down
2 changes: 1 addition & 1 deletion testing/constraints-3.8.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# are correct in setup.py
# List *all* library dependencies and extras in this file.
google-api-core==2.17.1 # Increased for gapic owlbot presubmit tests
google-auth==2.14.1
google-auth==2.14.1 # Tests google-auth without rest async support
proto-plus==1.22.3
protobuf
mock==4.0.2
Expand Down
3 changes: 2 additions & 1 deletion testing/constraints-3.9.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# This constraints file is required for unit tests.
# List all library dependencies and extras in this file.
google-api-core
google-api-core==2.21.0 # Tests google-api-core with rest async support
google-auth==2.35.0 # Tests google-auth with rest async support
proto-plus==1.22.3
protobuf
mock==4.0.2
Expand Down
20 changes: 20 additions & 0 deletions tests/system/aiplatform/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
# limitations under the License.
#

import pytest

from google.auth import credentials as auth_credentials

from google.cloud import aiplatform
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform_v1beta1.services import prediction_service
from tests.system.aiplatform import e2e_base


Expand All @@ -39,3 +43,19 @@ def test_init_calls_set_google_auth_default(self):
# init() with only project shouldn't overwrite creds
aiplatform.init(project=e2e_base._PROJECT)
assert aiplatform.initializer.global_config.credentials == creds

def test_init_rest_async_incorrect_credentials(self):
# Sync credentials should not be used for async REST transport.
creds = auth_credentials.AnonymousCredentials()
aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
api_transport="rest",
credentials=creds
)

with pytest.raises(ValueError):
aiplatform_initializer.global_config.create_client(
client_class=prediction_service.PredictionServiceAsyncClient,
prediction_client=True,
)
Loading

0 comments on commit c0d8edd

Please sign in to comment.