From dd24d59d23ad51f693fd99d2e0b633c0cc1c4eac Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Thu, 24 Oct 2024 13:03:09 -0700 Subject: [PATCH] feat: Add partner model tuning to SDK PiperOrigin-RevId: 689488147 --- tests/unit/vertexai/tuning/test_tuning.py | 54 ++++++++++++++++--- vertexai/tuning/_partner_model_tuning.py | 66 +++++++++++++++++++++++ vertexai/tuning/_tuning.py | 2 + 3 files changed, 116 insertions(+), 6 deletions(-) create mode 100644 vertexai/tuning/_partner_model_tuning.py diff --git a/tests/unit/vertexai/tuning/test_tuning.py b/tests/unit/vertexai/tuning/test_tuning.py index 5faf590bc1..919e93fe79 100644 --- a/tests/unit/vertexai/tuning/test_tuning.py +++ b/tests/unit/vertexai/tuning/test_tuning.py @@ -22,10 +22,12 @@ import importlib from typing import Dict, Iterable from unittest import mock +from unittest.mock import patch import uuid from google import auth from google.auth import credentials as auth_credentials +from google.cloud import storage from google.cloud import aiplatform import vertexai from google.cloud.aiplatform import compat @@ -34,19 +36,18 @@ from google.cloud.aiplatform.metadata import experiment_resources from google.cloud.aiplatform_v1beta1.services import gen_ai_tuning_service from google.cloud.aiplatform_v1beta1.types import job_state -from google.cloud.aiplatform_v1beta1.types import tuning_job as gca_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + tuning_job as gca_tuning_job, +) from vertexai.preview import tuning from vertexai.preview.tuning import ( sft as preview_supervised_tuning, ) -from vertexai.tuning import sft as supervised_tuning from vertexai.tuning import _distillation -from google.cloud import storage - +from vertexai.tuning import _partner_model_tuning +from vertexai.tuning import sft as supervised_tuning import pytest -from unittest.mock import patch - from google.rpc import status_pb2 @@ -294,3 +295,44 @@ def test_genai_tuning_service_distillation_distill_model(self): assert tuning_job.has_ended assert tuning_job.has_succeeded assert tuning_job.tuned_model_name + + @mock.patch.object( + target=tuning.TuningJob, + attribute="client_class", + new=MockTuningJobClientWithOverride, + ) + def test_genai_tuning_service_partner_model_tuning(self): + partner_model_train = _partner_model_tuning.train + tuning_job = partner_model_train( + base_model="llama3-8b-instruct-maas", + training_dataset_uri="gs://some-bucket/some_dataset.jsonl", + tuned_model_display_name="tuned-llama3-8b-instruct-maas", + validation_dataset_uri="gs://some-bucket/some_dataset.jsonl", + hyper_parameters={ + "learning_rate": 0.0001, + "batch_size": 32, + "num_epochs": 10, + }, + ) + assert tuning_job.state == job_state.JobState.JOB_STATE_PENDING + assert not tuning_job.has_ended + assert not tuning_job.has_succeeded + + # Refreshing the job. + tuning_job.refresh() + assert tuning_job.state == job_state.JobState.JOB_STATE_PENDING + assert not tuning_job.has_ended + assert not tuning_job.has_succeeded + + # Refreshing the job. + tuning_job.refresh() + assert tuning_job.state == job_state.JobState.JOB_STATE_RUNNING + assert not tuning_job.has_ended + assert not tuning_job.has_succeeded + + # Refreshing the job + tuning_job.refresh() + assert tuning_job.state == job_state.JobState.JOB_STATE_SUCCEEDED + assert tuning_job.has_ended + assert tuning_job.has_succeeded + assert tuning_job.tuned_model_name diff --git a/vertexai/tuning/_partner_model_tuning.py b/vertexai/tuning/_partner_model_tuning.py new file mode 100644 index 0000000000..561d1965a4 --- /dev/null +++ b/vertexai/tuning/_partner_model_tuning.py @@ -0,0 +1,66 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access +"""Classes for partner model tuning.""" + +from typing import Any, Dict, Optional + +from google.cloud.aiplatform_v1beta1.types import ( + tuning_job as gca_tuning_job_types, +) +from vertexai.tuning import _tuning + + +def train( + *, + base_model: str, + training_dataset_uri: str, + validation_dataset_uri: Optional[str] = None, + tuned_model_display_name: Optional[str] = None, + hyper_parameters: Optional[Dict[str, Any]] = None, + labels: Optional[Dict[str, str]] = None, +) -> "PartnerModelTuningJob": + """Tunes a third party partner model. + + Args: + base_model: The base model to tune. + training_dataset_uri: The training dataset uri. + validation_dataset_uri: The validation dataset uri. + tuned_model_display_name: The display name of the + [TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up + to 128 characters long and can consist of any UTF-8 characters. + hyper_parameters: The hyper parameters of the tuning job. + labels: User-defined metadata to be associated with trained models + + Returns: + A `TuningJob` object. + """ + partner_model_tuning_spec = gca_tuning_job_types.PartnerModelTuningSpec( + training_dataset_uri=training_dataset_uri, + validation_dataset_uri=validation_dataset_uri, + hyper_parameters=hyper_parameters, + ) + partner_model_tuning_job = PartnerModelTuningJob._create( + base_model=base_model, + tuning_spec=partner_model_tuning_spec, + tuned_model_display_name=tuned_model_display_name, + labels=labels, + ) + + return partner_model_tuning_job + + +class PartnerModelTuningJob(_tuning.TuningJob): + pass diff --git a/vertexai/tuning/_tuning.py b/vertexai/tuning/_tuning.py index ced3fac9b1..9dc5b9d0af 100644 --- a/vertexai/tuning/_tuning.py +++ b/vertexai/tuning/_tuning.py @@ -197,6 +197,8 @@ def _create( gca_tuning_job.supervised_tuning_spec = tuning_spec elif isinstance(tuning_spec, gca_tuning_job_types.DistillationSpec): gca_tuning_job.distillation_spec = tuning_spec + elif isinstance(tuning_spec, gca_tuning_job_types.PartnerModelTuningSpec): + gca_tuning_job.partner_model_tuning_spec = tuning_spec else: raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")