Skip to content

Commit 8318289

Browse files
committed
create cli command for config-path
1 parent 5d3fa49 commit 8318289

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

toolkit.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from transformers import utils as hf_utils
77
from pydantic import ValidationError
88
import torch
9+
import typer
910

1011
from src.pydantic_models.config_model import Config
1112
from src.data.dataset_generator import DatasetGenerator
@@ -19,7 +20,10 @@
1920
torch._logging.set_logs(all=logging.CRITICAL)
2021

2122

22-
def run_one_experiment(config: Config) -> None:
23+
app = typer.Typer()
24+
25+
26+
def run_one_experiment(config: Config, config_path: str) -> None:
2327
dir_helper = DirectoryHelper(config_path, config)
2428

2529
# Loading Data -------------------------------
@@ -70,7 +74,7 @@ def run_one_experiment(config: Config) -> None:
7074
RichUI.inference_found(results_path)
7175

7276
# QA -------------------------------
73-
# console.rule("[bold blue]:thinking_face: Running LLM Unit Tests")
77+
# RichUI.before_qa()
7478
# qa_path = dir_helper.save_paths.qa
7579
# if not exists(qa_path) or not listdir(qa_path):
7680
# # TODO: Instantiate unit test classes
@@ -80,9 +84,8 @@ def run_one_experiment(config: Config) -> None:
8084
# pass
8185

8286

83-
if __name__ == "__main__":
84-
config_path = "./config.yml" # TODO: parameterize this
85-
87+
@app.command()
88+
def run(config_path: str = "./config.yml") -> None:
8689
# Load YAML config
8790
with open(config_path, "r") as file:
8891
config = yaml.safe_load(file)
@@ -92,9 +95,9 @@ def run_one_experiment(config: Config) -> None:
9295
else [config]
9396
)
9497
for config in configs:
98+
# validate data with pydantic
9599
try:
96100
config = Config(**config)
97-
# validate data with pydantic
98101
except ValidationError as e:
99102
print(e.json())
100103

@@ -105,4 +108,8 @@ def run_one_experiment(config: Config) -> None:
105108
config = yaml.safe_load(file)
106109
config = Config(**config)
107110

108-
run_one_experiment(config)
111+
run_one_experiment(config, config_path)
112+
113+
114+
if __name__ == "__main__":
115+
app()

0 commit comments

Comments
 (0)