Skip to content

Commit

Permalink
feat: Add async REST support for transport override
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686157935
  • Loading branch information
matthew29tang authored and copybara-github committed Oct 23, 2024
1 parent 0866009 commit 1965559
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .kokoro/continuous/common.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ before_action {
fetch_keystore {
keystore_resource {
keystore_config_id: 73713
keyname: "vertexai-staging-endpoint"
keyname: "vertexai-staging-endpoint-1"
}
}
}
Expand Down
25 changes: 19 additions & 6 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,24 @@ def create_client(
user_agent=user_agent,
)

# Async rest requires async credentials
client_credentials = credentials or self.credentials
if (
self._api_transport == "rest"
and "Async" in client_class.__name__
and not isinstance(client_credentials, google.auth.aio.credentials.Credentials)
):
raise ValueError(
"Async REST clients require async credentials. "
+ "Please pass async credentials into vertexai.init()\n"
+ "Example:\n"
+ "from google.auth.aio.credentials import StaticCredentials\n"
+ "async_credentials = StaticCredentials(token=YOUR_TOKEN_HERE)\n"
+ "vertexai.init(project=PROJECT_ID, location=LOCATION, credentials=async_credentials)"
)

kwargs = {
"credentials": credentials or self.credentials,
"credentials": client_credentials,
"client_options": self.get_client_options(
location_override=location_override,
prediction_client=prediction_client,
Expand All @@ -570,11 +586,8 @@ def create_client(
# Do not pass "grpc", rely on gapic defaults unless "rest" is specified
if self._api_transport == "rest":
if "Async" in client_class.__name__:
# Warn user that "rest" is not supported and use grpc instead
logging.warning(
"REST is not supported for async clients, "
+ "falling back to grpc."
)
# Need to specify rest_asyncio for async clients
kwargs["transport"] = "rest_asyncio"
else:
kwargs["transport"] = self._api_transport

Expand Down
4 changes: 3 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ def system(session):
CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt"
)
system_test_path = os.path.join("tests", "system.py")
system_test_folder_path = os.path.join("tests", "system")
system_test_folder_path = os.path.join(
"tests", "system/vertexai/test_generative_models.py"
) # Temporary change

# Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true.
if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false":
Expand Down
4 changes: 2 additions & 2 deletions testing/constraints-3.8.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# This constraints file is used to check that lower bounds
# are correct in setup.py
# List *all* library dependencies and extras in this file.
google-api-core==2.17.1 # Increased for gapic owlbot presubmit tests
google-auth==2.14.1
google-api-core
google-auth==2.35.0
proto-plus==1.22.3
protobuf
mock==4.0.2
Expand Down
31 changes: 26 additions & 5 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,31 @@ def test_text_generation_preview_count_tokens(self, api_transport):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
async def test_text_generation_model_predict_async(self, api_transport):
aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
api_transport=api_transport,
)
# Create async credentials from default credentials for async REST
if api_transport == "rest":
default_credentials, _ = auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
auth_req = auth.transport.requests.Request()
default_credentials.refresh(auth_req)

# Create async credentials from default credentials
from google.auth.aio.credentials import StaticCredentials

async_credentials = StaticCredentials(token=default_credentials.token)

aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
credentials=async_credentials,
api_transport=api_transport,
)
else:
aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
api_transport=api_transport,
)

model = TextGenerationModel.from_pretrained("google/text-bison@001")
grounding_source = language_models.GroundingSource.WebSearch()
Expand All @@ -106,6 +126,7 @@ async def test_text_generation_model_predict_async(self, api_transport):
grounding_source=grounding_source,
)
assert response.text or response.is_blocked
await model.close_async_client()

@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
def test_text_generation_streaming(self, api_transport):
Expand Down
39 changes: 36 additions & 3 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class TestGenerativeModels(e2e_base.TestEndToEnd):
_temp_prefix = "temp_generative_models_test_"

@pytest.fixture(scope="function", autouse=True)
def setup_method(self, api_endpoint_env_name):
def setup_method(self, api_endpoint_env_name, api_transport):
super().setup_method()
credentials, _ = auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
Expand All @@ -112,6 +112,7 @@ def setup_method(self, api_endpoint_env_name):
location=e2e_base._LOCATION,
credentials=credentials,
api_endpoint=api_endpoint,
api_transport=api_transport,
)

def test_generate_content_with_cached_content_from_text(
Expand Down Expand Up @@ -154,7 +155,8 @@ def test_generate_content_with_cached_content_from_text(
finally:
cached_content.delete()

def test_generate_content_from_text(self, api_endpoint_env_name):
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
def test_generate_content_from_text(self, api_endpoint_env_name, api_transport):
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
response = model.generate_content(
"Why is sky blue?",
Expand Down Expand Up @@ -226,7 +228,37 @@ def test_generate_content_streaming(self, api_endpoint_env_name):
)

@pytest.mark.asyncio
async def test_generate_content_streaming_async(self, api_endpoint_env_name):
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
async def test_generate_content_streaming_async(
self, api_endpoint_env_name, api_transport
):
# Retrieve access token from ADC required to construct
# google.auth.aio.credentials.StaticCredentials for async REST transport.
# TODO: Update this when google.auth.aio.default is supported for async.
if api_transport == "rest":
default_credentials, _ = auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
auth_req = auth.transport.requests.Request()
default_credentials.refresh(auth_req)
if api_endpoint_env_name == STAGING_API_ENDPOINT:
api_endpoint = os.getenv(api_endpoint_env_name)
else:
api_endpoint = None

# Construct google.auth.aio.credentials.StaticCredentials
# using the access token from ADC for async REST transport.
from google.auth.aio.credentials import StaticCredentials

async_credentials = StaticCredentials(token=default_credentials.token)

aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
credentials=async_credentials,
api_endpoint=api_endpoint,
api_transport=api_transport,
)
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
async_stream = await model.generate_content_async(
"Why is sky blue?",
Expand All @@ -239,6 +271,7 @@ async def test_generate_content_streaming_async(self, api_endpoint_env_name):
or chunk.candidates[0].finish_reason
is generative_models.FinishReason.STOP
)
await model.close_async_client()

def test_generate_content_with_parameters(self, api_endpoint_env_name):
model = generative_models.GenerativeModel(
Expand Down
30 changes: 25 additions & 5 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2074,11 +2074,29 @@ def test_text_generation_multiple_candidates_grounding(self):
@pytest.mark.asyncio
async def test_text_generation_async(self, api_transport):
"""Tests the text generation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
api_transport=api_transport,
)
if api_transport == "rest":
from google import auth

default_credentials, _ = auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)

# Create async credentials from default credentials
from google.auth.aio.credentials import StaticCredentials

async_credentials = StaticCredentials(token=default_credentials.token)
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
credentials=async_credentials,
api_transport=api_transport,
)
else:
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
api_transport=api_transport,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
Expand Down Expand Up @@ -2115,6 +2133,8 @@ async def test_text_generation_async(self, api_transport):
assert prediction_parameters["stopSequences"] == ["\n"]
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]

await model.close_async_client()

@pytest.mark.asyncio
async def test_text_generation_multiple_candidates_grounding_async(self):
"""Tests the text generation model with multiple candidates async with web grounding."""
Expand Down
4 changes: 4 additions & 0 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ async def async_generator():

return async_generator()

async def close_async_client(self) -> None:
if self._prediction_async_client:
return await self._prediction_async_client.transport.close()

def count_tokens(
self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None
) -> gapic_prediction_service_types.CountTokensResponse:
Expand Down
4 changes: 4 additions & 0 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,6 +1682,10 @@ async def predict_streaming_async(
)
yield _parse_text_generation_model_response(prediction_obj)

async def close_async_client(self) -> None:
if self._endpoint._prediction_async_client:
return await self._endpoint._prediction_async_client.transport.close()


def _create_text_generation_prediction_request(
prompt: str,
Expand Down

0 comments on commit 1965559

Please sign in to comment.