Skip to content

feat: [vertexai] Add support for managed OSS fine tuning #5282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vertexai/tuning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

# We just want to re-export certain classes
# pylint: disable=g-multiple-import,g-importing-member
from vertexai.tuning._tuning import SourceModel
from vertexai.tuning._tuning import TuningJob

__all__ = [
"SourceModel",
"TuningJob",
]
10 changes: 8 additions & 2 deletions vertexai/tuning/_supervised_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,23 @@
)
from vertexai import generative_models
from vertexai.tuning import _tuning
from vertexai.tuning import SourceModel


def train(
*,
source_model: Union[str, generative_models.GenerativeModel],
source_model: Union[
str,
generative_models.GenerativeModel,
SourceModel],
train_dataset: str,
validation_dataset: Optional[str] = None,
tuned_model_display_name: Optional[str] = None,
epochs: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
adapter_size: Optional[Literal[1, 4, 8, 16]] = None,
labels: Optional[Dict[str, str]] = None,
output_uri: Optional[str] = None,
) -> "SupervisedTuningJob":
"""Tunes a model using supervised training.

Expand All @@ -49,7 +54,7 @@ def train(
learning_rate_multiplier: Learning rate multiplier for tuning.
adapter_size: Adapter size for tuning.
labels: User-defined metadata to be associated with trained models

output_uri: The Google Cloud Storage location to write the model artifacts.
Returns:
A `TuningJob` object.
"""
Expand Down Expand Up @@ -94,6 +99,7 @@ def train(
tuning_spec=supervised_tuning_spec,
tuned_model_display_name=tuned_model_display_name,
labels=labels,
output_uri=output_uri,
)
)
_ipython_utils.display_model_tuning_button(supervised_tuning_job)
Expand Down
58 changes: 52 additions & 6 deletions vertexai/tuning/_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,42 @@

_LOGGER = aiplatform_base.Logger(__name__)

class SourceModel:
r"""A model that is used in managed OSS supervised tuning.

Usage:
```
model = SourceModel(
base_model="meta/llama3.1-8b",
custom_base_model="gs://user-bucket/custom-weights",
)
sft_tuning_job = sft.train(
source_model=model,
train_dataset="gs://my-bucket/train.jsonl",
validation_dataset="gs://my-bucket/validation.jsonl",
epochs=4,
learning_rate_multiplier=0.5,
tuned_model_display_name="my-tuned-model",
output_uri="gs://user-bucket/tuned-model"
)

while not sft_tuning_job.has_ended:
time.sleep(60)
sft_tuning_job.refresh()

tuned_model = aiplatform.Model(sft_tuning_job.tuned_model_name)
```
"""

def __init__(
self,
base_model: str,
custom_base_model: str = "",
):
r"""Initializes SourceModel."""
self.base_model = base_model
self.custom_base_model = custom_base_model


class TuningJobClientWithOverride(aiplatform_utils.ClientWithOverride):
_is_temporary = True
Expand Down Expand Up @@ -132,7 +168,7 @@ def tuning_data_statistics(self) -> gca_tuning_job_types.TuningDataStats:
def _create(
cls,
*,
base_model: str,
base_model: Union[str, SourceModel],
tuning_spec: Union[
gca_tuning_job_types.SupervisedTuningSpec,
gca_tuning_job_types.DistillationSpec,
Expand All @@ -143,13 +179,13 @@ def _create(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
output_uri: Optional[str] = None,
) -> "TuningJob":
r"""Submits TuningJob.

Args:
base_model (str):
Model name for tuning, e.g., "gemini-1.0-pro"
or "gemini-1.0-pro-001".
base_model: Model for tuning:
Supported types: str, SourceModel.

This field is a member of `oneof`_ ``source_model``.
tuning_spec: Tuning Spec for Fine Tuning.
Expand Down Expand Up @@ -178,6 +214,7 @@ def _create(
Overrides location set in aiplatform.init.
credentials: Custom credentials to use to call tuning job service.
Overrides credentials set in aiplatform.init.
output_uri: The Google Cloud Storage location to write the artifacts.

Returns:
Submitted TuningJob.
Expand All @@ -191,17 +228,26 @@ def _create(
tuned_model_display_name = cls._generate_display_name()

gca_tuning_job = gca_tuning_job_types.TuningJob(
base_model=base_model,
tuned_model_display_name=tuned_model_display_name,
description=description,
labels=labels,
# The tuning_spec one_of is set later
output_uri=output_uri,
)

if isinstance(tuning_spec, gca_tuning_job_types.SupervisedTuningSpec):
gca_tuning_job.supervised_tuning_spec = tuning_spec
if isinstance(base_model, SourceModel):
gca_tuning_job.base_model = base_model.base_model
gca_tuning_job.custom_base_model = base_model.custom_base_model
else:
gca_tuning_job.base_model = base_model
elif isinstance(tuning_spec, gca_tuning_job_types.DistillationSpec):
gca_tuning_job.distillation_spec = tuning_spec
if isinstance(base_model, SourceModel):
raise RuntimeError(
"Distillation is not supported for custom models."
)
gca_tuning_job.base_model = base_model
else:
raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")

Expand Down