Skip to content

Commit

Permalink
feat: Add partner model tuning to SDK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689488147
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Oct 25, 2024
1 parent 1f3b2d8 commit dd24d59
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 6 deletions.
54 changes: 48 additions & 6 deletions tests/unit/vertexai/tuning/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import importlib
from typing import Dict, Iterable
from unittest import mock
from unittest.mock import patch
import uuid

from google import auth
from google.auth import credentials as auth_credentials
from google.cloud import storage
from google.cloud import aiplatform
import vertexai
from google.cloud.aiplatform import compat
Expand All @@ -34,19 +36,18 @@
from google.cloud.aiplatform.metadata import experiment_resources
from google.cloud.aiplatform_v1beta1.services import gen_ai_tuning_service
from google.cloud.aiplatform_v1beta1.types import job_state
from google.cloud.aiplatform_v1beta1.types import tuning_job as gca_tuning_job
from google.cloud.aiplatform_v1beta1.types import (
tuning_job as gca_tuning_job,
)
from vertexai.preview import tuning
from vertexai.preview.tuning import (
sft as preview_supervised_tuning,
)
from vertexai.tuning import sft as supervised_tuning
from vertexai.tuning import _distillation
from google.cloud import storage

from vertexai.tuning import _partner_model_tuning
from vertexai.tuning import sft as supervised_tuning
import pytest

from unittest.mock import patch

from google.rpc import status_pb2


Expand Down Expand Up @@ -294,3 +295,44 @@ def test_genai_tuning_service_distillation_distill_model(self):
assert tuning_job.has_ended
assert tuning_job.has_succeeded
assert tuning_job.tuned_model_name

@mock.patch.object(
target=tuning.TuningJob,
attribute="client_class",
new=MockTuningJobClientWithOverride,
)
def test_genai_tuning_service_partner_model_tuning(self):
partner_model_train = _partner_model_tuning.train
tuning_job = partner_model_train(
base_model="llama3-8b-instruct-maas",
training_dataset_uri="gs://some-bucket/some_dataset.jsonl",
tuned_model_display_name="tuned-llama3-8b-instruct-maas",
validation_dataset_uri="gs://some-bucket/some_dataset.jsonl",
hyper_parameters={
"learning_rate": 0.0001,
"batch_size": 32,
"num_epochs": 10,
},
)
assert tuning_job.state == job_state.JobState.JOB_STATE_PENDING
assert not tuning_job.has_ended
assert not tuning_job.has_succeeded

# Refreshing the job.
tuning_job.refresh()
assert tuning_job.state == job_state.JobState.JOB_STATE_PENDING
assert not tuning_job.has_ended
assert not tuning_job.has_succeeded

# Refreshing the job.
tuning_job.refresh()
assert tuning_job.state == job_state.JobState.JOB_STATE_RUNNING
assert not tuning_job.has_ended
assert not tuning_job.has_succeeded

# Refreshing the job
tuning_job.refresh()
assert tuning_job.state == job_state.JobState.JOB_STATE_SUCCEEDED
assert tuning_job.has_ended
assert tuning_job.has_succeeded
assert tuning_job.tuned_model_name
66 changes: 66 additions & 0 deletions vertexai/tuning/_partner_model_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pylint: disable=protected-access
"""Classes for partner model tuning."""

from typing import Any, Dict, Optional

from google.cloud.aiplatform_v1beta1.types import (
tuning_job as gca_tuning_job_types,
)
from vertexai.tuning import _tuning


def train(
*,
base_model: str,
training_dataset_uri: str,
validation_dataset_uri: Optional[str] = None,
tuned_model_display_name: Optional[str] = None,
hyper_parameters: Optional[Dict[str, Any]] = None,
labels: Optional[Dict[str, str]] = None,
) -> "PartnerModelTuningJob":
"""Tunes a third party partner model.
Args:
base_model: The base model to tune.
training_dataset_uri: The training dataset uri.
validation_dataset_uri: The validation dataset uri.
tuned_model_display_name: The display name of the
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up
to 128 characters long and can consist of any UTF-8 characters.
hyper_parameters: The hyper parameters of the tuning job.
labels: User-defined metadata to be associated with trained models
Returns:
A `TuningJob` object.
"""
partner_model_tuning_spec = gca_tuning_job_types.PartnerModelTuningSpec(
training_dataset_uri=training_dataset_uri,
validation_dataset_uri=validation_dataset_uri,
hyper_parameters=hyper_parameters,
)
partner_model_tuning_job = PartnerModelTuningJob._create(
base_model=base_model,
tuning_spec=partner_model_tuning_spec,
tuned_model_display_name=tuned_model_display_name,
labels=labels,
)

return partner_model_tuning_job


class PartnerModelTuningJob(_tuning.TuningJob):
pass
2 changes: 2 additions & 0 deletions vertexai/tuning/_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def _create(
gca_tuning_job.supervised_tuning_spec = tuning_spec
elif isinstance(tuning_spec, gca_tuning_job_types.DistillationSpec):
gca_tuning_job.distillation_spec = tuning_spec
elif isinstance(tuning_spec, gca_tuning_job_types.PartnerModelTuningSpec):
gca_tuning_job.partner_model_tuning_spec = tuning_spec
else:
raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")

Expand Down

0 comments on commit dd24d59

Please sign in to comment.