66from transformers import utils as hf_utils
77from pydantic import ValidationError
88import torch
9+ import typer
910
1011from src .pydantic_models .config_model import Config
1112from src .data .dataset_generator import DatasetGenerator
1920torch ._logging .set_logs (all = logging .CRITICAL )
2021
2122
22- def run_one_experiment (config : Config ) -> None :
23+ app = typer .Typer ()
24+
25+
26+ def run_one_experiment (config : Config , config_path : str ) -> None :
2327 dir_helper = DirectoryHelper (config_path , config )
2428
2529 # Loading Data -------------------------------
@@ -70,7 +74,7 @@ def run_one_experiment(config: Config) -> None:
7074 RichUI .inference_found (results_path )
7175
7276 # QA -------------------------------
73- # console.rule("[bold blue]:thinking_face: Running LLM Unit Tests" )
77+ # RichUI.before_qa( )
7478 # qa_path = dir_helper.save_paths.qa
7579 # if not exists(qa_path) or not listdir(qa_path):
7680 # # TODO: Instantiate unit test classes
@@ -80,9 +84,8 @@ def run_one_experiment(config: Config) -> None:
8084 # pass
8185
8286
83- if __name__ == "__main__" :
84- config_path = "./config.yml" # TODO: parameterize this
85-
87+ @app .command ()
88+ def run (config_path : str = "./config.yml" ) -> None :
8689 # Load YAML config
8790 with open (config_path , "r" ) as file :
8891 config = yaml .safe_load (file )
@@ -92,9 +95,9 @@ def run_one_experiment(config: Config) -> None:
9295 else [config ]
9396 )
9497 for config in configs :
98+ # validate data with pydantic
9599 try :
96100 config = Config (** config )
97- # validate data with pydantic
98101 except ValidationError as e :
99102 print (e .json ())
100103
@@ -105,4 +108,8 @@ def run_one_experiment(config: Config) -> None:
105108 config = yaml .safe_load (file )
106109 config = Config (** config )
107110
108- run_one_experiment (config )
111+ run_one_experiment (config , config_path )
112+
113+
114+ if __name__ == "__main__" :
115+ app ()
0 commit comments