|
26 | 26 | import click |
27 | 27 | from sparseml.pytorch.models.registry import ModelRegistry |
28 | 28 | from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET |
| 29 | +from sparseml.pytorch.optim.manager import ScheduledModifierManager |
29 | 30 | from sparseml.pytorch.torchvision import presets |
30 | 31 | from sparseml.pytorch.utils import ModuleExporter |
31 | 32 | from sparseml.pytorch.utils.model import load_model |
|
60 | 61 | help="The root dir path where the dataset is stored or should " |
61 | 62 | "be downloaded to if available", |
62 | 63 | ) |
| 64 | +@click.option( |
| 65 | + "--one-shot", |
| 66 | + default=None, |
| 67 | + type=str, |
| 68 | + help="Path to recipe to use to apply in a one-shot manner", |
| 69 | +) |
63 | 70 | @click.option( |
64 | 71 | "--labels-to-class-mapping", |
65 | 72 | type=click.Path(dir_okay=False, file_okay=True, exists=True, path_type=Path), |
@@ -118,6 +125,7 @@ def main( |
118 | 125 | arch_key: str, |
119 | 126 | checkpoint_path: str, |
120 | 127 | dataset_path: Path, |
| 128 | + one_shot: Optional[str], |
121 | 129 | labels_to_class_mapping: Optional[Path], |
122 | 130 | num_samples: int, |
123 | 131 | onnx_opset: int, |
@@ -159,6 +167,9 @@ def main( |
159 | 167 |
|
160 | 168 | load_model(checkpoint_path, model, strict=True) |
161 | 169 |
|
| 170 | + if one_shot is not None: |
| 171 | + ScheduledModifierManager.from_yaml(one_shot).apply(model) |
| 172 | + |
162 | 173 | if labels_to_class_mapping is not None: |
163 | 174 | with open(labels_to_class_mapping) as fp: |
164 | 175 | labels_to_class_mapping = json.load(fp) |
|
0 commit comments