Skip to content

Commit

Permalink
feat: Add async REST support for transport override
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686157935
  • Loading branch information
matthew29tang authored and copybara-github committed Oct 21, 2024
1 parent af50cd4 commit 472f72c
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 15 deletions.
25 changes: 19 additions & 6 deletions google/cloud/aiplatform/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,24 @@ def create_client(
user_agent=user_agent,
)

# Async rest requires async credentials
client_credentials = credentials or self.credentials
if (
self._api_transport == "rest"
and "Async" in client_class.__name__
and isinstance(client_credentials, auth_credentials.Credentials)
):
raise ValueError(
"Async REST clients require async credentials. "
+ "Please pass async credentials into vertexai.init()\n"
+ "Example:\n"
+ "from google.auth.aio.credentials import StaticCredentials\n"
+ "async_credentials = StaticCredentials(token=YOUR_TOKEN_HERE)\n"
+ "vertexai.init(project=PROJECT_ID, location=LOCATION, credentials=async_credentials)"
)

kwargs = {
"credentials": credentials or self.credentials,
"credentials": client_credentials,
"client_options": self.get_client_options(
location_override=location_override,
prediction_client=prediction_client,
Expand All @@ -570,11 +586,8 @@ def create_client(
# Do not pass "grpc", rely on gapic defaults unless "rest" is specified
if self._api_transport == "rest":
if "Async" in client_class.__name__:
# Warn user that "rest" is not supported and use grpc instead
logging.warning(
"REST is not supported for async clients, "
+ "falling back to grpc."
)
# Need to specify rest_asyncio for async clients
kwargs["transport"] = "rest_asyncio"
else:
kwargs["transport"] = self._api_transport

Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def system(session):
CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt"
)
system_test_path = os.path.join("tests", "system.py")
system_test_folder_path = os.path.join("tests", "system")
system_test_folder_path = os.path.join("tests", "system/vertexai/test_generative_models.py") # Temporary change

# Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true.
if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false":
Expand Down
32 changes: 27 additions & 5 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,31 @@ def test_text_generation_preview_count_tokens(self, api_transport):
@pytest.mark.asyncio
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
async def test_text_generation_model_predict_async(self, api_transport):
aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
api_transport=api_transport,
)
# Create async credentials from default credentials for async REST
if api_transport == "rest":
default_credentials, _ = auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
auth_req = auth.transport.requests.Request()
default_credentials.refresh(auth_req)

# Create async credentials from default credentials
from google.auth.aio.credentials import StaticCredentials

async_credentials = StaticCredentials(token=default_credentials.token)

aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
credentials=async_credentials,
api_transport=api_transport,
)
else:
aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
api_transport=api_transport,
)

model = TextGenerationModel.from_pretrained("google/text-bison@001")
grounding_source = language_models.GroundingSource.WebSearch()
Expand All @@ -106,6 +126,8 @@ async def test_text_generation_model_predict_async(self, api_transport):
grounding_source=grounding_source,
)
assert response.text or response.is_blocked
if api_transport == "rest":
await model.close_async_client()

@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
def test_text_generation_streaming(self, api_transport):
Expand Down
37 changes: 34 additions & 3 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class TestGenerativeModels(e2e_base.TestEndToEnd):
_temp_prefix = "temp_generative_models_test_"

@pytest.fixture(scope="function", autouse=True)
def setup_method(self, api_endpoint_env_name):
def setup_method(self, api_endpoint_env_name, api_transport):
super().setup_method()
credentials, _ = auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
Expand All @@ -112,6 +112,7 @@ def setup_method(self, api_endpoint_env_name):
location=e2e_base._LOCATION,
credentials=credentials,
api_endpoint=api_endpoint,
api_transport=api_transport,
)

def test_generate_content_with_cached_content_from_text(
Expand Down Expand Up @@ -154,7 +155,8 @@ def test_generate_content_with_cached_content_from_text(
finally:
cached_content.delete()

def test_generate_content_from_text(self, api_endpoint_env_name):
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
def test_generate_content_from_text(self, api_endpoint_env_name, api_transport):
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
response = model.generate_content(
"Why is sky blue?",
Expand Down Expand Up @@ -226,7 +228,34 @@ def test_generate_content_streaming(self, api_endpoint_env_name):
)

@pytest.mark.asyncio
async def test_generate_content_streaming_async(self, api_endpoint_env_name):
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
async def test_generate_content_streaming_async(
self, api_endpoint_env_name, api_transport
):
# Create async credentials from default credentials for async REST
if api_transport == "rest":
default_credentials, _ = auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
auth_req = auth.transport.requests.Request()
default_credentials.refresh(auth_req)
if api_endpoint_env_name == STAGING_API_ENDPOINT:
api_endpoint = os.getenv(api_endpoint_env_name)
else:
api_endpoint = None

# Create async credentials from default credentials
from google.auth.aio.credentials import StaticCredentials

async_credentials = StaticCredentials(token=default_credentials.token)

aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
credentials=async_credentials,
api_endpoint=api_endpoint,
api_transport=api_transport,
)
model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
async_stream = await model.generate_content_async(
"Why is sky blue?",
Expand All @@ -239,6 +268,8 @@ async def test_generate_content_streaming_async(self, api_endpoint_env_name):
or chunk.candidates[0].finish_reason
is generative_models.FinishReason.STOP
)
if api_transport == "rest":
await model.close_async_client()

def test_generate_content_with_parameters(self, api_endpoint_env_name):
model = generative_models.GenerativeModel(
Expand Down
4 changes: 4 additions & 0 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ async def async_generator():

return async_generator()

async def close_async_client(self) -> None:
if self._prediction_async_client:
return await self._prediction_async_client.transport.close()

def count_tokens(
self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None
) -> gapic_prediction_service_types.CountTokensResponse:
Expand Down

0 comments on commit 472f72c

Please sign in to comment.