Skip to content

Commit c9d0bc8

Browse files
committed
ruff formatting
1 parent 36688e0 commit c9d0bc8

File tree

10 files changed

+51
-154
lines changed

10 files changed

+51
-154
lines changed

llmtune/data/dataset_generator.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,8 @@ def _format_one_prompt(self, example, is_test: bool = False):
6161
return example
6262

6363
def _format_prompts(self):
64-
self.dataset["train"] = self.dataset["train"].map(
65-
partial(self._format_one_prompt, is_test=False)
66-
)
67-
self.dataset["test"] = self.dataset["test"].map(
68-
partial(self._format_one_prompt, is_test=True)
69-
)
64+
self.dataset["train"] = self.dataset["train"].map(partial(self._format_one_prompt, is_test=False))
65+
self.dataset["test"] = self.dataset["test"].map(partial(self._format_one_prompt, is_test=True))
7066

7167
def get_dataset(self) -> Tuple[Dataset, Dataset]:
7268
self._train_test_split()

llmtune/data/ingestor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ def get_ingestor(data_type: str):
1313
elif data_type == "huggingface":
1414
return HuggingfaceIngestor
1515
else:
16-
raise ValueError(
17-
f"'type' must be one of 'json', 'csv', or 'huggingface', you have {data_type}"
18-
)
16+
raise ValueError(f"'type' must be one of 'json', 'csv', or 'huggingface', you have {data_type}")
1917

2018

2119
class Ingestor(ABC):

llmtune/finetune/lora.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,7 @@ def _inject_lora(self):
9494
self.model = get_peft_model(self.model, self._lora_config)
9595

9696
if not self.config.accelerate:
97-
self.optimizer = bnb.optim.Adam8bit(
98-
self.model.parameters(), lr=self._training_args.learning_rate
99-
)
97+
self.optimizer = bnb.optim.Adam8bit(self.model.parameters(), lr=self._training_args.learning_rate)
10098
self.lr_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer)
10199
if self.config.accelerate:
102100
self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(

llmtune/inference/lora.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ def __init__(
3333
self.device_map = self.config.model.device_map
3434
self._weights_path = dir_helper.save_paths.weights
3535

36-
self.model, self.tokenizer = self._get_merged_model(
37-
dir_helper.save_paths.weights
38-
)
36+
self.model, self.tokenizer = self._get_merged_model(dir_helper.save_paths.weights)
3937

4038
def _get_merged_model(self, weights_path: str):
4139
# purge VRAM
@@ -45,20 +43,14 @@ def _get_merged_model(self, weights_path: str):
4543
dtype = (
4644
torch.float16
4745
if self.config.training.training_args.fp16
48-
else (
49-
torch.bfloat16
50-
if self.config.training.training_args.bf16
51-
else torch.float32
52-
)
46+
else (torch.bfloat16 if self.config.training.training_args.bf16 else torch.float32)
5347
)
5448

5549
self.model = AutoPeftModelForCausalLM.from_pretrained(
5650
weights_path,
5751
torch_dtype=dtype,
5852
device_map=self.device_map,
59-
quantization_config=(
60-
BitsAndBytesConfig(**self.config.model.bitsandbytes.model_dump())
61-
),
53+
quantization_config=(BitsAndBytesConfig(**self.config.model.bitsandbytes.model_dump())),
6254
)
6355

6456
"""TODO: figure out multi-gpu
@@ -68,9 +60,7 @@ def _get_merged_model(self, weights_path: str):
6860

6961
model = self.model.merge_and_unload()
7062

71-
tokenizer = AutoTokenizer.from_pretrained(
72-
self._weights_path, device_map=self.device_map
73-
)
63+
tokenizer = AutoTokenizer.from_pretrained(self._weights_path, device_map=self.device_map)
7464

7565
return model, tokenizer
7666

@@ -81,9 +71,7 @@ def infer_all(self):
8171

8272
# inference loop
8373
for idx, (prompt, label) in enumerate(zip(prompts, labels)):
84-
RichUI.inference_ground_truth_display(
85-
f"Generating on test set: {idx+1}/{len(prompts)}", prompt, label
86-
)
74+
RichUI.inference_ground_truth_display(f"Generating on test set: {idx+1}/{len(prompts)}", prompt, label)
8775

8876
try:
8977
result = self.infer_one(prompt)
@@ -101,9 +89,7 @@ def infer_all(self):
10189
writer.writerow(row)
10290

10391
def infer_one(self, prompt: str) -> str:
104-
input_ids = self.tokenizer(
105-
prompt, return_tensors="pt", truncation=True
106-
).input_ids.cuda()
92+
input_ids = self.tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
10793

10894
# stream processor
10995
streamer = TextIteratorStreamer(
@@ -113,9 +99,7 @@ def infer_one(self, prompt: str) -> str:
11399
timeout=60, # 60 sec timeout for generation; to handle OOM errors
114100
)
115101

116-
generation_kwargs = dict(
117-
input_ids=input_ids, streamer=streamer, **self.config.inference.model_dump()
118-
)
102+
generation_kwargs = dict(input_ids=input_ids, streamer=streamer, **self.config.inference.model_dump())
119103

120104
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
121105
thread.start()

llmtune/pydantic_models/config_model.py

Lines changed: 19 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,15 @@
77
# TODO: Refactor this into multiple files...
88
HfModelPath = str
99

10+
1011
class QaConfig(BaseModel):
11-
llm_tests: Optional[List[str]] = Field([], description = "list of tests that needs to be connected")
12+
llm_tests: Optional[List[str]] = Field([], description="list of tests that needs to be connected")
1213

1314

1415
class DataConfig(BaseModel):
15-
file_type: Literal["json", "csv", "huggingface"] = Field(
16-
None, description="File type"
17-
)
18-
path: Union[FilePath, HfModelPath] = Field(
19-
None, description="Path to the file or HuggingFace model"
20-
)
21-
prompt: str = Field(
22-
None, description="Prompt for the model. Use {} brackets for column name"
23-
)
16+
file_type: Literal["json", "csv", "huggingface"] = Field(None, description="File type")
17+
path: Union[FilePath, HfModelPath] = Field(None, description="Path to the file or HuggingFace model")
18+
prompt: str = Field(None, description="Prompt for the model. Use {} brackets for column name")
2419
prompt_stub: str = Field(
2520
None,
2621
description="Stub for the prompt; this is injected during training. Use {} brackets for column name",
@@ -47,9 +42,7 @@ class DataConfig(BaseModel):
4742

4843

4944
class BitsAndBytesConfig(BaseModel):
50-
load_in_8bit: Optional[bool] = Field(
51-
False, description="Enable 8-bit quantization with LLM.int8()"
52-
)
45+
load_in_8bit: Optional[bool] = Field(False, description="Enable 8-bit quantization with LLM.int8()")
5346
llm_int8_threshold: Optional[float] = Field(
5447
6.0, description="Outlier threshold for outlier detection in 8-bit quantization"
5548
)
@@ -60,9 +53,7 @@ class BitsAndBytesConfig(BaseModel):
6053
False,
6154
description="Enable splitting model parts between int8 on GPU and fp32 on CPU",
6255
)
63-
llm_int8_has_fp16_weight: Optional[bool] = Field(
64-
False, description="Run LLM.int8() with 16-bit main weights"
65-
)
56+
llm_int8_has_fp16_weight: Optional[bool] = Field(False, description="Run LLM.int8() with 16-bit main weights")
6657

6758
load_in_4bit: Optional[bool] = Field(
6859
True,
@@ -85,14 +76,10 @@ class ModelConfig(BaseModel):
8576
"NousResearch/Llama-2-7b-hf",
8677
description="Path to the model (huggingface repo or local path)",
8778
)
88-
device_map: Optional[str] = Field(
89-
"auto", description="device onto which to load the model"
90-
)
79+
device_map: Optional[str] = Field("auto", description="device onto which to load the model")
9180

9281
quantize: Optional[bool] = Field(False, description="Flag to enable quantization")
93-
bitsandbytes: BitsAndBytesConfig = Field(
94-
None, description="Bits and Bytes configuration"
95-
)
82+
bitsandbytes: BitsAndBytesConfig = Field(None, description="Bits and Bytes configuration")
9683

9784
# @validator("hf_model_ckpt")
9885
# def validate_model(cls, v, **kwargs):
@@ -115,22 +102,12 @@ def set_device_map_to_none(cls, v, values, **kwargs):
115102

116103
class LoraConfig(BaseModel):
117104
r: Optional[int] = Field(8, description="Lora rank")
118-
task_type: Optional[str] = Field(
119-
"CAUSAL_LM", description="Base Model task type during training"
120-
)
105+
task_type: Optional[str] = Field("CAUSAL_LM", description="Base Model task type during training")
121106

122-
lora_alpha: Optional[int] = Field(
123-
16, description="The alpha parameter for Lora scaling"
124-
)
125-
bias: Optional[str] = Field(
126-
"none", description="Bias type for Lora. Can be 'none', 'all' or 'lora_only'"
127-
)
128-
lora_dropout: Optional[float] = Field(
129-
0.1, description="The dropout probability for Lora layers"
130-
)
131-
target_modules: Optional[List[str]] = Field(
132-
None, description="The names of the modules to apply Lora to"
133-
)
107+
lora_alpha: Optional[int] = Field(16, description="The alpha parameter for Lora scaling")
108+
bias: Optional[str] = Field("none", description="Bias type for Lora. Can be 'none', 'all' or 'lora_only'")
109+
lora_dropout: Optional[float] = Field(0.1, description="The dropout probability for Lora layers")
110+
target_modules: Optional[List[str]] = Field(None, description="The names of the modules to apply Lora to")
134111
fan_in_fan_out: Optional[bool] = Field(
135112
False,
136113
description="Flag to indicate if the layer to replace stores weight like (fan_in, fan_out)",
@@ -139,9 +116,7 @@ class LoraConfig(BaseModel):
139116
None,
140117
description="List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint",
141118
)
142-
layers_to_transform: Optional[Union[List[int], int]] = Field(
143-
None, description="The layer indexes to transform"
144-
)
119+
layers_to_transform: Optional[Union[List[int], int]] = Field(None, description="The layer indexes to transform")
145120
layers_pattern: Optional[str] = Field(None, description="The layer pattern name")
146121
# rank_pattern: Optional[Dict[str, int]] = Field(
147122
# {}, description="The mapping from layer names or regexp expression to ranks"
@@ -154,15 +129,9 @@ class LoraConfig(BaseModel):
154129
# TODO: Get comprehensive Args!
155130
class TrainingArgs(BaseModel):
156131
num_train_epochs: Optional[int] = Field(1, description="Number of training epochs")
157-
per_device_train_batch_size: Optional[int] = Field(
158-
1, description="Batch size per training device"
159-
)
160-
gradient_accumulation_steps: Optional[int] = Field(
161-
1, description="Number of steps for gradient accumulation"
162-
)
163-
gradient_checkpointing: Optional[bool] = Field(
164-
True, description="Flag to enable gradient checkpointing"
165-
)
132+
per_device_train_batch_size: Optional[int] = Field(1, description="Batch size per training device")
133+
gradient_accumulation_steps: Optional[int] = Field(1, description="Number of steps for gradient accumulation")
134+
gradient_checkpointing: Optional[bool] = Field(True, description="Flag to enable gradient checkpointing")
166135
optim: Optional[str] = Field("paged_adamw_32bit", description="Optimizer")
167136
logging_steps: Optional[int] = Field(100, description="Number of logging steps")
168137
learning_rate: Optional[float] = Field(2.0e-4, description="Learning rate")
@@ -171,9 +140,7 @@ class TrainingArgs(BaseModel):
171140
fp16: Optional[bool] = Field(False, description="Flag to enable fp16")
172141
max_grad_norm: Optional[float] = Field(0.3, description="Maximum gradient norm")
173142
warmup_ratio: Optional[float] = Field(0.03, description="Warmup ratio")
174-
lr_scheduler_type: Optional[str] = Field(
175-
"constant", description="Learning rate scheduler type"
176-
)
143+
lr_scheduler_type: Optional[str] = Field("constant", description="Learning rate scheduler type")
177144

178145

179146
# TODO: Get comprehensive Args!

llmtune/qa/generics.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ def test_name(self) -> str:
1414
pass
1515

1616
@abstractmethod
17-
def get_metric(
18-
self, prompt: str, grount_truth: str, model_pred: str
19-
) -> Union[float, int, bool]:
17+
def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[float, int, bool]:
2018
pass
2119

2220

@@ -45,7 +43,6 @@ def __init__(
4543
ground_truths: List[str],
4644
model_preds: List[str],
4745
) -> None:
48-
4946
self.tests = tests
5047
self.prompts = prompts
5148
self.ground_truths = ground_truths
@@ -57,9 +54,7 @@ def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]:
5754
test_results = {}
5855
for test in zip(self.tests):
5956
metrics = []
60-
for prompt, ground_truth, model_pred in zip(
61-
self.prompts, self.ground_truths, self.model_preds
62-
):
57+
for prompt, ground_truth, model_pred in zip(self.prompts, self.ground_truths, self.model_preds):
6358
metrics.append(test.get_metric(prompt, ground_truth, model_pred))
6459
test_results[test.test_name] = metrics
6560

@@ -74,14 +69,10 @@ def print_test_results(self):
7469
result_dictionary = self.test_results()
7570
column_data = {key: list(result_dictionary[key]) for key in result_dictionary}
7671
mean_values = {key: statistics.mean(column_data[key]) for key in column_data}
77-
median_values = {
78-
key: statistics.median(column_data[key]) for key in column_data
79-
}
72+
median_values = {key: statistics.median(column_data[key]) for key in column_data}
8073
stdev_values = {key: statistics.stdev(column_data[key]) for key in column_data}
8174
# Use the RichUI class to display the table
82-
RichUI.display_table(
83-
result_dictionary, mean_values, median_values, stdev_values
84-
)
75+
RichUI.display_table(result_dictionary, mean_values, median_values, stdev_values)
8576

8677
def save_test_results(self, path: str):
8778
# TODO: save these!

llmtune/qa/qa_tests.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ class LengthTest(LLMQaTest):
2727
def test_name(self) -> str:
2828
return "summary_length"
2929

30-
def get_metric(
31-
self, prompt: str, ground_truth: str, model_prediction: str
32-
) -> Union[float, int, bool]:
30+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
3331
return abs(len(ground_truth) - len(model_prediction))
3432

3533

@@ -39,9 +37,7 @@ class JaccardSimilarityTest(LLMQaTest):
3937
def test_name(self) -> str:
4038
return "jaccard_similarity"
4139

42-
def get_metric(
43-
self, prompt: str, ground_truth: str, model_prediction: str
44-
) -> Union[float, int, bool]:
40+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
4541
set_ground_truth = set(ground_truth.lower())
4642
set_model_prediction = set(model_prediction.lower())
4743

@@ -64,14 +60,10 @@ def _encode_sentence(self, sentence):
6460
outputs = model(**tokens)
6561
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
6662

67-
def get_metric(
68-
self, prompt: str, ground_truth: str, model_prediction: str
69-
) -> Union[float, int, bool]:
63+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
7064
embedding_ground_truth = self._encode_sentence(ground_truth)
7165
embedding_model_prediction = self._encode_sentence(model_prediction)
72-
dot_product_similarity = np.dot(
73-
embedding_ground_truth, embedding_model_prediction
74-
)
66+
dot_product_similarity = np.dot(embedding_ground_truth, embedding_model_prediction)
7567
return dot_product_similarity
7668

7769

@@ -81,9 +73,7 @@ class RougeScoreTest(LLMQaTest):
8173
def test_name(self) -> str:
8274
return "rouge_score"
8375

84-
def get_metric(
85-
self, prompt: str, ground_truth: str, model_prediction: str
86-
) -> Union[float, int, bool]:
76+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
8777
scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True)
8878
scores = scorer.score(model_prediction, ground_truth)
8979
return float(scores["rouge1"].precision)
@@ -101,9 +91,7 @@ def _remove_stopwords(self, text: str) -> str:
10191
filtered_text = [word for word in word_tokens if word.lower() not in stop_words]
10292
return " ".join(filtered_text)
10393

104-
def get_metric(
105-
self, prompt: str, ground_truth: str, model_prediction: str
106-
) -> Union[float, int, bool]:
94+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]:
10795
cleaned_model_prediction = self._remove_stopwords(model_prediction)
10896
cleaned_ground_truth = self._remove_stopwords(ground_truth)
10997

@@ -130,12 +118,8 @@ class VerbPercent(PosCompositionTest):
130118
def test_name(self) -> str:
131119
return "verb_percent"
132120

133-
def get_metric(
134-
self, prompt: str, ground_truth: str, model_prediction: str
135-
) -> float:
136-
return self._get_pos_percent(
137-
model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"]
138-
)
121+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
122+
return self._get_pos_percent(model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"])
139123

140124

141125
@TestRegistry.register("adjective_percent")
@@ -144,9 +128,7 @@ class AdjectivePercent(PosCompositionTest):
144128
def test_name(self) -> str:
145129
return "adjective_percent"
146130

147-
def get_metric(
148-
self, prompt: str, ground_truth: str, model_prediction: str
149-
) -> float:
131+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
150132
return self._get_pos_percent(model_prediction, ["JJ", "JJR", "JJS"])
151133

152134

@@ -156,9 +138,7 @@ class NounPercent(PosCompositionTest):
156138
def test_name(self) -> str:
157139
return "noun_percent"
158140

159-
def get_metric(
160-
self, prompt: str, ground_truth: str, model_prediction: str
161-
) -> float:
141+
def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
162142
return self._get_pos_percent(model_prediction, ["NN", "NNS", "NNP", "NNPS"])
163143

164144

0 commit comments

Comments
 (0)