77# TODO: Refactor this into multiple files...
88HfModelPath = str
99
10+
1011class QaConfig (BaseModel ):
11- llm_tests : Optional [List [str ]] = Field ([], description = "list of tests that needs to be connected" )
12+ llm_tests : Optional [List [str ]] = Field ([], description = "list of tests that needs to be connected" )
1213
1314
1415class DataConfig (BaseModel ):
15- file_type : Literal ["json" , "csv" , "huggingface" ] = Field (
16- None , description = "File type"
17- )
18- path : Union [FilePath , HfModelPath ] = Field (
19- None , description = "Path to the file or HuggingFace model"
20- )
21- prompt : str = Field (
22- None , description = "Prompt for the model. Use {} brackets for column name"
23- )
16+ file_type : Literal ["json" , "csv" , "huggingface" ] = Field (None , description = "File type" )
17+ path : Union [FilePath , HfModelPath ] = Field (None , description = "Path to the file or HuggingFace model" )
18+ prompt : str = Field (None , description = "Prompt for the model. Use {} brackets for column name" )
2419 prompt_stub : str = Field (
2520 None ,
2621 description = "Stub for the prompt; this is injected during training. Use {} brackets for column name" ,
@@ -47,9 +42,7 @@ class DataConfig(BaseModel):
4742
4843
4944class BitsAndBytesConfig (BaseModel ):
50- load_in_8bit : Optional [bool ] = Field (
51- False , description = "Enable 8-bit quantization with LLM.int8()"
52- )
45+ load_in_8bit : Optional [bool ] = Field (False , description = "Enable 8-bit quantization with LLM.int8()" )
5346 llm_int8_threshold : Optional [float ] = Field (
5447 6.0 , description = "Outlier threshold for outlier detection in 8-bit quantization"
5548 )
@@ -60,9 +53,7 @@ class BitsAndBytesConfig(BaseModel):
6053 False ,
6154 description = "Enable splitting model parts between int8 on GPU and fp32 on CPU" ,
6255 )
63- llm_int8_has_fp16_weight : Optional [bool ] = Field (
64- False , description = "Run LLM.int8() with 16-bit main weights"
65- )
56+ llm_int8_has_fp16_weight : Optional [bool ] = Field (False , description = "Run LLM.int8() with 16-bit main weights" )
6657
6758 load_in_4bit : Optional [bool ] = Field (
6859 True ,
@@ -85,14 +76,10 @@ class ModelConfig(BaseModel):
8576 "NousResearch/Llama-2-7b-hf" ,
8677 description = "Path to the model (huggingface repo or local path)" ,
8778 )
88- device_map : Optional [str ] = Field (
89- "auto" , description = "device onto which to load the model"
90- )
79+ device_map : Optional [str ] = Field ("auto" , description = "device onto which to load the model" )
9180
9281 quantize : Optional [bool ] = Field (False , description = "Flag to enable quantization" )
93- bitsandbytes : BitsAndBytesConfig = Field (
94- None , description = "Bits and Bytes configuration"
95- )
82+ bitsandbytes : BitsAndBytesConfig = Field (None , description = "Bits and Bytes configuration" )
9683
9784 # @validator("hf_model_ckpt")
9885 # def validate_model(cls, v, **kwargs):
@@ -115,22 +102,12 @@ def set_device_map_to_none(cls, v, values, **kwargs):
115102
116103class LoraConfig (BaseModel ):
117104 r : Optional [int ] = Field (8 , description = "Lora rank" )
118- task_type : Optional [str ] = Field (
119- "CAUSAL_LM" , description = "Base Model task type during training"
120- )
105+ task_type : Optional [str ] = Field ("CAUSAL_LM" , description = "Base Model task type during training" )
121106
122- lora_alpha : Optional [int ] = Field (
123- 16 , description = "The alpha parameter for Lora scaling"
124- )
125- bias : Optional [str ] = Field (
126- "none" , description = "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"
127- )
128- lora_dropout : Optional [float ] = Field (
129- 0.1 , description = "The dropout probability for Lora layers"
130- )
131- target_modules : Optional [List [str ]] = Field (
132- None , description = "The names of the modules to apply Lora to"
133- )
107+ lora_alpha : Optional [int ] = Field (16 , description = "The alpha parameter for Lora scaling" )
108+ bias : Optional [str ] = Field ("none" , description = "Bias type for Lora. Can be 'none', 'all' or 'lora_only'" )
109+ lora_dropout : Optional [float ] = Field (0.1 , description = "The dropout probability for Lora layers" )
110+ target_modules : Optional [List [str ]] = Field (None , description = "The names of the modules to apply Lora to" )
134111 fan_in_fan_out : Optional [bool ] = Field (
135112 False ,
136113 description = "Flag to indicate if the layer to replace stores weight like (fan_in, fan_out)" ,
@@ -139,9 +116,7 @@ class LoraConfig(BaseModel):
139116 None ,
140117 description = "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint" ,
141118 )
142- layers_to_transform : Optional [Union [List [int ], int ]] = Field (
143- None , description = "The layer indexes to transform"
144- )
119+ layers_to_transform : Optional [Union [List [int ], int ]] = Field (None , description = "The layer indexes to transform" )
145120 layers_pattern : Optional [str ] = Field (None , description = "The layer pattern name" )
146121 # rank_pattern: Optional[Dict[str, int]] = Field(
147122 # {}, description="The mapping from layer names or regexp expression to ranks"
@@ -154,15 +129,9 @@ class LoraConfig(BaseModel):
154129# TODO: Get comprehensive Args!
155130class TrainingArgs (BaseModel ):
156131 num_train_epochs : Optional [int ] = Field (1 , description = "Number of training epochs" )
157- per_device_train_batch_size : Optional [int ] = Field (
158- 1 , description = "Batch size per training device"
159- )
160- gradient_accumulation_steps : Optional [int ] = Field (
161- 1 , description = "Number of steps for gradient accumulation"
162- )
163- gradient_checkpointing : Optional [bool ] = Field (
164- True , description = "Flag to enable gradient checkpointing"
165- )
132+ per_device_train_batch_size : Optional [int ] = Field (1 , description = "Batch size per training device" )
133+ gradient_accumulation_steps : Optional [int ] = Field (1 , description = "Number of steps for gradient accumulation" )
134+ gradient_checkpointing : Optional [bool ] = Field (True , description = "Flag to enable gradient checkpointing" )
166135 optim : Optional [str ] = Field ("paged_adamw_32bit" , description = "Optimizer" )
167136 logging_steps : Optional [int ] = Field (100 , description = "Number of logging steps" )
168137 learning_rate : Optional [float ] = Field (2.0e-4 , description = "Learning rate" )
@@ -171,9 +140,7 @@ class TrainingArgs(BaseModel):
171140 fp16 : Optional [bool ] = Field (False , description = "Flag to enable fp16" )
172141 max_grad_norm : Optional [float ] = Field (0.3 , description = "Maximum gradient norm" )
173142 warmup_ratio : Optional [float ] = Field (0.03 , description = "Warmup ratio" )
174- lr_scheduler_type : Optional [str ] = Field (
175- "constant" , description = "Learning rate scheduler type"
176- )
143+ lr_scheduler_type : Optional [str ] = Field ("constant" , description = "Learning rate scheduler type" )
177144
178145
179146# TODO: Get comprehensive Args!
0 commit comments