@@ -16,7 +16,7 @@ def build_runner(
1616
1717 Parameters
1818 ----------
19- model_type: str, optional
19+ model_type: str
2020 Which model to use. For the PyTorch engine, options are [`pytorch`]. For the
2121 TensorFlow engine, options are [`base`, `tensorrt`, `lite`].
2222 model_path: str, Path
@@ -33,13 +33,13 @@ def build_runner(
3333 -------
3434
3535 """
36- if model_type . lower ( ) == "pytorch" :
36+ if Engine . from_model_type ( model_type ) == Engine . PYTORCH :
3737 from dlclive .pose_estimation_pytorch .runner import PyTorchRunner
3838
3939 valid = {"device" , "precision" , "single_animal" , "dynamic" , "top_down_config" }
4040 return PyTorchRunner (model_path , ** filter_keys (valid , kwargs ))
4141
42- elif model_type . lower () in ( "tensorflow" , "base" , "tensorrt" , "lite" ) :
42+ elif Engine . from_model_type ( model_type ) == Engine . TENSORFLOW :
4343 from dlclive .pose_estimation_tensorflow .runner import TensorFlowRunner
4444
4545 if model_type .lower () == "tensorflow" :
@@ -54,3 +54,19 @@ def build_runner(
5454def filter_keys (valid : set [str ], kwargs : dict ) -> dict :
5555 """Filters the keys in kwargs, only keeping those in valid."""
5656 return {k : v for k , v in kwargs .items () if k in valid }
57+
58+
59+ from enum import Enum
60+
61+ class Engine (Enum ):
62+ TENSORFLOW = "tensorflow"
63+ PYTORCH = "pytorch"
64+
65+ @classmethod
66+ def from_model_type (cls , model_type : str ) -> "Engine" :
67+ if model_type .lower () == "pytorch" :
68+ return cls .PYTORCH
69+ elif model_type .lower () in ("tensorflow" , "base" , "tensorrt" , "lite" ):
70+ return cls .TENSORFLOW
71+ else :
72+ raise ValueError (f"Unknown model type: { model_type } " )
0 commit comments