2121 tuning_job as gca_tuning_job_types ,
2222)
2323from vertexai import generative_models
24- from vertexai .tuning import _tuning
24+ from vertexai .tuning import (
25+ SourceModel ,
26+ TuningJob ,
27+ )
2528
2629
2730def train (
2831 * ,
29- source_model : Union [str , generative_models .GenerativeModel ],
32+ source_model : Union [str , generative_models .GenerativeModel , SourceModel ],
3033 train_dataset : Union [str , datasets .MultimodalDataset ],
3134 validation_dataset : Optional [Union [str , datasets .MultimodalDataset ]] = None ,
3235 tuned_model_display_name : Optional [str ] = None ,
36+ tuning_mode : Optional [Literal ["FULL" , "PEFT_ADAPTER" ]] = None ,
3337 epochs : Optional [int ] = None ,
38+ learning_rate : Optional [float ] = None ,
3439 learning_rate_multiplier : Optional [float ] = None ,
3540 adapter_size : Optional [Literal [1 , 4 , 8 , 16 , 32 ]] = None ,
3641 labels : Optional [Dict [str , str ]] = None ,
42+ output_uri : Optional [str ] = None ,
3743) -> "SupervisedTuningJob" :
3844 """Tunes a model using supervised training.
3945
@@ -44,14 +50,41 @@ def train(
4450 tuned_model_display_name: The display name of the
4551 [TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
4652 128 characters long and can consist of any UTF-8 characters.
53+ tuning_mode: Tuning mode for this tuning job. Can only be used with OSS
54+ models.
4755 epochs: Number of training epoches for this tuning job.
48- learning_rate_multiplier: Learning rate multiplier for tuning.
56+ learning_rate: Learning rate for tuning. Can only be used with OSS
57+ models. Mutually exclusive with `learning_rate_multiplier`.
58+ learning_rate_multiplier: Learning rate multiplier for tuning. Can only
59+ be used with OSS models. Mutually exclusive with `learning_rate`.
4960 adapter_size: Adapter size for tuning.
5061 labels: User-defined metadata to be associated with trained models
62+ output_uri: The Google Cloud Storage URI to write the tuned model to.
63+ Can only be used with OSS models.
5164
5265 Returns:
5366 A `TuningJob` object.
5467 """
68+ if tuning_mode is None :
69+ tuning_mode_value = None
70+ elif tuning_mode == "FULL" :
71+ tuning_mode_value = (
72+ gca_tuning_job_types .SupervisedTuningSpec .TuningMode .TUNING_MODE_FULL
73+ )
74+ elif tuning_mode == "PEFT_ADAPTER" :
75+ tuning_mode_value = (
76+ gca_tuning_job_types .SupervisedTuningSpec .TuningMode .TUNING_MODE_PEFT_ADAPTER
77+ )
78+ else :
79+ raise ValueError (
80+ f"Unsupported tuning mode: { tuning_mode } . The supported tuning modes are [FULL, PEFT_ADAPTER]"
81+ )
82+
83+ if learning_rate and learning_rate_multiplier :
84+ raise ValueError (
85+ "Only one of `learning_rate` and `learning_rate_multiplier` can be set."
86+ )
87+
5588 if adapter_size is None :
5689 adapter_size_value = None
5790 elif adapter_size == 1 :
@@ -83,10 +116,12 @@ def train(
83116 if isinstance (validation_dataset , datasets .MultimodalDataset ):
84117 validation_dataset = validation_dataset .resource_name
85118 supervised_tuning_spec = gca_tuning_job_types .SupervisedTuningSpec (
119+ tuning_mode = tuning_mode_value ,
86120 training_dataset_uri = train_dataset ,
87121 validation_dataset_uri = validation_dataset ,
88122 hyper_parameters = gca_tuning_job_types .SupervisedHyperParameters (
89123 epoch_count = epochs ,
124+ learning_rate = learning_rate ,
90125 learning_rate_multiplier = learning_rate_multiplier ,
91126 adapter_size = adapter_size_value ,
92127 ),
@@ -95,20 +130,26 @@ def train(
95130 if isinstance (source_model , generative_models .GenerativeModel ):
96131 source_model = source_model ._prediction_resource_name .rpartition ("/" )[- 1 ]
97132
133+ if labels is None :
134+ labels = {}
135+ if "mg-source" not in labels and output_uri :
136+ labels ["mg-source" ] = "sdk"
137+
98138 supervised_tuning_job = (
99139 SupervisedTuningJob ._create ( # pylint: disable=protected-access
100140 base_model = source_model ,
101141 tuning_spec = supervised_tuning_spec ,
102142 tuned_model_display_name = tuned_model_display_name ,
103143 labels = labels ,
144+ output_uri = output_uri ,
104145 )
105146 )
106147 _ipython_utils .display_model_tuning_button (supervised_tuning_job )
107148
108149 return supervised_tuning_job
109150
110151
111- class SupervisedTuningJob (_tuning . TuningJob ):
152+ class SupervisedTuningJob (TuningJob ):
112153 def __init__ (self , tuning_job_name : str ):
113154 super ().__init__ (tuning_job_name = tuning_job_name )
114155 _ipython_utils .display_model_tuning_button (self )
0 commit comments