Skip to content

Commit 6d91504

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Enable Vertex Model Garden Managed OSS Fine Tuning.
PiperOrigin-RevId: 815806517
1 parent e2aa3eb commit 6d91504

File tree

3 files changed

+98
-12
lines changed

3 files changed

+98
-12
lines changed

vertexai/tuning/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
# We just want to re-export certain classes
1818
# pylint: disable=g-multiple-import,g-importing-member
19+
from vertexai.tuning._tuning import SourceModel
1920
from vertexai.tuning._tuning import TuningJob
2021

2122
__all__ = [
23+
"SourceModel",
2224
"TuningJob",
2325
]

vertexai/tuning/_supervised_tuning.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,25 @@
2121
tuning_job as gca_tuning_job_types,
2222
)
2323
from vertexai import generative_models
24-
from vertexai.tuning import _tuning
24+
from vertexai.tuning import (
25+
SourceModel,
26+
TuningJob,
27+
)
2528

2629

2730
def train(
2831
*,
29-
source_model: Union[str, generative_models.GenerativeModel],
32+
source_model: Union[str, generative_models.GenerativeModel, SourceModel],
3033
train_dataset: Union[str, datasets.MultimodalDataset],
3134
validation_dataset: Optional[Union[str, datasets.MultimodalDataset]] = None,
3235
tuned_model_display_name: Optional[str] = None,
36+
tuning_mode: Optional[Literal["FULL", "PEFT_ADAPTER"]] = None,
3337
epochs: Optional[int] = None,
38+
learning_rate: Optional[float] = None,
3439
learning_rate_multiplier: Optional[float] = None,
3540
adapter_size: Optional[Literal[1, 4, 8, 16, 32]] = None,
3641
labels: Optional[Dict[str, str]] = None,
42+
output_uri: Optional[str] = None,
3743
) -> "SupervisedTuningJob":
3844
"""Tunes a model using supervised training.
3945
@@ -44,14 +50,41 @@ def train(
4450
tuned_model_display_name: The display name of the
4551
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
4652
128 characters long and can consist of any UTF-8 characters.
53+
tuning_mode: Tuning mode for this tuning job. Can only be used with OSS
54+
models.
4755
epochs: Number of training epoches for this tuning job.
48-
learning_rate_multiplier: Learning rate multiplier for tuning.
56+
learning_rate: Learning rate for tuning. Can only be used with OSS
57+
models. Mutually exclusive with `learning_rate_multiplier`.
58+
learning_rate_multiplier: Learning rate multiplier for tuning. Can only
59+
be used with OSS models. Mutually exclusive with `learning_rate`.
4960
adapter_size: Adapter size for tuning.
5061
labels: User-defined metadata to be associated with trained models
62+
output_uri: The Google Cloud Storage URI to write the tuned model to.
63+
Can only be used with OSS models.
5164
5265
Returns:
5366
A `TuningJob` object.
5467
"""
68+
if tuning_mode is None:
69+
tuning_mode_value = None
70+
elif tuning_mode == "FULL":
71+
tuning_mode_value = (
72+
gca_tuning_job_types.SupervisedTuningSpec.TuningMode.TUNING_MODE_FULL
73+
)
74+
elif tuning_mode == "PEFT_ADAPTER":
75+
tuning_mode_value = (
76+
gca_tuning_job_types.SupervisedTuningSpec.TuningMode.TUNING_MODE_PEFT_ADAPTER
77+
)
78+
else:
79+
raise ValueError(
80+
f"Unsupported tuning mode: {tuning_mode}. The supported tuning modes are [FULL, PEFT_ADAPTER]"
81+
)
82+
83+
if learning_rate and learning_rate_multiplier:
84+
raise ValueError(
85+
"Only one of `learning_rate` and `learning_rate_multiplier` can be set."
86+
)
87+
5588
if adapter_size is None:
5689
adapter_size_value = None
5790
elif adapter_size == 1:
@@ -83,10 +116,12 @@ def train(
83116
if isinstance(validation_dataset, datasets.MultimodalDataset):
84117
validation_dataset = validation_dataset.resource_name
85118
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
119+
tuning_mode=tuning_mode_value,
86120
training_dataset_uri=train_dataset,
87121
validation_dataset_uri=validation_dataset,
88122
hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters(
89123
epoch_count=epochs,
124+
learning_rate=learning_rate,
90125
learning_rate_multiplier=learning_rate_multiplier,
91126
adapter_size=adapter_size_value,
92127
),
@@ -95,20 +130,26 @@ def train(
95130
if isinstance(source_model, generative_models.GenerativeModel):
96131
source_model = source_model._prediction_resource_name.rpartition("/")[-1]
97132

133+
if labels is None:
134+
labels = {}
135+
if "mg-source" not in labels and output_uri:
136+
labels["mg-source"] = "sdk"
137+
98138
supervised_tuning_job = (
99139
SupervisedTuningJob._create( # pylint: disable=protected-access
100140
base_model=source_model,
101141
tuning_spec=supervised_tuning_spec,
102142
tuned_model_display_name=tuned_model_display_name,
103143
labels=labels,
144+
output_uri=output_uri,
104145
)
105146
)
106147
_ipython_utils.display_model_tuning_button(supervised_tuning_job)
107148

108149
return supervised_tuning_job
109150

110151

111-
class SupervisedTuningJob(_tuning.TuningJob):
152+
class SupervisedTuningJob(TuningJob):
112153
def __init__(self, tuning_job_name: str):
113154
super().__init__(tuning_job_name=tuning_job_name)
114155
_ipython_utils.display_model_tuning_button(self)

vertexai/tuning/_tuning.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,42 @@
4343
_LOGGER = aiplatform_base.Logger(__name__)
4444

4545

46+
class SourceModel:
47+
r"""A model that is used in managed OSS supervised tuning.
48+
49+
Usage:
50+
```
51+
model = SourceModel(
52+
base_model="meta/llama3_1@llama-3.1-8b",
53+
custom_base_model="gs://user-bucket/custom-weights",
54+
)
55+
sft_tuning_job = sft.train(
56+
source_model=model,
57+
train_dataset="gs://my-bucket/train.jsonl",
58+
validation_dataset="gs://my-bucket/validation.jsonl",
59+
epochs=4,
60+
tuned_model_display_name="my-tuned-model",
61+
output_uri="gs://user-bucket/tuned-model"
62+
)
63+
64+
while not sft_tuning_job.has_ended:
65+
time.sleep(60)
66+
sft_tuning_job.refresh()
67+
68+
tuned_model = aiplatform.Model(sft_tuning_job.tuned_model_name)
69+
```
70+
"""
71+
72+
def __init__(
73+
self,
74+
base_model: str,
75+
custom_base_model: str = "",
76+
):
77+
r"""Initializes SourceModel."""
78+
self.base_model = base_model
79+
self.custom_base_model = custom_base_model
80+
81+
4682
class TuningJobClientWithOverride(aiplatform_utils.ClientWithOverride):
4783
_is_temporary = True
4884
_default_version = compat.V1BETA1
@@ -133,7 +169,7 @@ def tuning_data_statistics(self) -> gca_tuning_job_types.TuningDataStats:
133169
def _create(
134170
cls,
135171
*,
136-
base_model: str,
172+
base_model: Union[str, SourceModel],
137173
tuning_spec: Union[
138174
gca_tuning_job_types.SupervisedTuningSpec,
139175
gca_tuning_job_types.DistillationSpec,
@@ -144,15 +180,13 @@ def _create(
144180
project: Optional[str] = None,
145181
location: Optional[str] = None,
146182
credentials: Optional[auth_credentials.Credentials] = None,
183+
output_uri: Optional[str] = None,
147184
) -> "TuningJob":
148185
r"""Submits TuningJob.
149186
150187
Args:
151-
base_model (str):
152-
Model name for tuning, e.g., "gemini-1.0-pro"
153-
or "gemini-1.0-pro-001".
154-
155-
This field is a member of `oneof`_ ``source_model``.
188+
base_model: Model for tuning.
189+
Supported types: str, SourceModel.
156190
tuning_spec: Tuning Spec for Fine Tuning.
157191
Supported types: SupervisedTuningSpec, DistillationSpec.
158192
tuned_model_display_name: The display name of the
@@ -179,6 +213,7 @@ def _create(
179213
Overrides location set in aiplatform.init.
180214
credentials: Custom credentials to use to call tuning job service.
181215
Overrides credentials set in aiplatform.init.
216+
output_uri: The Google Cloud Storage location to write the artifacts. This is only used for OSS models.
182217
183218
Returns:
184219
Submitted TuningJob.
@@ -192,17 +227,25 @@ def _create(
192227
tuned_model_display_name = cls._generate_display_name()
193228

194229
gca_tuning_job = gca_tuning_job_types.TuningJob(
195-
base_model=base_model,
196230
tuned_model_display_name=tuned_model_display_name,
197231
description=description,
198232
labels=labels,
199-
# The tuning_spec one_of is set later
233+
# The tuning_spec one_of is set later.
234+
output_uri=output_uri,
200235
)
201236

202237
if isinstance(tuning_spec, gca_tuning_job_types.SupervisedTuningSpec):
203238
gca_tuning_job.supervised_tuning_spec = tuning_spec
239+
if isinstance(base_model, SourceModel):
240+
gca_tuning_job.base_model = base_model.base_model
241+
gca_tuning_job.custom_base_model = base_model.custom_base_model
242+
else:
243+
gca_tuning_job.base_model = base_model
204244
elif isinstance(tuning_spec, gca_tuning_job_types.DistillationSpec):
205245
gca_tuning_job.distillation_spec = tuning_spec
246+
if isinstance(base_model, SourceModel):
247+
raise RuntimeError("Distillation is not supported for custom models.")
248+
gca_tuning_job.base_model = base_model
206249
else:
207250
raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")
208251

0 commit comments

Comments
 (0)