Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit af6aab2

Browse files
rahul-tulibfineran
authored andcommitted
Remove: export to onnx by default (#591)
* Remove: `export` to `onnx` by default * Remove: not needed params
1 parent c4664fd commit af6aab2

File tree

2 files changed

+1
-11
lines changed

2 files changed

+1
-11
lines changed

src/sparseml/pytorch/image_classification/train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,6 @@ def train(
615615
helpers.save_model_training(
616616
model,
617617
optim,
618-
input_shape,
619618
"checkpoint-best",
620619
save_dir,
621620
epoch,
@@ -633,7 +632,6 @@ def train(
633632
helpers.save_model_training(
634633
model,
635634
optim,
636-
input_shape,
637635
f"checkpoint-{epoch:04d}-{val_metric:.04f}",
638636
save_dir,
639637
epoch,
@@ -648,7 +646,7 @@ def train(
648646
# only convert qat -> quantized ONNX graph for finalized model
649647
# TODO: change this to all checkpoints when conversion times improve
650648
helpers.save_model_training(
651-
model, optim, input_shape, "model", save_dir, epoch - 1, val_res, True
649+
model, optim, "model", save_dir, epoch - 1, val_res
652650
)
653651

654652
LOGGER.info("layer sparsities:")

src/sparseml/pytorch/image_classification/utils/helpers.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,10 @@ def save_recipe(
378378
def save_model_training(
379379
model: Module,
380380
optim: Optimizer,
381-
input_shape: Tuple[int, ...],
382381
save_name: str,
383382
save_dir: str,
384383
epoch: int,
385384
val_res: Union[ModuleRunResults, None],
386-
convert_qat: bool = False,
387385
):
388386
"""
389387
:param model: model architecture
@@ -404,12 +402,6 @@ def save_model_training(
404402
)
405403
exporter = ModuleExporter(model, save_dir)
406404
exporter.export_pytorch(optim, epoch, f"{save_name}.pth")
407-
exporter.export_onnx(
408-
torch.randn(1, *input_shape),
409-
f"{save_name}.onnx",
410-
convert_qat=convert_qat,
411-
)
412-
413405
info_path = os.path.join(save_dir, f"{save_name}.txt")
414406

415407
with open(info_path, "w") as info_file:

0 commit comments

Comments
 (0)