1515import os
1616import tempfile
1717
18- import numpy as np
18+ import numpy
1919import onnx
2020import pytest
2121
2727 ORTModelRunner ,
2828 quantize_resnet_identity_add_inputs ,
2929)
30- from sparseml .pytorch .datasets import ImagenetteDataset , ImagenetteSize , MNISTDataset
30+ from sparseml .pytorch .datasets import ImagenetteDataset , ImagenetteSize
3131from sparsezoo import Zoo
3232
3333
@@ -66,7 +66,7 @@ def _test_quant_model_output(
6666 quant_outputs = list (quant_outputs .values ())
6767 # Check that the predicted values of outputs are the same
6868 for idx in test_output_idxs :
69- if np .argmax (base_outputs [idx ]) == np .argmax (quant_outputs [idx ]):
69+ if numpy .argmax (base_outputs [idx ]) == numpy .argmax (quant_outputs [idx ]):
7070 n_matches += 1
7171 # check that at least 98% match, should be higher in practice
7272 assert n_matches >= 98 * len (test_output_idxs )
@@ -82,47 +82,6 @@ def _test_resnet_identity_quant(model_path, has_resnet_block, save_optimized):
8282 onnx .save (quant_model , model_path )
8383
8484
85- @pytest .mark .skipif (
86- os .getenv ("NM_ML_SKIP_QUANTIZATION_TESTS" , False ),
87- reason = "Skipping quantization tests" ,
88- )
89- def test_quantize_model_post_training_mnist ():
90- # Prepare model paths
91- mnist_model_path = Zoo .search_models (
92- domain = "cv" ,
93- sub_domain = "classification" ,
94- architecture = "mnistnet" ,
95- framework = "pytorch" ,
96- )[0 ].onnx_file .downloaded_path ()
97- quant_model_path = tempfile .NamedTemporaryFile (suffix = ".onnx" , delete = False ).name
98-
99- # Prepare sample validation dataset
100- batch_size = 1
101- val_dataset = MNISTDataset (train = False )
102- input_dict = [{"input" : img .numpy ()} for (img , _ ) in val_dataset ]
103- data_loader = DataLoader (input_dict , None , batch_size )
104-
105- # Run calibration and quantization
106- quantize_model_post_training (
107- mnist_model_path , data_loader , quant_model_path , show_progress = False
108- )
109-
110- # Verify that ResNet identity has no affect
111- _test_resnet_identity_quant (quant_model_path , False , False )
112-
113- # Verify Convs and MatMuls are quantized
114- _test_model_is_quantized (mnist_model_path , quant_model_path )
115-
116- # Verify quant model accuracy
117- test_data_loader = DataLoader (input_dict , None , 1 ) # initialize a new generator
118- _test_quant_model_output (
119- mnist_model_path , quant_model_path , test_data_loader , [0 ], batch_size
120- )
121-
122- # Clean up
123- os .remove (quant_model_path )
124-
125-
12685@pytest .mark .skipif (
12786 os .getenv ("NM_ML_SKIP_QUANTIZATION_TESTS" , False ),
12887 reason = "Skipping quantization tests" ,
0 commit comments