@@ -176,6 +176,7 @@ def export_onnx(
176176 opset : int = DEFAULT_ONNX_OPSET ,
177177 disable_bn_fusing : bool = True ,
178178 convert_qat : bool = False ,
179+ ** export_kwargs ,
179180 ):
180181 """
181182 Export an onnx file for the current module and for a sample batch.
@@ -197,7 +198,14 @@ def export_onnx(
197198 the module being exported, the resulting QAT ONNX model will be converted
198199 to a fully quantized ONNX model using `quantize_torch_qat_export`. Default
199200 is False.
201+ :param export_kwargs: kwargs to be passed as is to the torch.onnx.export api
202+ call. Useful to pass in dyanmic_axes, input_names, output_names, etc.
203+ See more on the torch.onnx.export api spec in the PyTorch docs:
204+ https://pytorch.org/docs/stable/onnx.html
200205 """
206+ if not export_kwargs :
207+ export_kwargs = {}
208+
201209 if isinstance (sample_batch , Dict ) and not isinstance (
202210 sample_batch , collections .OrderedDict
203211 ):
@@ -217,20 +225,26 @@ def export_onnx(
217225 sample_batch , self ._module , check_feat_lab_inp = False
218226 )
219227
220- input_names = None
221- if isinstance (sample_batch , Tensor ):
222- input_names = ["input" ]
223- elif isinstance (sample_batch , Dict ):
224- input_names = list (sample_batch .keys ())
225- sample_batch = tuple ([sample_batch [f ] for f in input_names ])
226- elif isinstance (sample_batch , Iterable ):
227- input_names = [
228- "input_{}" .format (index ) for index , _ in enumerate (iter (sample_batch ))
229- ]
230- if isinstance (sample_batch , List ):
231- sample_batch = tuple (sample_batch ) # torch.onnx.export requires tuple
228+ if "input_names" not in export_kwargs :
229+ if isinstance (sample_batch , Tensor ):
230+ export_kwargs ["input_names" ] = ["input" ]
231+ elif isinstance (sample_batch , Dict ):
232+ export_kwargs ["input_names" ] = list (sample_batch .keys ())
233+ sample_batch = tuple (
234+ [sample_batch [f ] for f in export_kwargs ["input_names" ]]
235+ )
236+ elif isinstance (sample_batch , Iterable ):
237+ export_kwargs ["input_names" ] = [
238+ "input_{}" .format (index )
239+ for index , _ in enumerate (iter (sample_batch ))
240+ ]
241+ if isinstance (sample_batch , List ):
242+ sample_batch = tuple (
243+ sample_batch
244+ ) # torch.onnx.export requires tuple
232245
233- output_names = self .get_output_names (out )
246+ if "output_names" not in export_kwargs :
247+ export_kwargs ["output_names" ] = self .get_output_names (out )
234248
235249 # disable active quantization observers because they cannot be exported
236250 disabled_observers = []
@@ -259,11 +273,10 @@ def export_onnx(
259273 export_module ,
260274 sample_batch ,
261275 onnx_path ,
262- input_names = input_names ,
263- output_names = output_names ,
264276 strip_doc_string = True ,
265277 verbose = False ,
266278 opset_version = opset ,
279+ ** export_kwargs ,
267280 )
268281
269282 # re-enable disabled quantization observers
0 commit comments