|
20 | 20 | import os |
21 | 21 | import warnings |
22 | 22 | from copy import deepcopy |
23 | | -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple |
| 23 | +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union |
24 | 24 |
|
25 | 25 | import numpy |
26 | 26 | import onnx |
|
42 | 42 | script_model, |
43 | 43 | trace_model, |
44 | 44 | ) |
45 | | -from sparseml.pytorch.utils.quantization import quantize_torch_qat_export |
| 45 | +from sparseml.pytorch.utils.quantization import ( |
| 46 | + quantize_torch_qat_export, |
| 47 | + skip_onnx_input_quantize, |
| 48 | +) |
46 | 49 | from sparseml.utils import clean_path, create_parent_dirs |
47 | 50 |
|
48 | 51 |
|
@@ -152,23 +155,6 @@ def export_to_zoo( |
152 | 155 | sample_originals=sample_originals, |
153 | 156 | ) |
154 | 157 |
|
155 | | - @classmethod |
156 | | - def get_output_names(cls, out: Any): |
157 | | - """ |
158 | | - Get name of output tensors |
159 | | -
|
160 | | - :param out: outputs of the model |
161 | | - :return: list of names |
162 | | - """ |
163 | | - output_names = None |
164 | | - if isinstance(out, Tensor): |
165 | | - output_names = ["output"] |
166 | | - elif isinstance(out, Iterable): |
167 | | - output_names = [ |
168 | | - "output_{}".format(index) for index, _ in enumerate(iter(out)) |
169 | | - ] |
170 | | - return output_names |
171 | | - |
172 | 158 | def export_onnx( |
173 | 159 | self, |
174 | 160 | sample_batch: Any, |
@@ -203,96 +189,16 @@ def export_onnx( |
203 | 189 | See more on the torch.onnx.export api spec in the PyTorch docs: |
204 | 190 | https://pytorch.org/docs/stable/onnx.html |
205 | 191 | """ |
206 | | - if not export_kwargs: |
207 | | - export_kwargs = {} |
208 | | - |
209 | | - if isinstance(sample_batch, Dict) and not isinstance( |
210 | | - sample_batch, collections.OrderedDict |
211 | | - ): |
212 | | - warnings.warn( |
213 | | - "Sample inputs passed into the ONNX exporter should be in " |
214 | | - "the same order defined in the model forward function. " |
215 | | - "Consider using OrderedDict for this purpose.", |
216 | | - UserWarning, |
217 | | - ) |
218 | | - |
219 | | - sample_batch = tensors_to_device(sample_batch, "cpu") |
220 | | - onnx_path = os.path.join(self._output_dir, name) |
221 | | - create_parent_dirs(onnx_path) |
222 | | - |
223 | | - with torch.no_grad(): |
224 | | - out = tensors_module_forward( |
225 | | - sample_batch, self._module, check_feat_lab_inp=False |
226 | | - ) |
227 | | - |
228 | | - if "input_names" not in export_kwargs: |
229 | | - if isinstance(sample_batch, Tensor): |
230 | | - export_kwargs["input_names"] = ["input"] |
231 | | - elif isinstance(sample_batch, Dict): |
232 | | - export_kwargs["input_names"] = list(sample_batch.keys()) |
233 | | - sample_batch = tuple( |
234 | | - [sample_batch[f] for f in export_kwargs["input_names"]] |
235 | | - ) |
236 | | - elif isinstance(sample_batch, Iterable): |
237 | | - export_kwargs["input_names"] = [ |
238 | | - "input_{}".format(index) |
239 | | - for index, _ in enumerate(iter(sample_batch)) |
240 | | - ] |
241 | | - if isinstance(sample_batch, List): |
242 | | - sample_batch = tuple( |
243 | | - sample_batch |
244 | | - ) # torch.onnx.export requires tuple |
245 | | - |
246 | | - if "output_names" not in export_kwargs: |
247 | | - export_kwargs["output_names"] = self.get_output_names(out) |
248 | | - |
249 | | - # disable active quantization observers because they cannot be exported |
250 | | - disabled_observers = [] |
251 | | - for submodule in self._module.modules(): |
252 | | - if ( |
253 | | - hasattr(submodule, "observer_enabled") |
254 | | - and submodule.observer_enabled[0] == 1 |
255 | | - ): |
256 | | - submodule.observer_enabled[0] = 0 |
257 | | - disabled_observers.append(submodule) |
258 | | - |
259 | | - is_quant_module = any( |
260 | | - hasattr(submodule, "qconfig") and submodule.qconfig |
261 | | - for submodule in self._module.modules() |
262 | | - ) |
263 | | - batch_norms_wrapped = False |
264 | | - if torch.__version__ >= "1.7" and not is_quant_module and disable_bn_fusing: |
265 | | - # prevent batch norm fusing by adding a trivial operation before every |
266 | | - # batch norm layer |
267 | | - export_module = deepcopy(self._module) |
268 | | - batch_norms_wrapped = _wrap_batch_norms(export_module) |
269 | | - else: |
270 | | - export_module = self._module |
271 | | - |
272 | | - torch.onnx.export( |
273 | | - export_module, |
274 | | - sample_batch, |
275 | | - onnx_path, |
276 | | - strip_doc_string=True, |
277 | | - verbose=False, |
278 | | - opset_version=opset, |
| 192 | + export_onnx( |
| 193 | + module=self._module, |
| 194 | + sample_batch=sample_batch, |
| 195 | + file_path=os.path.join(self._output_dir, name), |
| 196 | + opset=opset, |
| 197 | + disable_bn_fusing=disable_bn_fusing, |
| 198 | + convert_qat=convert_qat, |
279 | 199 | **export_kwargs, |
280 | 200 | ) |
281 | 201 |
|
282 | | - # re-enable disabled quantization observers |
283 | | - for submodule in disabled_observers: |
284 | | - submodule.observer_enabled[0] = 1 |
285 | | - |
286 | | - # clean up graph from any injected / wrapped operations |
287 | | - if batch_norms_wrapped: |
288 | | - onnx_model = onnx.load(onnx_path) |
289 | | - _delete_trivial_onnx_adds(onnx_model) |
290 | | - onnx.save(onnx_model, onnx_path) |
291 | | - |
292 | | - if convert_qat and is_quant_module: |
293 | | - # overwrite exported model with fully quantized version |
294 | | - quantize_torch_qat_export(model=onnx_path, output_file_path=onnx_path) |
295 | | - |
296 | 202 | def export_torchscript( |
297 | 203 | self, |
298 | 204 | name: str = "model.pts", |
@@ -425,6 +331,166 @@ def export_samples( |
425 | 331 | exp_counter += len(exported_input) |
426 | 332 |
|
427 | 333 |
|
| 334 | +def export_onnx( |
| 335 | + module: Module, |
| 336 | + sample_batch: Any, |
| 337 | + file_path: str, |
| 338 | + opset: int = DEFAULT_ONNX_OPSET, |
| 339 | + disable_bn_fusing: bool = True, |
| 340 | + convert_qat: bool = False, |
| 341 | + dynamic_axes: Union[str, Dict[str, List[int]]] = None, |
| 342 | + skip_input_quantize: bool = False, |
| 343 | + **export_kwargs, |
| 344 | +): |
| 345 | + """ |
| 346 | + Export an onnx file for the current module and for a sample batch. |
| 347 | + Sample batch used to feed through the model to freeze the graph for a |
| 348 | + particular execution. |
| 349 | +
|
| 350 | + :param module: torch Module object to export |
| 351 | + :param sample_batch: the batch to export an onnx for, handles creating the |
| 352 | + static graph for onnx as well as setting dimensions |
| 353 | + :param file_path: path to the onnx file to save |
| 354 | + :param opset: onnx opset to use for exported model. Default is 11, if torch |
| 355 | + version is 1.2 or below, default is 9 |
| 356 | + :param disable_bn_fusing: torch >= 1.7.0 only. Set True to disable batch norm |
| 357 | + fusing during torch export. Default and suggested setting is True. Batch |
| 358 | + norm fusing will change the exported parameter names as well as affect |
| 359 | + sensitivity analyses of the exported graph. Additionally, the DeepSparse |
| 360 | + inference engine, and other engines, perform batch norm fusing at model |
| 361 | + compilation. |
| 362 | + :param convert_qat: if True and quantization aware training is detected in |
| 363 | + the module being exported, the resulting QAT ONNX model will be converted |
| 364 | + to a fully quantized ONNX model using `quantize_torch_qat_export`. Default |
| 365 | + is False. |
| 366 | + :param dynamic_axes: dictionary of input or output names to list of dimensions |
| 367 | + of those tensors that should be exported as dynamic. May input 'batch' |
| 368 | + to set the first dimension of all inputs and outputs to dynamic. Default |
| 369 | + is an empty dict |
| 370 | + :param skip_input_quantize: if True, the export flow will attempt to delete |
| 371 | + the first Quantize Linear Nodes(s) immediately after model input and set |
| 372 | + the model input type to UINT8. Default is False |
| 373 | + :param export_kwargs: kwargs to be passed as is to the torch.onnx.export api |
| 374 | + call. Useful to pass in dyanmic_axes, input_names, output_names, etc. |
| 375 | + See more on the torch.onnx.export api spec in the PyTorch docs: |
| 376 | + https://pytorch.org/docs/stable/onnx.html |
| 377 | + """ |
| 378 | + if not export_kwargs: |
| 379 | + export_kwargs = {} |
| 380 | + |
| 381 | + if isinstance(sample_batch, Dict) and not isinstance( |
| 382 | + sample_batch, collections.OrderedDict |
| 383 | + ): |
| 384 | + warnings.warn( |
| 385 | + "Sample inputs passed into the ONNX exporter should be in " |
| 386 | + "the same order defined in the model forward function. " |
| 387 | + "Consider using OrderedDict for this purpose.", |
| 388 | + UserWarning, |
| 389 | + ) |
| 390 | + |
| 391 | + sample_batch = tensors_to_device(sample_batch, "cpu") |
| 392 | + create_parent_dirs(file_path) |
| 393 | + |
| 394 | + module = deepcopy(module).cpu().eval() |
| 395 | + |
| 396 | + with torch.no_grad(): |
| 397 | + out = tensors_module_forward(sample_batch, module, check_feat_lab_inp=False) |
| 398 | + |
| 399 | + if "input_names" not in export_kwargs: |
| 400 | + if isinstance(sample_batch, Tensor): |
| 401 | + export_kwargs["input_names"] = ["input"] |
| 402 | + elif isinstance(sample_batch, Dict): |
| 403 | + export_kwargs["input_names"] = list(sample_batch.keys()) |
| 404 | + sample_batch = tuple( |
| 405 | + [sample_batch[f] for f in export_kwargs["input_names"]] |
| 406 | + ) |
| 407 | + elif isinstance(sample_batch, Iterable): |
| 408 | + export_kwargs["input_names"] = [ |
| 409 | + "input_{}".format(index) for index, _ in enumerate(iter(sample_batch)) |
| 410 | + ] |
| 411 | + if isinstance(sample_batch, List): |
| 412 | + sample_batch = tuple(sample_batch) # torch.onnx.export requires tuple |
| 413 | + |
| 414 | + if "output_names" not in export_kwargs: |
| 415 | + export_kwargs["output_names"] = _get_output_names(out) |
| 416 | + |
| 417 | + if dynamic_axes == "batch": |
| 418 | + dynamic_axes = { |
| 419 | + tensor_name: {0: "batch"} |
| 420 | + for tensor_name in ( |
| 421 | + export_kwargs["input_names"] + export_kwargs["output_names"] |
| 422 | + ) |
| 423 | + } |
| 424 | + |
| 425 | + # disable active quantization observers because they cannot be exported |
| 426 | + disabled_observers = [] |
| 427 | + for submodule in module.modules(): |
| 428 | + if ( |
| 429 | + hasattr(submodule, "observer_enabled") |
| 430 | + and submodule.observer_enabled[0] == 1 |
| 431 | + ): |
| 432 | + submodule.observer_enabled[0] = 0 |
| 433 | + disabled_observers.append(submodule) |
| 434 | + |
| 435 | + is_quant_module = any( |
| 436 | + hasattr(submodule, "qconfig") and submodule.qconfig |
| 437 | + for submodule in module.modules() |
| 438 | + ) |
| 439 | + batch_norms_wrapped = False |
| 440 | + if torch.__version__ >= "1.7" and not is_quant_module and disable_bn_fusing: |
| 441 | + # prevent batch norm fusing by adding a trivial operation before every |
| 442 | + # batch norm layer |
| 443 | + batch_norms_wrapped = _wrap_batch_norms(module) |
| 444 | + |
| 445 | + torch.onnx.export( |
| 446 | + module, |
| 447 | + sample_batch, |
| 448 | + file_path, |
| 449 | + strip_doc_string=True, |
| 450 | + verbose=False, |
| 451 | + opset_version=opset, |
| 452 | + dynamic_axes=dynamic_axes, |
| 453 | + **export_kwargs, |
| 454 | + ) |
| 455 | + |
| 456 | + # re-enable disabled quantization observers |
| 457 | + for submodule in disabled_observers: |
| 458 | + submodule.observer_enabled[0] = 1 |
| 459 | + |
| 460 | + # clean up graph from any injected / wrapped operations |
| 461 | + if batch_norms_wrapped: |
| 462 | + onnx_model = onnx.load(file_path) |
| 463 | + _delete_trivial_onnx_adds(onnx_model) |
| 464 | + onnx.save(onnx_model, file_path) |
| 465 | + |
| 466 | + if convert_qat and is_quant_module: |
| 467 | + # overwrite exported model with fully quantized version |
| 468 | + quantize_torch_qat_export(model=file_path, output_file_path=file_path) |
| 469 | + |
| 470 | + if skip_input_quantize: |
| 471 | + try: |
| 472 | + skip_onnx_input_quantize(file_path, file_path) |
| 473 | + except Exception as e: |
| 474 | + _LOGGER.warning( |
| 475 | + f"Unable to skip input QuantizeLinear op with exception {e}" |
| 476 | + ) |
| 477 | + |
| 478 | + |
| 479 | +def _get_output_names(out: Any): |
| 480 | + """ |
| 481 | + Get name of output tensors |
| 482 | +
|
| 483 | + :param out: outputs of the model |
| 484 | + :return: list of names |
| 485 | + """ |
| 486 | + output_names = None |
| 487 | + if isinstance(out, Tensor): |
| 488 | + output_names = ["output"] |
| 489 | + elif isinstance(out, Iterable): |
| 490 | + output_names = ["output_{}".format(index) for index, _ in enumerate(iter(out))] |
| 491 | + return output_names |
| 492 | + |
| 493 | + |
428 | 494 | class _AddNoOpWrapper(Module): |
429 | 495 | # trivial wrapper to break-up Conv-BN blocks |
430 | 496 |
|
|
0 commit comments