diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index ab94f203c231..d3b1522e50c9 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -178,6 +178,12 @@ def __init__(self, module): self._load_params = module["load_params"] self._share_params = module["share_params"] + self._get_workspace_dtype = module["get_workspace_dtype"] + self._get_workspace_size = module["get_workspace_size"] + self._get_function_list = module["get_function_list"] + self._get_storageid = module["get_storageid"] + self._get_output_eid = module["get_output_eid"] + def set_input(self, key=None, value=None, **params): """Set inputs to the module via kwargs @@ -512,3 +518,49 @@ def benchmark( cooldown_interval_ms=cooldown_interval_ms, repeats_to_cooldown=repeats_to_cooldown, )() + + def get_workspace_dtype(self): + """Get the dtype of workspace to the graph + + Returns + ------- + dtype : str + The dtypes of workspace. + """ + return self._get_workspace_dtype() + + def get_workspace_size(self): + """Get the dtype of workspace to the graph + + Returns + ------- + dtype : int + The bytes size of workspace. + """ + return self._get_workspace_size() + + def get_function_list(self): + """Get the Host Function execute order + + Returns + ------- + dtype : str + The Host function execute order + """ + return self._get_function_list() + + def get_storageid(self): + return self._get_storageid() + + def get_output_eid(self, index): + """Get index-th output to out + + Parameters + ---------- + index : int + The output index + + out : NDArray + The output array container + """ + return self._get_output_eid(index) \ No newline at end of file diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py index eee3169400ff..9095ae8e59d5 100644 --- a/python/tvm/relay/backend/executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -180,6 +180,7 @@ def __init__( libmod_name, params, function_metadata, + constant_params = None ): assert isinstance(graph_json_str, string_types) fcreate = get_global_func("tvm.graph_executor_factory.create") @@ -199,6 +200,12 @@ def __init__( self.iter_cnt = 0 self.function_metadata = function_metadata + self.constant_params = constant_params + self.device_function_list = get_global_func("tir.transform.retrieve_device_function_list") + self.device_function_thread_config = get_global_func("runtime.module.retrieve_device_function_thread_config") + self.device_memory_size = get_global_func("tir.transform.retrieve_device_memory_size") + + def export_library(self, file_name, fcompile=None, addons=None, **kwargs): return self.module.export_library(file_name, fcompile, addons, **kwargs) @@ -216,3 +223,15 @@ def get_executor_config(self): def get_lib(self): return self.lib + + def get_constant_params(self): + return self.constant_params + + def get_device_function_list(self): + return self.device_function_list() + + def get_grid_block_thread_config(self): + return self.device_function_thread_config() + + def get_device_memory_size(self): + return self.device_memory_size() \ No newline at end of file diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 40a91cc75a00..1621255d3df2 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -70,6 +70,7 @@ def __init__(self): self._get_executor_codegen_metadata = self.mod["get_executor_codegen_metadata"] self._get_devices = self.mod["get_devices"] self._get_irmodule = self.mod["get_irmodule"] + self._get_constant_params = self.mod["get_constant_params"] def build( self, @@ -249,6 +250,14 @@ def get_params(self): ret[key] = value.data return ret + def get_constant_params(self): + """Return the constant params.""" + params = self._get_constant_params() + ret = {} + for key, value in params.items(): + ret[key] = value.data.asnumpy() + return ret + def get_irmodule(self): """Returns the TargetIRModule's post-lowering""" return self._get_irmodule() @@ -372,6 +381,7 @@ def build( mod_name=mod_name, ) func_metadata = bld_mod.get_function_metadata() + constant_params = bld_mod.get_constant_params() devices = bld_mod.get_devices() lowered_ir_mods = bld_mod.get_irmodule() executor_codegen_metadata = bld_mod.get_executor_codegen_metadata() @@ -400,6 +410,7 @@ def build( mod_name, params, func_metadata, + constant_params=constant_params ) else: assert False, "Executor " + executor + " not supported" diff --git a/python/tvm/tpat/__init__.py b/python/tvm/tpat/__init__.py new file mode 100644 index 000000000000..44b1fdcc5697 --- /dev/null +++ b/python/tvm/tpat/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from . import cuda \ No newline at end of file diff --git a/python/tvm/tpat/cuda/__init__.py b/python/tvm/tpat/cuda/__init__.py new file mode 100644 index 000000000000..ee0bce8a0d32 --- /dev/null +++ b/python/tvm/tpat/cuda/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .pipeline import pipeline \ No newline at end of file diff --git a/python/tvm/tpat/cuda/kernel.py b/python/tvm/tpat/cuda/kernel.py new file mode 100644 index 000000000000..80877d4892e9 --- /dev/null +++ b/python/tvm/tpat/cuda/kernel.py @@ -0,0 +1,233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +import tvm +import tvm.contrib.graph_executor as runtime +import tvm.relay as relay +from tvm import dlight +from tvm import meta_schedule as ms + + +class Config(object): + def __init__(self, name, onnx_model, input_shapes, target, tunning_option) -> None: + self.name = name + self.onnx_model = onnx_model + self.input_shapes = input_shapes + self.tunning_option = tunning_option + self.work_dir = ( + f"{tunning_option['work_dir']}/{name}" + if tunning_option["work_dir"] + else f"./log_db/{name}" + ) + + if target == "gpu": + self.target = self._detect_cuda_target() + + def _tune_option(self): + default = { + "target": self.target, + "builder": ms.builder.LocalBuilder(), + "runner": ms.runner.LocalRunner(), + "max_trials_global": 1000, + "max_trials_per_task": 100, + } + + default.update(self.tunning_option) + default["work_dir"] = self.work_dir + + return default + + def _detect_cuda_target(self): + dev = tvm.cuda() + if not dev.exist: + return None + + return tvm.target.Target( + { + "kind": "cuda", + "max_shared_memory_per_block": dev.max_shared_memory_per_block, + "max_threads_per_block": dev.max_threads_per_block, + "thread_warp_size": dev.warp_size, + "registers_per_block": 65536, + "arch": "sm_" + tvm.cuda().compute_version.replace(".", ""), + } + ) + + +class Kernel(object): + def __init__(self, name, onnx_model, input_shapes, enable_tunning, tunning_option): + self._name = name + self._enable_tunning = enable_tunning + self._config = Config(name, onnx_model, input_shapes, "gpu", tunning_option) + + self._lib = None + self._module = None + + def run(self): + """ + Tvm Auto Scheduler + """ + + # 1. Model -> Relay + mod, params = relay.frontend.from_onnx(self._config.onnx_model) + + # 2. Tune it + if self._enable_tunning and not os.path.exists(self._config.work_dir): + tunning_option = self._config._tune_option() + ms.relay_integration.tune_relay(mod=mod, params=params, **tunning_option) + + # 3. Compiling + try: + if self._enable_tunning: + db = ms.Database.create(kind="json", work_dir=self._config.work_dir) + with db, self._config.target as target, tvm.transform.PassContext(opt_level=3): + mod = dlight.ApplyDefaultSchedule(dlight.gpu.Fallback())(mod) # type: ignore + mod = tvm.tir.transform.ForceNarrowIndexToInt32()(mod) + lib = ms.relay_integration.compile_relay( + database=db, + mod=mod, + target=target, + params=params, + ) + else: + with self._config.target as target, tvm.transform.PassContext(opt_level=3): + mod = dlight.ApplyDefaultSchedule(dlight.gpu.Fallback())(mod) # type: ignore + mod = tvm.tir.transform.ForceNarrowIndexToInt32()(mod) + lib = relay.build(mod, target=target, params=params) + + # load parameters + dev = tvm.cuda(0) + module_exec = runtime.GraphModule(lib["default"](dev)) # type: ignore + + self._lib = lib + self._module = module_exec + + # 4. Running + self._module.run() + except Exception as e: + print("[ERROR]: ", e) + self._lib = None + self._module = None + + @property + def build_module(self): + return self._lib + + @property + def graph_module(self): + return self._module + + @property + def cuda_source_code(self): + """Return source code of this kernel. + + Returns + ------- + str + source code of kernel + """ + if not self._lib: + return None + + try: + source_code = self._lib.get_lib().imported_modules[0].get_source() + # consistent type + source_code = source_code.replace("signed char*", "int*") + source_code = source_code.replace("uint64_t*", "int*") + source_code = source_code.replace("long long", "int") + source_code = source_code.replace("double", "float") + except IndexError: + return None + return source_code + + @property + def constant_params(self): + """Get constant params of the built module. + + It's a map, whose key is the storage id of param, + value is the numpy data of param. + """ + return self._lib.get_constant_params() if self._lib else None + + @property + def device_function_list(self): + """Get a list of functions which will executed by device. + + The format is: param1 param2 ... paramn. + + If param is in constant params list, it will be an address, + or it will be an index which indicates the order of it. + """ + return self._lib.get_device_function_list() if self._lib else None + + @property + def device_function_thread_config(self): + """Get block and grid dim config for kernel functions. + + The format is: grid=(x, y, z) block=(x, y, z). + """ + return self._lib.get_grid_block_thread_config() if self._lib else None + + @property + def device_allocate_memory_size(self): + """Get allocate memory for kernel functions. + + The format is: + """ + return self._lib.get_device_memory_size() if self._lib else None + + @property + def num_inputs(self): + """Get input number of node.""" + return self._module.get_num_inputs() if self._module else None + + @property + def num_outputs(self): + """Get output number of node.""" + return self._module.get_num_outputs() if self._module else None + + @property + def workspace_dtype(self): + """Get dtype of inputs and outputs. + + You can use dtype.split()[eid] to get workspace type of specific entry id. + """ + return self._module.get_workspace_dtype() if self._module else None + + @property + def workspace_size(self): + """Get size of inputs and outputs. + + You can use size.split()[eid] to get workspace size of specific entry id. + """ + return self._module.get_workspace_size() if self._module else None + + @property + def host_function_list(self): + """Get host function list.""" + return self._module.get_function_list() if self._module else None + + @property + def storageid(self): + """Get storage id.""" + return self._module.get_storageid() if self._module else None + + @property + def plugin_name(self): + return self._name diff --git a/python/tvm/tpat/cuda/onnx_util.py b/python/tvm/tpat/cuda/onnx_util.py new file mode 100644 index 000000000000..dd2ef1ab0c33 --- /dev/null +++ b/python/tvm/tpat/cuda/onnx_util.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +import onnx +import onnx_graphsurgeon as gs +from onnx import shape_inference + +from .type_mapping import onnx_type_mapping + + +def load_model(onnx_file): + try: + onnx_model = onnx.load(onnx_file) + inferred_model = shape_inference.infer_shapes(onnx_model) + except: + dummy_file = "tensor_shape_inference.onnx" + shape_inference.infer_shapes_path(onnx_file, output_path=dummy_file) + inferred_model = onnx.load(dummy_file) + os.remove(dummy_file) + + return inferred_model + + +def _handle_trt_not_support_type( + graph, + output_model_path, + node_name_to_plugin_name, + onnx_original_tensor_type, +): + count = 0 + insert_cast_nodes = False + + for node in graph.nodes: + if node.name in node_name_to_plugin_name: + node.op = node_name_to_plugin_name[node.name] + for i, inp in enumerate(node.inputs): + if inp.is_empty(): + node.inputs.remove(inp) + graph.cleanup() + continue + if onnx_original_tensor_type[inp.name] in onnx_type_mapping: + cast_node = gs.Node( + op="Cast", + name="cast_to_int32_for_" + inp.name.split(":")[0], + attrs={"to": 6}, + ) # 6: INT32 + + cast_node.inputs = [inp] + cast_node_out = gs.Variable(cast_node.name + ":0") + cast_node.outputs = [cast_node_out] + node.inputs[i] = cast_node_out + graph.nodes.append(cast_node) + graph.cleanup() + insert_cast_nodes = True + for i, oup in enumerate(node.outputs): + if onnx_original_tensor_type[oup.name] in onnx_type_mapping: + dtype = onnx_type_mapping[onnx_original_tensor_type[oup.name]] + cast_node = gs.Node( + op="Cast", + name="cast_back_for_" + oup.name.split(":")[0], + attrs={"to": dtype}, + ) + + cast_node.outputs = [oup] + cast_node_out = gs.Variable(cast_node.name + ":0") + cast_node.inputs = [cast_node_out] + node.outputs[i] = cast_node_out + graph.nodes.append(cast_node) + graph.cleanup() + insert_cast_nodes = True + count = count + 1 + assert count == len(node_name_to_plugin_name) + if insert_cast_nodes: + _remove_unnecessary_cast_nodes(graph) + + try: + onnx.save(gs.export_onnx(graph), output_model_path["name"]) + except: + onnx.save( + gs.export_onnx(graph), + output_model_path["name"], + save_as_external_data=True, + location=output_model_path["weights"], + ) + + +def _remove_unnecessary_cast_nodes(graph): + graph.toposort() + cast_nodes = [ + node + for node in graph.nodes + if (node.op == "Cast" and node.outputs[0] not in graph.outputs and node.o().op == "Cast") + ] + for node in cast_nodes: + if ( + node.attrs["to"] == 13 # uint64 + and len(node.inputs[0].inputs) <= 1 + and len(node.outputs[0].outputs) <= 1 + ): + node.o().inputs = node.inputs + node.inputs.clear() + graph.cleanup() + + +def _compute_tensor_type(graph, tunning_nodes): + onnx_original_tensor_type = {} + + for tunning_node in tunning_nodes: + for inp in tunning_node.inputs: + if inp.is_empty(): + continue + onnx_original_tensor_type[inp.name] = inp.dtype.name + + for oup in tunning_node.outputs: + if oup.is_empty(): + continue + onnx_original_tensor_type[oup.name] = oup.dtype.name + + return onnx_original_tensor_type + + +def rewrite( + model, + tunning_nodes, + node_name_to_plugin_name, + output_model_path, +): + """ + Insert cast operator for operators which inputs or outputs has bool type. + Modify operator type in onnx model for tensorRT can run plugin. + """ + + graph = gs.import_onnx(model) + + _onnx_original_tensor_type = _compute_tensor_type(graph, tunning_nodes) + + _handle_trt_not_support_type( + graph, + output_model_path, + node_name_to_plugin_name, + _onnx_original_tensor_type, + ) diff --git a/python/tvm/tpat/cuda/pipeline.py b/python/tvm/tpat/cuda/pipeline.py new file mode 100644 index 000000000000..0b9143ce0db6 --- /dev/null +++ b/python/tvm/tpat/cuda/pipeline.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from typing import Tuple + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import onnxruntime as ort + +from tvm.tpat.cuda.kernel import Kernel +from tvm.tpat.cuda.template import StaticBatchPluginTemplate +from tvm.tpat.cuda.template_params import PluginTemplateParams + +from tvm.tpat.cuda.onnx_util import rewrite, load_model + + +def _enhance_onnx_shape(graph, inputs, outputs): + graph.outputs = [] + graph.outputs.extend(inputs) + graph.outputs.extend(outputs) + + graph.cleanup() + + half_model = gs.export_onnx(graph) + half_model_path = "half_model.onnx" + onnx.save(half_model, half_model_path) + + EP_list = ["CPUExecutionProvider", "CUDAExecutionProvider"] + session = ort.InferenceSession(half_model_path, providers=EP_list) + outname = [output.name for output in session.get_outputs()] + dummy_input = {} + for gi in graph.inputs: + dummy_input[gi.name] = (1 + np.random.random([int(i) for i in gi.shape])).astype(gi.dtype) + dummy_output = session.run(outname, dummy_input) + + tensor_shapes = [] + for i in range(len(inputs)): + assert inputs[i].name == outname[i] + tensor_shapes.append(dummy_output[i].shape) + for i in range(len(outputs)): + assert outputs[i].name == outname[len(inputs) + i] + tensor_shapes.append(dummy_output[len(inputs) + i].shape) + os.remove(half_model_path) + return tensor_shapes + + +def _extract_target_onnx_node(model, tunning_node): + """ + Extract target node from onnx graph + """ + + graph = gs.import_onnx(model) + + tensors = graph.tensors() + + subgraph_inputs = [ + tensors[inp.name].to_variable(dtype=inp.dtype, shape=inp.shape) + for inp in tunning_node.inputs + if (inp.__class__ == gs.Variable and not inp.is_empty()) + ] + subgraph_outputs = [ + tensors[oup.name].to_variable(dtype=oup.dtype, shape=oup.shape) + for oup in tunning_node.outputs + ] + + computed_tensor_shapes = _enhance_onnx_shape(graph, subgraph_inputs, subgraph_outputs) + + for i in range(len(subgraph_inputs)): + subgraph_inputs[i].shape = computed_tensor_shapes[i] + for i in range(len(subgraph_outputs)): + subgraph_outputs[i].shape = computed_tensor_shapes[len(subgraph_inputs) + i] + + input_shapes = [(inp.name, inp.shape, inp.dtype.name) for inp in subgraph_inputs] + output_shapes = [oup.shape for oup in subgraph_outputs] + + graph.inputs = subgraph_inputs + graph.outputs = subgraph_outputs + graph.cleanup() + submodel = gs.export_onnx(graph) + + return submodel, input_shapes, output_shapes + + +def _get_node_to_be_tunned(model, node_names): + graph = gs.import_onnx(model) + + # 2. retrieve all node which need to transform to plugins + if node_names is None or len(node_names) == 0: + return [] + + node_to_be_tunned = [node for node in graph.nodes if node.name in node_names] + + return node_to_be_tunned + + +def pipeline( + onnx_file: str, + node_names: list[str], + enable_tunning: bool, + tunning_option: object, + output_onnx: object, +) -> Tuple[str, list[str]]: + """Generate plugins for specified nodes in an ONNX model. + + This function is the entry point for generating plugins for specific nodes as requested by users. + + Parameters + ---------- + onnx_file : str + Path to the input ONNX file. + node_names : list[str] + Names of the nodes to be generated as TensorRT plugins. + enable_tunning : bool + Flag indicating whether tunning is enabled. + tunning_option : object + Tunning option provided for ms.relay_integration.tune_relay, you don't need to specify mod, params and target. + output_onnx : object + { "name": xx, "weights": xx } + Path to the output ONNX file where the modified model will be saved. + It will firstly try to save without weights, if it fails, it will then + save it with weights. + + Returns + ------- + Tuple[str, List[str]] + A tuple containing the path to the output ONNX file and a list of generated plugin paths. + """ + + # 1. load onnx and inference shapes + model = load_model(onnx_file) + + # 2. retrieve all node which need to transform to plugins + node_to_be_tunned = _get_node_to_be_tunned(model, node_names) + + assert len(node_to_be_tunned) > 0, "The number of nodes to be tunned should larger than zero" + + # 3. generate plugins for each of them + node_name_to_plugin_name = {} + plugin_path = [] + for node in node_to_be_tunned: + name = node.name + print(f"Processing ---- {name}") + plugin_name = "tpat_{}".format(name.replace("/", "_").replace(".", "_")) + + submodel, input_shapes, output_shapes = _extract_target_onnx_node(model, node) + + try: + kernel = Kernel(plugin_name, submodel, input_shapes, enable_tunning, tunning_option) + kernel.run() + + ## 3.1 fill in template + params = PluginTemplateParams(kernel, submodel, output_shapes, node, name) + template = StaticBatchPluginTemplate(params) + lib = template.fill() + + if lib: + plugin_path.append(lib) + node_name_to_plugin_name[name] = plugin_name + except Exception as e: + print(f"Skip {name}, ERROR:: {e}") + continue + + # 4. generate the modified onnx + rewrite(model, node_to_be_tunned, node_name_to_plugin_name, output_onnx) + + return output_onnx, plugin_path diff --git a/python/tvm/tpat/cuda/plugin/Makefile b/python/tvm/tpat/cuda/plugin/Makefile new file mode 100644 index 000000000000..1aa97fcb7b62 --- /dev/null +++ b/python/tvm/tpat/cuda/plugin/Makefile @@ -0,0 +1,77 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Variables need to be defined by Users +CUDA_PATH = /path/to/cuda +CUDNN_PATH = /path/to/cudnn +TRT_PATH = /path/to/TensorRT +ARCH = sm_86 +######################################## + +CUDA_INC_PATH = $(CUDA_PATH)/include +CUDA_LIB_PATH = $(CUDA_PATH)/lib +CUDA_COM_PATH = $(CUDA_PATH)/samples/common/inc + +CUDNN_INC_PATH = $(CUDNN_PATH)/include +CUDNN_LIB_PATH = $(CUDNN_PATH)/lib + +TRT_INC_PATH = $(TRT_PATH)/include +TRT_LIB_PATH = $(TRT_PATH)/lib + +GCC = g++ +NVCC = $(CUDA_PATH)/bin/nvcc +CCFLAGS = -w -std=c++11 +INCLUDES := -I. -I$(CUDA_COM_PATH) -I$(CUDA_INC_PATH) -I$(CUDNN_INC_PATH) -I$(TRT_INC_PATH) -I/usr/include + +LDFLAGS := -L$(CUDA_LIB_PATH) -L$(CUDNN_LIB_PATH) -L$(TRT_LIB_PATH) +LDFLAGS += -lnvinfer -lcudart -lcuda + +LDFLAGS += -Wl,-rpath=$(CUDA_LIB_PATH) +LDFLAGS += -Wl,-rpath=$(CUDNN_LIB_PATH) +LDFLAGS += -Wl,-rpath=$(TRT_LIB_PATH) + +SO = $(plugin_name).so +OBJ = $(shell find . -name '*.o') +DEP = $(OBJ:.o=.d) + +SRCDIR := ./src +OBJDIR := ./obj +LIBDIR := ./lib + +all: $(SO) + +$(plugin_name).so: $(plugin_name).o + +-include $(DEP) + +clean: + rm -rf $(LIBDIR)/$(SO) $(OBJDIR)/* + +%.o: $(SRCDIR)/%.cpp + $(AT)if [ ! -d $(OBJDIR) ]; then mkdir -p $(OBJDIR); fi + $(GCC) $(CCFLAGS) -fPIC -MD -MP $(INCLUDES) -o $@ -c $< + +%.o: $(SRCDIR)/%.cu + $(AT)if [ ! -d $(OBJDIR) ]; then mkdir -p $(OBJDIR); fi + $(NVCC) $(CCFLAGS) -M -MT $@ $(INCLUDES) -o $(@:.o=.d) $< + $(NVCC) $(CCFLAGS) $(INCLUDES) -Xcompiler -fPIC -arch=$(ARCH) -o $@ -c $< + +$(SO): + $(GCC) $(CCFLAGS) -shared -o $@ $+ $(LDFLAGS) + $(AT)if [ ! -d $(LIBDIR) ]; then mkdir -p $(LIBDIR); fi + $(AT) mv *.o $(OBJDIR)/ + $(AT) mv *.d $(OBJDIR)/ + $(AT) mv *.so $(LIBDIR)/ diff --git a/python/tvm/tpat/cuda/plugin/trt8.0_plugin_cu.template b/python/tvm/tpat/cuda/plugin/trt8.0_plugin_cu.template new file mode 100644 index 000000000000..48f843f19741 --- /dev/null +++ b/python/tvm/tpat/cuda/plugin/trt8.0_plugin_cu.template @@ -0,0 +1,54 @@ +#include "{{plugin_name}}.h" +#include +#include +#include +#include +#include + +#define BLOCKSIZE_X 16 +#define BLOCKSIZE_Y 16 + +using namespace nvinfer1; +using namespace plugin; + +// CUDA Runtime error messages +#ifdef __DRIVER_TYPES_H__ +static const char *_cudaGetErrorEnum(cudaError_t error) +{ + return cudaGetErrorName(error); +} +#endif + +template +void check(T result, char const *const func, const char *const file, + int const line) +{ + if (result) + { + fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line, + static_cast(result), _cudaGetErrorEnum(result), func); + exit(EXIT_FAILURE); + } +} +#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__) + + +{{plugin_source_code}} + +PluginFieldCollection {{plugin_name}}Creator::mFC{}; +std::vector {{plugin_name}}Creator::mPluginAttributes; + +int {{plugin_name}}::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { + {% for constant in plugin_workspace_constant %} + const {{constant.type}} constant_{{constant.index}}[{{constant.length}}] = { {{constant.value}} }; + checkCudaErrors(cudaMemcpyAsync({{constant.pos}}, &constant_{{constant.index}}, {{constant.length}} * sizeof({{constant.type}}), cudaMemcpyHostToDevice, stream)); + {% endfor %} + dim3 dimBlock, dimGrid; + {% for kernel in plugin_device_function_configuration %} + dimGrid = dim3{{kernel.grid_dim}}; + dimBlock = dim3{{kernel.block_dim}}; + {{kernel.name}}<<>>({{kernel.enqueue_params}}); + {% endfor %} +} + +REGISTER_TENSORRT_PLUGIN({{plugin_name}}Creator); diff --git a/python/tvm/tpat/cuda/plugin/trt8.0_plugin_h.template b/python/tvm/tpat/cuda/plugin/trt8.0_plugin_h.template new file mode 100644 index 000000000000..22b3d0a8deb1 --- /dev/null +++ b/python/tvm/tpat/cuda/plugin/trt8.0_plugin_h.template @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "NvInfer.h" +#include +#include +#include +#include + +namespace nvinfer1 +{ +namespace plugin +{ + +class {{plugin_name}}: public IPluginV2DynamicExt { +public: + {{plugin_name}}() {} + + {{plugin_name}}(const void *buffer, size_t length) { + } + + virtual size_t getSerializationSize() const noexcept override { + return 0; + } + virtual void serialize(void *buffer) const noexcept override {} + + //! The combination of kLINEAR + kFLOAT is supported. + bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override + { + bool condition = true; + {% for tensor_format in plugin_tensor_format %}if (pos == {{ loop.index0 }}){ + //std::cout << (inOut[pos].format == nvinfer1::TensorFormat::k{{tensor_format.format}}) << ", " << (inOut[pos].type == nvinfer1::DataType::k{{tensor_format.type}}) << std::endl; + condition &= inOut[pos].format == nvinfer1::TensorFormat::k{{tensor_format.format}}; + condition &= inOut[pos].type == nvinfer1::DataType::k{{tensor_format.type}}; + } + {% endfor %} + return condition; + } + + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override { + return new {{plugin_name}}(); + } + int getNbOutputs() const noexcept override { + //std::cout << __FUNCTION__ << std::endl; + return {{plugin_output_number}}; + } + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override { + //std::cout << __FUNCTION__ << std::endl; + {% for tensor_dims in plugin_output_shape %}if (outputIndex == {{ loop.index0 }}){ + nvinfer1::DimsExprs output_shape; + output_shape.nbDims = {{tensor_dims.nbdims}}; + {% for s in tensor_dims.shape %}output_shape.d[{{loop.index0}}] = exprBuilder.constant({{s}}); + {% endfor %} + return output_shape; + } + {% endfor %} + } + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override{ + //std::cout << __FUNCTION__ << std::endl; + {% for type in plugin_output_dtype %}if (index == {{ loop.index0 }}){ + return nvinfer1::DataType::k{{type}}; + } + {% endfor %} + } + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override{ + return {{plugin_workspace_size}}; + } + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override {} + int initialize() noexcept override {return 0;} + void terminate() noexcept override {} + void destroy() noexcept override { delete this; } + void setPluginNamespace(const char* szNamespace) noexcept override {mNamespace = szNamespace;} + const char* getPluginNamespace() const noexcept override {return mNamespace.c_str();} + const char* getPluginType() const noexcept override {return "{{plugin_name}}";} + const char* getPluginVersion() const noexcept override {return "1";} + void attachToContext(cudnnContext * /*cudnn*/, cublasContext * /*cublas*/, nvinfer1::IGpuAllocator * /*allocator*/) noexcept {} + void detachFromContext() noexcept {} + +private: + const char* mPluginNamespace; + std::string mNamespace; +}; + +class {{plugin_name}}Creator: public nvinfer1::IPluginCreator { +public: + {{plugin_name}}Creator(){ + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override { + {{plugin_name}}* obj = new {{plugin_name}}{serialData, serialLength}; + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } + + const char* getPluginName() const noexcept override {return "{{plugin_name}}";} + const char* getPluginVersion() const noexcept override {return "1";} + + void setPluginNamespace(const char* szNamespace) noexcept override {mNamespace = szNamespace;} + const char* getPluginNamespace() const noexcept override {return mNamespace.c_str();} + + const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override { + //std::cout << __FUNCTION__ << std::endl; + return &mFC; + } + nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override { + //std::cout << __FUNCTION__ << std::endl; + {{plugin_name}}* obj = new {{plugin_name}}{}; + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } +private: + std::string mNamespace; + static PluginFieldCollection mFC; + static std::vector mPluginAttributes; +}; + +} // namespace plugin + +} // namespace nvinfer1 diff --git a/python/tvm/tpat/cuda/template.py b/python/tvm/tpat/cuda/template.py new file mode 100644 index 000000000000..4e3fd66e8c14 --- /dev/null +++ b/python/tvm/tpat/cuda/template.py @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import contextlib +import os + +from jinja2 import Environment, FileSystemLoader + + +@contextlib.contextmanager +def pushd(new_dir): + pre_dir = os.getcwd() + os.chdir(new_dir) + try: + yield + finally: + os.chdir(pre_dir) + + +class PluginTemplate(object): + def __init__(self, template_params): + with pushd(os.path.normpath(os.path.dirname(__file__))): + template_loader = FileSystemLoader(searchpath="./") + self._template_env = Environment(loader=template_loader) + + self._plugin_name = template_params.plugin_name + self._plugin_output_number = template_params.num_outputs + self._plugin_output_dtype = template_params.output_dtype + self._plugin_workspace_size = template_params.total_workspace_size + self._plugin_source_code = template_params.cuda_source_code + self._plugin_output_shape = self._parse_plugin_output_shape( + template_params.output_shape + ) + self._plugin_tensor_format = self._parse_plugin_tensor_format( + template_params.tensor_type + ) + self._plugin_device_function_configuration = ( + self._parse_plugin_device_function_configuration( + template_params.device_function_configuration, + template_params.device_function_list, + ) + ) + self._plugin_workspace_constant = self._parse_plugin_workspace_constant( + template_params.workspace_constant + ) + + class TensorDims: + def __init__(self, nbdims, shape): + self.nbdims = nbdims + self.shape = tuple(shape) + + def __str__(self): + return f"TensorDims(nbdims={self.nbdims}, shape={self.shape})" + + def __repr__(self): + return str(self) + + class TensorFormat: + def __init__(self, format, type): + self.format = format + self.type = type + + def __str__(self): + return f"TensorFormat(format={self.format}, type={self.type})" + + def __repr__(self): + return str(self) + + class Kernel: + def __init__( + self, + name, + grid_dim, + block_dim, + enqueue_params, + kernel_params=None, + code=None, + ): + self.name = name + self.grid_dim = grid_dim + self.block_dim = block_dim + self.enqueue_params = enqueue_params + self.kernel_params = kernel_params + self.code = code + + def __str__(self): + return f"Kernel(name={self.name}, grid_dim={self.grid_dim}, block_dim={self.block_dim}, enqueue_params={self.enqueue_params})" + + def __repr__(self): + return str(self) + + class Constant: + def __init__(self, pos, value, type, index, length): + self.pos = pos + self.value = value + self.type = type + self.index = index + self.length = length + + def __str__(self): + return f"Constant(pos={self.pos}, length={self.length}, type={self.type}, index={self.index})" + + def __repr__(self): + return str(self) + + class Case: + def __init__( + self, + batch_size, + plugin_template, + dy_plugin_input_size_type_without_bs=None, + dy_plugin_output_size_type_without_bs=None, + ): + self.batch_size = batch_size + self.plugin_template = plugin_template + self.dy_plugin_input_size_type_without_bs = ( + dy_plugin_input_size_type_without_bs + ) + self.dy_plugin_output_size_type_without_bs = ( + dy_plugin_output_size_type_without_bs + ) + + class Shape: + def __init__(self, size, dtype): + self.size = size + self.dtype = dtype + + def _parse_plugin_output_shape(self, output_shape): + plugin_output_shape = [] + for s in output_shape: + nbdims = len(s) + shape = s + plugin_output_shape.append(self.TensorDims(nbdims, shape)) + return plugin_output_shape + + def _parse_plugin_tensor_format(self, tensor_type): + plugin_tensor_format = [] + for dtype in tensor_type: + plugin_tensor_format.append(self.TensorFormat("LINEAR", dtype)) + return plugin_tensor_format + + def _parse_plugin_device_function_configuration( + self, device_function_configuration, device_function_list + ): + frequency = {} + kernel_configuration = [] + for func_name in device_function_list: + if func_name not in frequency.keys(): + frequency[func_name] = 0 + key_name = func_name + else: + frequency[func_name] += 1 + key_name = f"{func_name}_{frequency[func_name]}" + + kernel_configuration.append( + self.Kernel( + func_name, + device_function_configuration[key_name]["grid_dim"], + device_function_configuration[key_name]["block_dim"], + device_function_configuration[key_name]["enqueue_params"], + ) + ) + return kernel_configuration + + def _parse_plugin_workspace_constant(self, workspace_constant): + plugin_constant_init = [] + for init_constant in workspace_constant.items(): + value_str = ", ".join(str(ele) for ele in init_constant[1][0]) + value_str = value_str.strip(",") + plugin_constant_init.append( + self.Constant( + init_constant[0], + value_str, + init_constant[1][1], + init_constant[1][2], + len(init_constant[1][0]), + ) + ) + return plugin_constant_init + + def generate_header_file(self): + raise Exception("not implement method") + + def generate_source_file(self): + raise Exception("not implement method") + + def fill(self): + plugin_header_path = f"./plugin/src/{self._plugin_name}.h" + plugin_source_path = f"./plugin/src/{self._plugin_name}.cu" + if os.path.isfile(plugin_header_path): + os.remove(plugin_header_path) + if os.path.isfile(plugin_source_path): + os.remove(plugin_source_path) + + with pushd(os.path.normpath(os.path.dirname(__file__))): + self.generate_header_file() + self.generate_source_file() + result = self._build_plugin() + + if result: + return f"{os.path.dirname(os.path.abspath(__file__))}/plugin/lib/{self._plugin_name}.so" + else: + return False + + def _build_plugin(self): + os.chdir("./plugin") + + os.system(f"make clean plugin_name={self._plugin_name}") + os.system(f"make plugin_name={self._plugin_name}") + + os.chdir("../") + return True + + +class StaticBatchPluginTemplate(PluginTemplate): + """ + Fill in the useable params which generated by PluginTemplateParams to plugin template. + The plugin template is compatible with TensorRT-8.0. + """ + + def __init__( + self, + template_params, + TEMPLATE_HEADER_FILE="./plugin/trt8.0_plugin_h.template", + TEMPLATE_SOURCE_FILE="./plugin/trt8.0_plugin_cu.template", + ): + super(StaticBatchPluginTemplate, self).__init__(template_params) + + self._template_header_file = TEMPLATE_HEADER_FILE + self._template_source_file = TEMPLATE_SOURCE_FILE + + def generate_header_file(self): + template = self._template_env.get_template(self._template_header_file) + output_text = template.render( + plugin_name=self._plugin_name, + plugin_output_number=self._plugin_output_number, + plugin_output_shape=self._plugin_output_shape, + plugin_output_dtype=self._plugin_output_dtype, + plugin_workspace_size=self._plugin_workspace_size, + plugin_tensor_format=self._plugin_tensor_format, + ) + with open("./plugin/src/{}.h".format(self._plugin_name), "w") as f: + f.write(output_text) + + def generate_source_file(self): + template = self._template_env.get_template(self._template_source_file) + output_text = template.render( + plugin_name=self._plugin_name, + plugin_device_function_configuration=self._plugin_device_function_configuration, + plugin_source_code=self._plugin_source_code, + plugin_workspace_constant=self._plugin_workspace_constant, + ) + with open("./plugin/src/{}.cu".format(self._plugin_name), "w") as f: + f.write(output_text) diff --git a/python/tvm/tpat/cuda/template_params.py b/python/tvm/tpat/cuda/template_params.py new file mode 100644 index 000000000000..c03f9d83a9dd --- /dev/null +++ b/python/tvm/tpat/cuda/template_params.py @@ -0,0 +1,401 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re + +from .type_mapping import plugin_type_size, python_to_trt_type_mapping, tvm_to_c_type_mapping + + +class PluginTemplateParams(object): + """ + Generate useable params for TensorRT plugin. + """ + + def __init__(self, kernel, model, output_shapes, tunning_node, name): + self._kernel = kernel + self._model = model + self._tunning_name = name + self._tunning_node = tunning_node + + self._input_dict = {} + + self._cuda_source_code = None + + self._workspace_size = [] # eid -> workspace size + self._workspace_dtype = [] # eid -> workspace dtype + self._total_workspace_size = 0 # total workspace size need by plugin + + # Kernel related params + self._device_function_params = ( + {} + ) # kernel -> index for params of host function or address based on workspace + self._device_thread_config = {} # kernel -> thread dim + self._device_function_list = [] # kernel invoke order + self._device_allocate_memory_size = {} # address -> (dtype, extent), intermediate variable + + # Host side function attrs + self._host_function_params = {} # function -> eid of params (firstly inputs, then outputs) + + self._nums_inputs = 0 # number of inputs + self._nums_outputs = 0 # number of outputs + self._output_dtype = [] # dtype of outputs + self._output_shape = output_shapes # shape of outputs + self._constant_params = {} # constant params, storage_id -> data + self._trt_workspace_constant = {} + + self._tensor_type = [] # tensor type of inputs and outputs + + self._storage_id = [] # eid -> storage id + self._device_function_configuration = None + + self._parse_tensor_type() + self._parse_kernel_params() + self._prepare_input_dict() + self._prepare_device_function_config() + + def _describe(self): + """Use for debug.""" + print(f"Cuda source code >>> {self._cuda_source_code}") + print(f"Constant params >>> {self._constant_params}") + print(f"Device Function Param >>> {self._device_function_params}") + print(f"Device Thread Config >>> {self._device_thread_config}") + print(f"Device Function List >>> {self._device_function_list}") + print(f"Nums Input >>> {self._nums_inputs}") + print(f"Nums Output >>> {self._nums_outputs}") + print(f"Workspace Data Type >>> {self._workspace_dtype}") + print(f"Workspace Size >>> {self._workspace_size}") + print(f"Host Function Params >>> {self._host_function_params}") + print(f"Storage Id >>> {self._storage_id}") + print(f"Device Memory Size >>> {self._device_allocate_memory_size}") + + # Parse Constant. + def _parse_constant_params(self, constant_params): + tvm_constant = {} + for key, value in constant_params.items(): + tvm_constant[key] = value.flatten() + return tvm_constant + + def _parse_device_function_list(self, device_function_thread_config): + function_list = [] + for item in device_function_thread_config.split("\n"): + if len(item) == 0: + continue + item = item.split() + + function_list.append(item[0]) + + return function_list + + # Parse device functions params order. + def _parse_device_function_params(self, device_function_list): + frequency = {} + result = {} + for device_function in device_function_list.split("\n"): + if len(device_function) == 0: + continue + item = device_function.split() + name = item[0] + params = item[1:] + + if name not in result.keys(): + result[name] = params + frequency[name] = 0 + else: + frequency[name] += 1 + func_name = f"{name}_{frequency[name]}" + result[func_name] = params + return result + + # Parse device functions thread config. + def _parse_device_function_thread_config(self, device_function_thread_config): + frequency = {} + kernel_thread_config = {} + for item in device_function_thread_config.split("\n"): + if len(item) == 0: + continue + config = item.split() + kernel_name = config[0] + params = config[1:] + + if kernel_name not in kernel_thread_config.keys(): + kernel_thread_config[kernel_name] = params + frequency[kernel_name] = 0 + else: + frequency[kernel_name] += 1 + func_name = f"{kernel_name}_{frequency[kernel_name]}" + kernel_thread_config[func_name] = params + return kernel_thread_config + + # Parse global memory allocated in device side. + def _parse_device_allocate_memory_size(self, device_allocate_global_memory): + allocate_global_memory = {} + for allocate_memory in device_allocate_global_memory.split("\n"): + if len(allocate_memory) == 0: + continue + allocate = allocate_memory.split() + allocate_global_memory[allocate[0]] = allocate[1:] + return allocate_global_memory + + # Parse variables storage index. + def _parse_storageid(self, storageid): + storage_id = [] + for sid in storageid.split("\n"): + if len(sid) == 0: + continue + storage_id = sid.split() + return storage_id + + # Parse numbers of input, only variable. + def _parse_nums_input(self, nums_input): + real_nums_input = int(nums_input) - int(len(self._constant_params)) + return real_nums_input + + # Parse numbers of output. + def _parse_nums_output(self, nums_output): + real_nums_output = int(nums_output) + return real_nums_output + + # Parse datatype of variables in memory. + def _parse_workspace_dtype(self, workspaces_dtype): + return workspaces_dtype.split() + + # Parse size of variables in memory. + def _parse_workspace_size(self, workspace_size): + return workspace_size.split() + + def _parse_host_function_params(self, host_function_list): + """ + Parse the list of host functions. + """ + frequency = {} + result = {} + for function in host_function_list.split("\n"): + if len(function) == 0: + continue + data = function.split() + name = data[0] + params = data[1:] + + if name not in result.keys(): + result[name] = params + frequency[name] = 0 + else: + frequency[name] += 1 + func_name = f"{name}_{frequency[name]}" + result[func_name] = params + return result + + def _parse_kernel_params(self): + self._cuda_source_code = self._kernel.cuda_source_code + self._constant_params = self._parse_constant_params(self._kernel.constant_params) + self._device_function_params = self._parse_device_function_params( + self._kernel.device_function_list + ) + self._device_function_list = self._parse_device_function_list( + self._kernel.device_function_thread_config + ) + self._device_thread_config = self._parse_device_function_thread_config( + self._kernel.device_function_thread_config + ) + self._device_allocate_memory_size = self._parse_device_allocate_memory_size( + self._kernel.device_allocate_memory_size + ) + self._nums_inputs = self._parse_nums_input(self._kernel.num_inputs) + self._nums_outputs = self._parse_nums_output(self._kernel.num_outputs) + self._workspace_dtype = self._parse_workspace_dtype(self._kernel.workspace_dtype) + self._workspace_size = self._parse_workspace_size(self._kernel.workspace_size) + self._host_function_params = self._parse_host_function_params( + self._kernel.host_function_list + ) + self._storage_id = self._parse_storageid(self._kernel.storageid) + + self._describe() + + def _parse_tensor_type(self): + """ + Infer for input and output shape. + """ + tunning_node = self._tunning_node + + for inp in tunning_node.inputs: + self._tensor_type.append(python_to_trt_type_mapping[inp.dtype.name]) + + for oup in tunning_node.outputs: + self._tensor_type.append(python_to_trt_type_mapping[oup.dtype.name]) + + def _prepare_input_dict(self): + """ + The memory address used by functions params. + """ + workspace_size = 0 + input_slot_dict = {} # storageid -> xx + + # 1. for outputs + for i in range(self._nums_outputs): + # given index of outputs, return entry id + eid = self._kernel.graph_module.get_output_eid(i) + sid = int(self._storage_id[eid]) + # resolve output type given entry id + self._output_dtype.append(python_to_trt_type_mapping[self._workspace_dtype[eid]]) + self._input_dict[str(eid)] = f"outputs[{i}]" + input_slot_dict[sid] = f"outputs[{i}]" + + # 2. for inputs, including variable and constants + storage_id_to_workspace_size = {} # different entry id may map to same storage id + for eid in range(len(self._workspace_size)): + sid = int(self._storage_id[eid]) + if sid not in storage_id_to_workspace_size.keys(): + storage_id_to_workspace_size[sid] = 0 + storage_id_to_workspace_size[sid] = max( + int(self._workspace_size[eid]), int(storage_id_to_workspace_size[sid]) + ) + + for eid in range(len(self._workspace_size)): + sid = int(self._storage_id[eid]) + if sid in input_slot_dict.keys(): + self._input_dict[str(eid)] = input_slot_dict[sid] + continue + if eid < self._nums_inputs: + # it must be variable + self._input_dict[str(eid)] = "inputs[" + str(eid) + "]" + elif eid < len(self._workspace_size) - self._nums_outputs: + # it must be constant + if eid == self._nums_inputs: + # the first one + self._input_dict[str(eid)] = "workspace" + else: + self._input_dict[str(eid)] = f"(workspace + {workspace_size})" + workspace_size += int(storage_id_to_workspace_size[sid]) + + key = self._input_dict[str(eid)] + if ( + not key in self._trt_workspace_constant.keys() + and str(sid) in self._constant_params.keys() + ): + self._trt_workspace_constant[key] = ( + self._constant_params[str(sid)], # value + tvm_to_c_type_mapping[self._workspace_dtype[eid]], # type + int(eid), # id + ) + input_slot_dict[sid] = self._input_dict[str(eid)] + + if len(self._device_allocate_memory_size) != 0: + for key, value in self._device_allocate_memory_size.items(): + self._input_dict[key] = ( + "(" + + tvm_to_c_type_mapping[value[0]] + + "*)(workspace + " + + str(workspace_size) + + ")" + ) + workspace_size += int(value[1]) * plugin_type_size[value[0]] + + self._total_workspace_size = workspace_size + + def _prepare_device_function_config(self): + """ + Grid, Block Layout, etc. + """ + configuration = {} + frequency = {} + + for i in range(len(self._device_function_list)): + device_function_name = self._device_function_list[i] + host_function_name = re.sub(r"_kernel_?\d*", "", device_function_name, count=1) + + if device_function_name not in configuration.keys(): + configuration[device_function_name] = {} + frequency[device_function_name] = 0 + else: + frequency[device_function_name] += 1 + host_function_name = f"{host_function_name}_{frequency[device_function_name]}" + device_function_name = f"{device_function_name}_{frequency[device_function_name]}" + configuration[device_function_name] = {} + + # grid and block dim + configuration[device_function_name]["grid_dim"] = self._device_thread_config[ + device_function_name + ][0].strip("grid=") + configuration[device_function_name]["block_dim"] = self._device_thread_config[ + device_function_name + ][1].strip("block=") + + device_params = self._device_function_params[device_function_name] + host_params = self._host_function_params[host_function_name] # eid of params + + enqueue_params = "" + for j in range(len(device_params)): + if device_params[j].isdigit(): # correspond to eid + eid = host_params[int(device_params[j])] + dtype = self._workspace_dtype[int(eid)] + enqueue_params += ( + "(" + tvm_to_c_type_mapping[dtype] + "*)" + self._input_dict[str(eid)] + ) + else: + if ( + device_params[j] in self._input_dict.keys() + ): # correspond to device memory, intermediate variable + enqueue_params += self._input_dict[device_params[j]] + + if j != len(device_params) - 1: + enqueue_params += ", " + configuration[device_function_name]["enqueue_params"] = enqueue_params + self._device_function_configuration = configuration + + @property + def device_function_list(self): + return self._device_function_list + + @property + def device_function_configuration(self): + return self._device_function_configuration + + @property + def total_workspace_size(self): + return self._total_workspace_size + + @property + def num_outputs(self): + return self._nums_outputs + + @property + def output_dtype(self): + return self._output_dtype + + @property + def output_shape(self): + return self._output_shape + + @property + def tensor_type(self): + return self._tensor_type + + @property + def workspace_constant(self): + return self._trt_workspace_constant + + @property + def cuda_source_code(self): + return self._cuda_source_code + + @property + def plugin_name(self): + return self._kernel.plugin_name + + @property + def onnx_op_type(self): + return self._kernel.onnx_op_type diff --git a/python/tvm/tpat/cuda/type_mapping.py b/python/tvm/tpat/cuda/type_mapping.py new file mode 100644 index 000000000000..492d36930982 --- /dev/null +++ b/python/tvm/tpat/cuda/type_mapping.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# type mapping : tvm -> c, used by c++ +tvm_to_c_type_mapping = { + "bool": "int", + "int16": "int", + "int32": "int", + "int64": "int", + "uint8": "uchar", + "uint32": "int", + "uint64": "int", + "float16": "half", + "float32": "float", + "float64": "float", +} + +# type mapping : python -> trt, used by TensorRT's getOutputDataType +python_to_trt_type_mapping = { + "bool": "INT32", + "int32": "INT32", + "int64": "INT32", + "uint64": "INT32", + "uint8": "INT8", + "float16": "FLOAT", + "float32": "FLOAT", + "float64": "FLOAT", +} + +# type size : trt workspace, sizeof c++ data type +plugin_type_size = { + "bool": 4, + "int16": 4, + "int32": 4, + "int64": 4, + "uint8": 1, + "uint32": 4, + "uint64": 4, + "float16": 4, + "float32": 4, + "float64": 4, +} + +# onnx type, used by CAST operator +# "int32": 6 +onnx_type_mapping = {"int64": 7, "bool": 9, "uint32": 12, "uint64": 13} \ No newline at end of file diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index f698c654d6d8..ee4e98b4b22e 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1228,6 +1228,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Collect any constants extracted by external codegen. ret.params = std::unordered_map(); + Map const_name_to_constant = lowered_mod->GetAttr>(tvm::attr::kConstNameToConstant) .value_or({}); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 83c252d831c5..216a375b7b53 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -86,6 +86,17 @@ struct ExecutorCodegen { return ret; } + std::unordered_map GetParamIds() { + std::unordered_map ret; + auto names = CallFunc>("list_params_name", nullptr); + for (const auto& expr : names) { + // Implicit cast from runtime::String to std::string + std::string key = expr; + ret[key] = CallFunc("get_param_id", key); + } + return ret; + } + Array GetExternalModules() { return CallFunc>("get_external_modules", nullptr); } @@ -222,6 +233,9 @@ class RelayBuildModule : public runtime::ModuleNode { ICHECK_EQ(args.num_args, 2); *rv = this->Optimize(args[0], args[1]); }); + } else if (name == "get_constant_params") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetConstantParams(); }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); @@ -268,6 +282,21 @@ class RelayBuildModule : public runtime::ModuleNode { return ret; } + /*! + * \brief Get params dictionary, but key is ParamIdx + * + * \return Map params dictionary + */ + Map GetConstantParams() { + Map ret; + auto param_ids = this->executor_codegen_->GetParamIds(); + + for (const auto& kv : ret_.params) { + ret.Set(std::to_string(param_ids[kv.first]), Constant(kv.second)); + } + return ret; + } + /*! * \brief Set the parameters * diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 868173d28c13..15c62d7f8fae 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -266,6 +266,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator(); + Map const_name_to_constant = lowered_mod->GetAttr>(tvm::attr::kConstNameToConstant) .value_or({}); @@ -292,6 +293,10 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator param_storage_ids() { + return param_storage_ids_; + } + protected: /*! * \brief Add node to graph @@ -663,6 +668,14 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { CHECK(it != this->output_.params.end()) << "no such parameter " << key; *rv = (*it).second; }); + } else if (name == "get_param_id") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + String key = args[0]; + auto it = this->output_.params.find(key); + CHECK(it != this->output_.params.end()) << "no such parameter " << key; + auto storage_ids = this->codegen_->param_storage_ids(); + *rv = static_cast(storage_ids[(*it).first]); + }); } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.lowered_funcs; diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index acaea425d178..97b28a021903 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -304,6 +304,7 @@ struct LoweredOutput { * to the constant's value. */ std::unordered_map params; + ExecutorCodegenMetadata metadata; }; diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index f54aefe8c4eb..39c59f17f40b 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -41,6 +42,9 @@ namespace tvm { namespace runtime { +// funcs thread config +std::vector funcs_thread_config; + // Module to support thread-safe multi-GPU execution. // cuModule is a per-GPU module // The runtime will contain a per-device module table @@ -204,6 +208,13 @@ class CUDAWrappedFunc { << cuda; } LOG(FATAL) << os.str(); + } else { + std::stringstream ss; + ss << func_name_ << " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," + << wl.grid_dim(2) << ") " + << " block=(" << wl.block_dim(0) << "," << wl.block_dim(1) << "," << wl.block_dim(2) + << ")\n"; + funcs_thread_config.push_back(ss.str()); } } @@ -263,6 +274,7 @@ PackedFunc CUDAModuleNode::GetFunction(const String& name, const ObjectPtr fmap, std::string cuda_source) { + funcs_thread_config.clear(); auto n = make_object(data, fmt, fmap, cuda_source); return Module(n); } @@ -289,10 +301,21 @@ Module CUDAModuleLoadBinary(void* strm) { return CUDAModuleCreate(data, fmt, fmap, std::string()); } +String CUDAModuleGetThreadConfig() { + String ret = ""; + for (const String& func_config : funcs_thread_config) { + ret = ret + func_config; + } + return ret; +} + TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin").set_body_typed(CUDAModuleLoadFile); TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx").set_body_typed(CUDAModuleLoadFile); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda").set_body_typed(CUDAModuleLoadBinary); + +TVM_REGISTER_GLOBAL("runtime.module.retrieve_device_function_thread_config") + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CUDAModuleGetThreadConfig(); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index 777a5a442a98..cbdae9a510ab 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -375,6 +375,49 @@ void GraphExecutor::DefaultLookupLinkedParam(TVMArgs args, TVMRetValue* rv) { *rv = NDArray(GetObjectPtr(container)); } +String GraphExecutor::GetWorkspaceDtype() { + std::ostringstream os; + for (const std::string& s_type : attrs_.dltype) { + os << s_type << " "; + } + return os.str(); +} + +String GraphExecutor::GetWorkspaceSize() { + std::ostringstream os; + for (size_t i = 0; i < data_entry_.size(); ++i) { + const DLTensor* tmp = data_entry_[i].operator->(); + os << GetDataSize(*tmp) << " "; + } + return os.str(); +} + +String GraphExecutor::GetFunctionList() { + std::ostringstream os; + for (auto funcs : exec_func_) { + for (auto func : funcs) { + os << func << " "; + } + os << "\n"; + } + return os.str(); +} + +String GraphExecutor::GetStorageId() { + std::ostringstream os; + for (auto id : attrs_.storage_id) { + os << id << " "; + } + os << "\n"; + return os.str(); +} + +int GraphExecutor::GetOutputEid(int index) const { + ICHECK_LT(static_cast(index), outputs_.size()); + uint32_t eid = this->entry_id(outputs_[index]); + return eid; +} + void GraphExecutor::SetupStorage() { // Grab saved optimization plan from graph. std::vector vtype; @@ -510,14 +553,23 @@ void GraphExecutor::SetupOpExecs() { const auto& inode = nodes_[nid]; if (inode.op_type == "null") continue; std::vector args; + std::vector eids; + std::vector funcs; for (const auto& e : inode.inputs) { uint32_t eid = this->entry_id(e); args.push_back(const_cast(data_entry_[eid].operator->())); + eids.push_back(eid); // entry id of inputs } for (uint32_t index = 0; index < inode.param.num_outputs; ++index) { uint32_t eid = this->entry_id(nid, index); args.push_back(const_cast(data_entry_[eid].operator->())); + eids.push_back(eid); // entry id of outputs } + funcs.push_back(inode.param.func_name); + for (auto eid : eids) { + funcs.push_back(std::to_string(eid)); + } + exec_func_.push_back(funcs); ICHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op"; std::shared_ptr op_args = nullptr; @@ -738,6 +790,21 @@ PackedFunc GraphExecutor::GetFunction(const String& name, const ObjectPtrGetWorkspaceDtype(); }); + } else if (name == "get_workspace_size") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetWorkspaceSize(); }); + } else if (name == "get_function_list") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetFunctionList(); }); + } else if (name == "get_storageid") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetStorageId(); }); + } else if (name == "get_output_eid") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutputEid(args[0]); }); } else { return PackedFunc(); } diff --git a/src/runtime/graph_executor/graph_executor.h b/src/runtime/graph_executor/graph_executor.h index 2f6b8b8147e5..9d044cdf8a2f 100644 --- a/src/runtime/graph_executor/graph_executor.h +++ b/src/runtime/graph_executor/graph_executor.h @@ -416,6 +416,15 @@ class TVM_DLL GraphExecutor : public ModuleNode { } ICHECK_EQ(bitmask, 1 | 2 | 4 | 8 | 16) << "invalid format"; } + /*! \brief get the storage dtype */ + String GetWorkspaceDtype(); + /*! \brief get the storage size */ + String GetWorkspaceSize(); + /*! \brief get the exec func in order*/ + String GetFunctionList(); + /*! \brief get storage ids*/ + String GetStorageId(); + int GetOutputEid(int index) const; /*! \brief PackedFunc to lookup a linked paramter from a local Module. */ void DefaultLookupLinkedParam(TVMArgs args, TVMRetValue* rv); /*! \brief Delete NDArray::Container with linked (i.e. static) data. */ @@ -430,6 +439,8 @@ class TVM_DLL GraphExecutor : public ModuleNode { * \param eid The data_enrty_ index. */ void CheckExternalDLTensor(const DLTensor* external, uint32_t eid) const; + /*! \brief Store execute function in order */ + std::vector> exec_func_; /*! * \brief Create an execution function given input. * \param attrs The node attributes. diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 932116485fa1..9000f04e2626 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -29,11 +29,17 @@ #include #include +#include + #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" namespace tvm { namespace tir { +extern std::unordered_map> host_function_name_to_params; +extern std::unordered_map name_to_prefix; +std::vector device_funcs; +std::vector device_memory_size; namespace { struct KernelInfo { @@ -120,6 +126,8 @@ class DeviceInfoCollector : public StmtVisitor { } void VisitStmt_(const AllocateNode* op) final { + ResolveDeviceMemorySize(op); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { ICHECK(!dyn_shmem_size.defined()) << "Only one dynamic shared memory allocation is allowed."; @@ -136,6 +144,16 @@ class DeviceInfoCollector : public StmtVisitor { StmtVisitor::VisitStmt_(op); } + void ResolveDeviceMemorySize(const AllocateNode* op) { + std::stringstream ss; + ss << op->buffer_var.get() << " " << op->dtype << " "; + for (auto extent : op->extents) { + ss << extent << " "; + } + ss << "\n"; + device_memory_size.push_back(ss.str()); + } + // The collected results KernelInfo info_; // recording what thread axis have been visited. @@ -298,6 +316,7 @@ class DeviceKernelMutator : public StmtExprMutator { device_kernel_launch_.insert(gvar); Array call_args; + call_args.push_back(StringImm(dev_info.global_symbol)); for (PrimExpr arg : node->args) { call_args.push_back(arg); @@ -306,11 +325,33 @@ class DeviceKernelMutator : public StmtExprMutator { call_args.push_back(Substitute(launch_arg, param_map)); } + ResolveDeviceFuncs(gvar->name_hint, node->args); + auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype; return Call(dtype, builtin::tvm_call_packed(), call_args); } + void ResolveDeviceFuncs(const String& name_hint, const Array& args) { + std::stringstream ss; + ss << name_hint << " "; + for (auto arg : args) { + bool find_param_in_host = false; + auto params = host_function_name_to_params[name_to_prefix[name_hint]]; + for (int i = 0; i < params.size(); ++i) { + if (arg.same_as(params[i])) { + ss << i << " "; + find_param_in_host = true; + } + } + if (!find_param_in_host) { + ss << arg.get() << " "; + } + } + ss << "\n"; + device_funcs.push_back(ss.str()); + } + Optional current_target_; std::unordered_map device_info_map_; std::unordered_set device_kernel_launch_; @@ -318,9 +359,27 @@ class DeviceKernelMutator : public StmtExprMutator { }; namespace transform { +String GetDeviceFunctionList() { + String ret = ""; + for (auto func : device_funcs) { + ret = ret + func; + } + return ret; +} + +String GetDeviceMemorySize() { + String ret = ""; + for (auto m : device_memory_size) { + ret = ret + m; + } + return ret; +} Pass LowerDeviceKernelLaunch() { auto pass_func = [](IRModule mod, PassContext ctx) -> IRModule { + device_funcs.clear(); + device_memory_size.clear(); + auto mutator = [&mod]() { std::unordered_map device_info_map; for (const auto& [gvar, base_func] : mod->functions) { @@ -372,6 +431,12 @@ Pass LowerDeviceKernelLaunch() { TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceKernelLaunch") .set_body_typed(LowerDeviceKernelLaunch); +TVM_REGISTER_GLOBAL("tir.transform.retrieve_device_function_list") + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = GetDeviceFunctionList(); }); + +TVM_REGISTER_GLOBAL("tir.transform.retrieve_device_memory_size") + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = GetDeviceMemorySize(); }); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 94e245b636a8..18acbda1bee8 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -41,6 +41,7 @@ namespace tvm { namespace tir { static constexpr const char* kDeviceContextVar = "device_api_context"; +std::unordered_map> host_function_name_to_params; namespace { class ReturnRewriter : public StmtMutator { @@ -278,6 +279,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { std::vector> var_def; std::vector> buffer_def; + std::vector params_of_function; + for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { Var param = func_ptr->params[i]; @@ -290,6 +293,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { var_def.emplace_back(f_arg_value(param.dtype(), i), param); if (func_ptr->buffer_map.count(param)) { + params_of_function.push_back(func_ptr->buffer_map[param]->data); buffer_def.emplace_back(param, func_ptr->buffer_map[param]); } @@ -316,6 +320,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { } } + host_function_name_to_params[name_hint] = params_of_function; + Array args{v_packed_args, buf_packed_arg_type_ids->data, v_num_packed_args, v_out_ret_value, v_out_ret_tcode, v_resource_handle}; @@ -386,6 +392,8 @@ namespace transform { Pass MakePackedAPI() { auto pass_func = [](IRModule mod, PassContext ctx) { + host_function_name_to_params.clear(); + Map packed_func_methods; for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 9b1dbf1a6618..d79e30520b94 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -41,10 +41,12 @@ namespace tvm { namespace tir { +std::unordered_map name_to_prefix; + class HostDeviceSplitter : public StmtMutator { public: - explicit HostDeviceSplitter(IRModule* device_mod, std::function var_supply) - : device_mod_(device_mod), var_supply_(var_supply) {} + explicit HostDeviceSplitter(IRModule* device_mod, std::function var_supply, std::string name_prefix = "") + : device_mod_(device_mod), var_supply_(var_supply), name_prefix_(name_prefix) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tvm::attr::kTarget) { @@ -92,6 +94,9 @@ class HostDeviceSplitter : public StmtMutator { } GlobalVar kernel_symbol_global = var_supply_(); + + name_to_prefix[kernel_symbol_global->name_hint] = name_prefix_; + PrimFunc device_func(params, body, kernel_ret_type); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, {tir::attr::kNoAlias, Bool(true)}, @@ -117,11 +122,13 @@ class HostDeviceSplitter : public StmtMutator { IRModule* device_mod_; // Generate new GlobalVar for the kernel std::function var_supply_; + // name prefix of function + std::string name_prefix_; }; PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, - std::function var_supply) { - HostDeviceSplitter splitter(device_mod, var_supply); + std::function var_supply, std::string name_prefix = "") { + HostDeviceSplitter splitter(device_mod, var_supply, name_prefix); if (auto body = splitter(func->body); !body.same_as(func->body)) { func.CopyOnWrite()->body = body; @@ -139,6 +146,8 @@ Pass SplitHostDevice() { IRModule device_mod = IRModule(Map({})); IRModule updates = IRModule(Map({})); + name_to_prefix.clear(); + for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { PrimFunc func = opt.value(); @@ -150,7 +159,7 @@ Pass SplitHostDevice() { return global_var_supply->FreshGlobal(kernel_name, false); }; - func = SplitHostDevice(std::move(func), &device_mod, var_supply); + func = SplitHostDevice(std::move(func), &device_mod, var_supply, name_prefix); if (!func.same_as(base_func)) { updates->Add(gvar, func); } diff --git a/tests/python/tpat/cuda/__init__.py b/tests/python/tpat/cuda/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/python/tpat/cuda/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/python/tpat/cuda/common.py b/tests/python/tpat/cuda/common.py new file mode 100644 index 000000000000..58ef60c7ce91 --- /dev/null +++ b/tests/python/tpat/cuda/common.py @@ -0,0 +1,3548 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os + +import numpy as np +import onnx +import pycuda.autoinit +import pycuda.driver as cuda +import pytest +import tensorflow as tf +import tensorflow.compat.v1 as tf +import tensorrt as trt +from onnx import TensorProto, helper, mapping, numpy_helper +from onnx.backend.test.case.node import _extract_value_info + +from tvm import tpat + +from .trt import ( + allocate_buffers, + build_engine, + do_inference, + load_plugin, + remove_plugin, +) + +tf.disable_v2_behavior() + +INPUT_MODEL_FILE = "test_op_plugin.onnx" +OUTPUT_MODEL_FILE = "test_op_trt.onnx" + +TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) + +# set gpu device for tensorflow +gpu_devices = tf.config.experimental.list_physical_devices("GPU") +for device in gpu_devices: + tf.config.experimental.set_memory_growth(device, True) + +# Simple helper data class that's a little nicer to use than a 2-tuple. + + +def convert_to_list(x): + if not isinstance(x, list): + x = [x] + return x + + +def run_tf_graph(sess, input_data, input_node, output_node): + """Generic function to execute tensorflow""" + input_data = convert_to_list(input_data) + input_node = convert_to_list(input_node) + output_node = convert_to_list(output_node) + + tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node] + + input_dict = {e: input_data[i] for i, e in enumerate(input_node)} + # if len(input_node) == 1 and input_node[0] == "": + # output_data = sess.run(tensor) + # else: + output_data = sess.run(tensor, input_dict) + return output_data + + +def verify_tf_with_trt_result(in_data, in_name, out_name, op_name): + def name_without_num(name): + return name.split(":")[0] if ":" in name else name + + out_name = convert_to_list(out_name) + out_node = [name_without_num(name) for name in out_name] + in_data = convert_to_list(in_data) + in_name = convert_to_list(in_name) + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + tf_result = run_tf_graph(sess, in_data, in_name, out_name) + frozen_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, out_node) + with open("./test_op_{}.pb".format(op_name), "wb") as ofile: + ofile.write(frozen_graph.SerializeToString()) + os.system( + "python3 -m tf2onnx.convert --input ./test_op_{}.pb --inputs {} --outputs {} --output {} --opset 11".format( + op_name, str(",").join(in_name), str(",").join(out_name), INPUT_MODEL_FILE + ) + ) + ops_name = [op_name] + + _, trt_plugin_names = tpat.cuda.pipeline( + INPUT_MODEL_FILE, + ops_name, + False, + {"work_dir": "./log_db", "max_trials_per_task": 500}, + OUTPUT_MODEL_FILE, + ) + + load_plugin(trt_plugin_names) + engine = build_engine(OUTPUT_MODEL_FILE, trt_engine_datatype=trt.DataType.HALF) + + inputs, outputs, bindings, stream = allocate_buffers(engine) + with engine.create_execution_context() as context: + for i in range(len(inputs)): + input_data = in_data[i].ravel() + np.copyto(inputs[i].host, input_data) + + trt_result = do_inference( + context, + bindings=bindings, + inputs=inputs, + outputs=outputs, + stream=stream, + ) + + ret = True + if len(trt_result) == 1: + ret = compare_tf_trt_result(tf_result, trt_result) + else: + for i in range(len(trt_result)): + ret &= compare_tf_trt_result(tf_result[i], trt_result[i]) + assert ret, "result check False" + return ret + + +def compare_tf_trt_result(tf_result, trt_result): + print(tf_result) + print("================") + print(trt_result) + tf_reshape = np.array(tf_result).reshape(-1) + trt_reshape = np.array(trt_result).reshape(-1) + + if ( + isinstance(tf_result, list) + and isinstance(trt_result, list) + and len(tf_result) > 0 + and len(trt_result) > 0 + and np.isnan(tf_result[0]).any() + and np.isnan(trt_result[0]).any() + ): + return True + elif ( + isinstance(tf_result, list) + and isinstance(trt_result, list) + and len(tf_result) > 0 + and len(trt_result) > 0 + and np.isinf(tf_result[0]).any() + and np.isinf(trt_result[0]).any() + ): + return True + elif np.isnan(tf_reshape).any() and np.isnan(trt_reshape).any(): + return True + print( + "trt cross_check output ", + str(np.allclose(tf_reshape.flatten(), trt_reshape.flatten(), atol=1e-5)), + flush=True, + ) + return bool(np.allclose(tf_reshape.flatten(), trt_reshape.flatten(), atol=1e-5)) + + +def get_onnxruntime_output(model, inputs): + import onnxruntime.backend + + rep = onnxruntime.backend.prepare(model, "GPU") + if isinstance(inputs, list) and len(inputs) == 1: + inp = inputs[0] + else: + inp = inputs + output = rep.run(inp) + # Unpack output if there's only a single value. + if len(output) == 1: + output = output[0] + return output + + +def verify_with_ort_with_trt( + model, + inputs, + op_name, + opset=None, + dtype="float32", + opt_level=1, + np_result=None, + use_vm=False, + layout=0, +): + if opset is not None: + model.opset_import[0].version = opset + onnx.save(model, INPUT_MODEL_FILE) + if np_result is None: + ort_result = get_onnxruntime_output(model, inputs) + else: + ort_result = np_result + + in_data = convert_to_list(inputs) + ops_name = [op_name] + + _, trt_plugin_names = tpat.cuda.pipeline( + INPUT_MODEL_FILE, + ops_name, + False, + {"work_dir": "./log_db", "max_trials_per_task": 500}, + OUTPUT_MODEL_FILE, + ) + + libs = load_plugin(trt_plugin_names) + engine = build_engine(OUTPUT_MODEL_FILE, trt_engine_datatype=trt.DataType.HALF) + + inputs, outputs, bindings, stream = allocate_buffers(engine) + with engine.create_execution_context() as context: + for i in range(len(inputs)): + input_data = in_data[i].ravel() + np.copyto(inputs[i].host, input_data) + + trt_result = do_inference( + context, + bindings=bindings, + inputs=inputs, + outputs=outputs, + stream=stream, + ) + + remove_plugin(libs) + + ret = True + if len(trt_result) == 1: + ret = compare_tf_trt_result(ort_result, trt_result) + else: + for i in range(len(trt_result)): + ret &= compare_tf_trt_result(ort_result[i], trt_result[i]) + assert ret, "result check False" + return ret + + +def make_constant_node(name, data_type, dims, vals): + return helper.make_node( + "Constant", + inputs=[], + outputs=[name], + value=helper.make_tensor(name=name, data_type=data_type, dims=dims, vals=vals), + ) + + +def make_onnx_model(node, inputs, outputs, name, **kwargs): + present_inputs = [x for x in node.input if (x != "")] + present_outputs = [x for x in node.output if (x != "")] + input_type_protos = [None] * len(inputs) + if "input_type_protos" in kwargs: + input_type_protos = kwargs[str("input_type_protos")] + del kwargs[str("input_type_protos")] + output_type_protos = [None] * len(outputs) + if "output_type_protos" in kwargs: + output_type_protos = kwargs[str("output_type_protos")] + del kwargs[str("output_type_protos")] + inputs_vi = [ + _extract_value_info(arr, arr_name, input_type) + for arr, arr_name, input_type in zip(inputs, present_inputs, input_type_protos) + ] + outputs_vi = [ + _extract_value_info(arr, arr_name, output_type) + for arr, arr_name, output_type in zip(outputs, present_outputs, output_type_protos) + ] + graph = helper.make_graph(nodes=[node], name=name, inputs=inputs_vi, outputs=outputs_vi) + kwargs[str("producer_name")] = "TRTPluginAutoGen-test" + model = onnx.helper.make_model(graph, **kwargs) + return model + + +def op_expect(node, inputs, outputs, op_type, op_name, np_result=None): + model = make_onnx_model(node, inputs=inputs, outputs=outputs, name="test_{}".format(op_type)) + verify_with_ort_with_trt(model, inputs, op_name, np_result=np_result) + + +# ==================================================================================== +# ---UnitTest +# ==================================================================================== + + +def test_abs(): + op_name = "abs_0" + op_type = "Abs" + x = np.random.randn(3, 4, 5).astype(np.float32) + y = abs(x) + node = helper.make_node(op_type, inputs=["x"], outputs=["y"], name=op_name) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_acos(): + op_name = "acos_0" + op_type = "Acos" + node = onnx.helper.make_node("Acos", inputs=["x"], outputs=["y"], name=op_name) + x = np.array([-0.5, 0, 0.5]).astype(np.float32) + y = np.arccos(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "acos_1" + op_type = "Acos" + node = onnx.helper.make_node("Acos", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.rand(3, 4, 5).astype(np.float32) + y = np.arccos(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +@pytest.mark.skip(reason="TensorRT segmentfault") +def test_and(): + op_name = "and_0" + op_type = "And" + node = onnx.helper.make_node("And", inputs=["x", "y"], outputs=["and"], name=op_name) + # 2d + x = (np.random.randn(3, 4) > 0).astype(bool) + y = (np.random.randn(3, 4) > 0).astype(bool) + z = np.logical_and(x, y) + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + op_name = "and_1" + op_type = "And" + node = onnx.helper.make_node("And", inputs=["x", "y"], outputs=["and"], name=op_name) + x = (np.random.randn(3, 4, 5) > 0).astype(bool) + y = (np.random.randn(3, 4, 5) > 0).astype(bool) + z = np.logical_and(x, y) + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + op_name = "and_2" + op_type = "And" + node = onnx.helper.make_node("And", inputs=["x", "y"], outputs=["and"], name=op_name) + x = (np.random.randn(3, 4, 5, 6) > 0).astype(bool) + y = (np.random.randn(3, 4, 5, 6) > 0).astype(bool) + z = np.logical_and(x, y) + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + +def test_add(): + op_name = "add_0" + op_type = "Add" + node = onnx.helper.make_node("Add", inputs=["x", "y"], outputs=["sum"], name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.random.randn(3, 4, 5).astype(np.float32) + op_expect(node, inputs=[x, y], outputs=[x + y], op_type=op_type, op_name=op_name) + + op_name = "add_1" + op_type = "Add" + node = onnx.helper.make_node("Add", inputs=["x", "y"], outputs=["sum"], name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.random.randn(5).astype(np.float32) + op_expect(node, inputs=[x, y], outputs=[x + y], op_type=op_type, op_name=op_name) + + +def test_argmax(): + op_type = "ArgMax" + op_name = "argmax_0" + data = np.array([[2, 1, 3, 10], [3, 4, 5, 6]], dtype=np.float32) + keepdims = 1 + axis = -1 + node = onnx.helper.make_node( + "ArgMax", + inputs=["data"], + outputs=["result"], + keepdims=keepdims, + axis=axis, + name=op_name, + ) + + # result: [[1], [1]] + from onnx.backend.test.case.node.argmax import argmax_use_numpy + + result = argmax_use_numpy(data, keepdims=keepdims, axis=axis) + op_expect(node, inputs=[data], outputs=[result], op_type=op_type, op_name=op_name) + + op_name = "argmax_1" + node = onnx.helper.make_node( + "ArgMax", + inputs=["data"], + outputs=["result"], + keepdims=keepdims, + axis=axis, + name=op_name, + ) + + data = np.random.uniform(-10, 10, [2, 3, 4]).astype(np.float32) + # result's shape: [1, 3, 4] + result = argmax_use_numpy(data, keepdims=keepdims, axis=axis) + op_expect(node, inputs=[data], outputs=[result], op_type=op_type, op_name=op_name) + + +def test_argmin(): + op_type = "ArgMin" + op_name = "argmin_0" + data = np.array([[2, 1], [3, 10]], dtype=np.float32) + keepdims = 1 + axis = 1 + node = onnx.helper.make_node( + "ArgMin", + inputs=["data"], + outputs=["result"], + keepdims=keepdims, + axis=axis, + name=op_name, + ) + + # result: [[1], [1]] + from onnx.backend.test.case.node.argmin import argmin_use_numpy + + result = argmin_use_numpy(data, keepdims=keepdims, axis=axis) + op_expect(node, inputs=[data], outputs=[result], op_type=op_type, op_name=op_name) + + +def test_asin(): + op_name = "asin_0" + op_type = "Asin" + node = onnx.helper.make_node("Asin", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([-0.5, 0, 0.5]).astype(np.float32) + y = np.arcsin(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "asin_1" + op_type = "Asin" + node = onnx.helper.make_node("Asin", inputs=["x"], outputs=["y"], name=op_name) + + x = np.random.rand(3, 4, 5).astype(np.float32) + y = np.arcsin(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_asinh(): + op_name = "asinh_0" + op_type = "Asinh" + node = onnx.helper.make_node("Asinh", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([-1, 0, 1]).astype(np.float32) + y = np.arcsinh(x) # expected output [-0.88137358, 0., 0.88137358] + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "asinh_1" + op_type = "Asinh" + node = onnx.helper.make_node("Asinh", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.arcsinh(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_atan(): + op_type = "Atan" + op_name = "atan_0" + node = onnx.helper.make_node("Atan", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([-1, 0, 1]).astype(np.float32) + y = np.arctan(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_type = "Atan" + op_name = "atan_1" + node = onnx.helper.make_node("Atan", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.arctan(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_atanh(): + op_name = "atanh_0" + op_type = "Atanh" + node = onnx.helper.make_node("Atanh", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([-0.5, 0, 0.5]).astype(np.float32) + y = np.arctanh(x) # expected output [-0.54930615, 0., 0.54930615] + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "atanh_1" + op_type = "Atanh" + node = onnx.helper.make_node("Atanh", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.uniform(0.0, 1.0, (3, 4, 5)).astype(np.float32) + y = np.arctanh(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_averagepool(): + op_name = "averagepool_1d_default" + op_type = "AveragePool" + """ + input_shape: [1, 3, 32] + output_shape: [1, 3, 31] + """ + node = onnx.helper.make_node( + "AveragePool", inputs=["x"], outputs=["y"], kernel_shape=[2], name=op_name + ) + x = np.random.randn(1, 3, 32).astype(np.float32) + x_shape = np.shape(x) + kernel_shape = [2] + strides = [1] + from onnx.backend.test.case.node.pool_op_common import get_output_shape, pool + + out_shape = get_output_shape("VALID", x_shape[2:], kernel_shape, strides) + padded = x + y = pool(padded, x_shape, kernel_shape, strides, out_shape, [0], "AVG") + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "averagepool_2d_ceil" + op_type = "AveragePool" + node = onnx.helper.make_node( + "AveragePool", + inputs=["x"], + outputs=["y"], + kernel_shape=[3, 3], + strides=[2, 2], + ceil_mode=True, + name=op_name, + ) + x = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ] + ).astype(np.float32) + y = np.array([[[[6, 7.5], [12, 13.5]]]]).astype(np.float32) + + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +@pytest.mark.skip(reason="TensorRT segmentfault") +def test_batchnormalization(): + op_name = "batchnormalization_0" + op_type = "BatchNormalization" + # input size: (2, 3, 4, 5) + x = np.random.randn(2, 3, 4, 5).astype(np.float32) + s = np.random.randn(3).astype(np.float32) + bias = np.random.randn(3).astype(np.float32) + mean = np.random.randn(3).astype(np.float32) + var = np.random.rand(3).astype(np.float32) + from onnx.backend.test.case.node.batchnorm import _batchnorm_test_mode + + y = _batchnorm_test_mode(x, s, bias, mean, var).astype(np.float32) + + node = onnx.helper.make_node( + "BatchNormalization", + inputs=["x", "s", "bias", "mean", "var"], + outputs=["y"], + name=op_name, + ) + + # output size: (2, 3, 4, 5) + op_expect( + node, + inputs=[x, s, bias, mean, var], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + +def test_ceil(): + op_name = "ceil_0" + op_type = "Ceil" + node = onnx.helper.make_node("Ceil", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([-1.5, 1.2]).astype(np.float32) + y = np.ceil(x) # expected output [-1., 2.] + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "ceil_1" + op_type = "Ceil" + node = onnx.helper.make_node("Ceil", inputs=["x"], outputs=["y"], name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.ceil(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_celu(): + op_name = "celu_0" + op_type = "Celu" + alpha = 2.0 + node = onnx.helper.make_node("Celu", inputs=["X"], outputs=["Y"], alpha=alpha, name=op_name) + + input_data = np.array( + [ + [ + [[0.8439683], [0.5665144], [0.05836735]], + [[0.02916367], [0.12964272], [0.5060197]], + [[0.79538304], [0.9411346], [0.9546573]], + ], + [ + [[0.17730942], [0.46192095], [0.26480448]], + [[0.6746842], [0.01665257], [0.62473077]], + [[0.9240844], [0.9722341], [0.11965699]], + ], + [ + [[0.41356155], [0.9129373], [0.59330076]], + [[0.81929934], [0.7862604], [0.11799799]], + [[0.69248444], [0.54119414], [0.07513223]], + ], + ], + dtype=np.float32, + ) + + # Calculate expected output data + positive_input = np.maximum(0, input_data) + negative_input = np.minimum(0, alpha * (np.exp(input_data / alpha) - 1)) + expected_output = positive_input + negative_input + + op_expect( + node, + inputs=[input_data], + outputs=[expected_output], + op_type=op_type, + op_name=op_name, + ) + + +def test_clip(): + op_name = "Clip_0" + op_type = "Clip" + node = onnx.helper.make_node("Clip", inputs=["x", "min", "max"], outputs=["y"], name=op_name) + x = np.array([-2, 0, 2]).astype(np.float32) + min_val = np.array([-1.0]).astype(np.float32) # .float32(-1.0) + max_val = np.array([1.0]).astype(np.float32) # .float32(1.0) + y = np.clip(x, min_val, max_val) # expected output [-1., 0., 1.] + op_expect( + node, + inputs=[x, min_val, max_val], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + op_name = "Clip_1" + op_type = "Clip" + node = onnx.helper.make_node("Clip", inputs=["x", "min", "max"], outputs=["y"], name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.clip(x, min_val, max_val) + op_expect( + node, + inputs=[x, min_val, max_val], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + op_name = "Clip_2" + op_type = "Clip" + node = onnx.helper.make_node("Clip", inputs=["x", "min", "max"], outputs=["y"], name=op_name) + min_val = np.array([-5.0]).astype(np.float32) # .float32(-1.0) + max_val = np.array([5.0]).astype(np.float32) # .float32(1.0) + op_name = "Clip_3" + op_type = "Clip" + node = onnx.helper.make_node("Clip", inputs=["x", "min", "max"], outputs=["y"], name=op_name) + + x = np.array([-1, 0, 1]).astype(np.float32) + y = np.array([-1, 0, 1]).astype(np.float32) + op_expect( + node, + inputs=[x, min_val, max_val], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + op_name = "Clip_4" + op_type = "Clip" + node = onnx.helper.make_node("Clip", inputs=["x", "min", "max"], outputs=["y"], name=op_name) + x = np.array([-6, 0, 6]).astype(np.float32) + y = np.array([-5, 0, 5]).astype(np.float32) + op_expect( + node, + inputs=[x, min_val, max_val], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + op_name = "Clip_5" + op_type = "Clip" + node = onnx.helper.make_node("Clip", inputs=["x", "min", "max"], outputs=["y"], name=op_name) + x = np.array([-1, 0, 6]).astype(np.float32) + y = np.array([-1, 0, 5]).astype(np.float32) + op_expect( + node, + inputs=[x, min_val, max_val], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + +def test_concat(): + test_cases = { + "1d": ([1, 2], [3, 4]), + "2d": ([[1, 2], [3, 4]], [[5, 6], [7, 8]]), + "3d": ( + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + [[[9, 10], [11, 12]], [[13, 14], [15, 16]]], + ), + } # type: Dict[Text, Sequence[Any]] + + for test_case, values_ in test_cases.items(): + values = [np.asarray(v, dtype=np.float32) for v in values_] + for i in range(len(values[0].shape)): + op_name = "concat_{}_{}".format(test_case, i) + op_type = "Concat" + in_args = ["value" + str(k) for k in range(len(values))] + node = onnx.helper.make_node( + "Concat", + inputs=[s for s in in_args], + outputs=["output"], + axis=i, + name=op_name, + ) + output = np.concatenate(values, i) + op_expect( + node, + inputs=[v for v in values], + outputs=[output], + op_type=op_type, + op_name=op_name, + ) + + for i in range(-len(values[0].shape), 0): + op_name = "concat_{}_1_{}".format(test_case, abs(i)) + op_type = "Concat" + in_args = ["value" + str(k) for k in range(len(values))] + node = onnx.helper.make_node( + "Concat", + inputs=[s for s in in_args], + outputs=["output"], + axis=i, + name=op_name, + ) + output = np.concatenate(values, i) + op_expect( + node, + inputs=[v for v in values], + outputs=[output], + op_type=op_type, + op_name=op_name, + ) + + +def test_conv(): + # ------Conv + op_name, op_type = "test_basic_conv_with_padding", "Conv" + x = np.array( + [ + [ + [ + [0.0, 1.0, 2.0, 3.0, 4.0], # (1, 1, 5, 5) input tensor + [5.0, 6.0, 7.0, 8.0, 9.0], + [10.0, 11.0, 12.0, 13.0, 14.0], + [15.0, 16.0, 17.0, 18.0, 19.0], + [20.0, 21.0, 22.0, 23.0, 24.0], + ] + ] + ] + ).astype(np.float32) + # NOCC:invalid-name(其他:onnx example) + W = np.array( + [ + [ + [ + [1.0, 1.0, 1.0], # (1, 1, 3, 3) tensor for convolution weights + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + ] + ] + ).astype(np.float32) + + # Convolution with padding + node_with_padding = onnx.helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=[3, 3], + # Default values for other attributes: strides=[1, 1], dilations=[1, 1], groups=1 + pads=[1, 1, 1, 1], + name=op_name, + ) + y_with_padding = np.array( + [ + [ + [ + [12.0, 21.0, 27.0, 33.0, 24.0], # (1, 1, 5, 5) output tensor + [33.0, 54.0, 63.0, 72.0, 51.0], + [63.0, 99.0, 108.0, 117.0, 81.0], + [93.0, 144.0, 153.0, 162.0, 111.0], + [72.0, 111.0, 117.0, 123.0, 84.0], + ] + ] + ] + ).astype(np.float32) + op_expect( + node_with_padding, + inputs=[x, W], + outputs=[y_with_padding], + op_type=op_type, + op_name=op_name, + ) + + op_name, op_type = "test_basic_conv_without_padding", "Conv" + # Convolution without padding + node_without_padding = onnx.helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=[3, 3], + # Default values for other attributes: strides=[1, 1], dilations=[1, 1], groups=1 + pads=[0, 0, 0, 0], + name=op_name, + ) + y_without_padding = np.array( + [ + [ + [ + [54.0, 63.0, 72.0], # (1, 1, 3, 3) output tensor + [99.0, 108.0, 117.0], + [144.0, 153.0, 162.0], + ] + ] + ] + ).astype(np.float32) + op_expect( + node_without_padding, + inputs=[x, W], + outputs=[y_without_padding], + op_type=op_type, + op_name=op_name, + ) + + # conv_with_autopad_same + op_name, op_type = "test_conv_with_autopad_same", "Conv" + x = np.array( + [ + [ + [ + [0.0, 1.0, 2.0, 3.0, 4.0], # (1, 1, 5, 5) input tensor + [5.0, 6.0, 7.0, 8.0, 9.0], + [10.0, 11.0, 12.0, 13.0, 14.0], + [15.0, 16.0, 17.0, 18.0, 19.0], + [20.0, 21.0, 22.0, 23.0, 24.0], + ] + ] + ] + ).astype(np.float32) + # NOCC:invalid-name(其他:onnx example) + W = np.array( + [ + [ + [ + [1.0, 1.0, 1.0], # (1, 1, 3, 3) tensor for convolution weights + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + ] + ] + ).astype(np.float32) + + # Convolution with auto_pad='SAME_LOWER' and strides=2 + node = onnx.helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + auto_pad="SAME_LOWER", + kernel_shape=[3, 3], + strides=[2, 2], + name=op_name, + ) + y = np.array([[[[12.0, 27.0, 24.0], [63.0, 108.0, 81.0], [72.0, 117.0, 84.0]]]]).astype( + np.float32 + ) + op_expect(node, inputs=[x, W], outputs=[y], op_type=op_type, op_name=op_name) + + # conv_with_strides + op_name, op_type = "test_conv_with_strides_padding", "Conv" + x = np.array( + [ + [ + [ + [0.0, 1.0, 2.0, 3.0, 4.0], # (1, 1, 7, 5) input tensor + [5.0, 6.0, 7.0, 8.0, 9.0], + [10.0, 11.0, 12.0, 13.0, 14.0], + [15.0, 16.0, 17.0, 18.0, 19.0], + [20.0, 21.0, 22.0, 23.0, 24.0], + [25.0, 26.0, 27.0, 28.0, 29.0], + [30.0, 31.0, 32.0, 33.0, 34.0], + ] + ] + ] + ).astype(np.float32) + # NOCC:invalid-name(其他:onnx example) + W = np.array( + [ + [ + [ + [1.0, 1.0, 1.0], # (1, 1, 3, 3) tensor for convolution weights + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + ] + ] + ).astype(np.float32) + + # Convolution with strides=2 and padding + node_with_padding = onnx.helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[ + 2, + 2, + ], # Default values for other attributes: dilations=[1, 1], groups=1 + name=op_name, + ) + y_with_padding = np.array( + [ + [ + [ + [12.0, 27.0, 24.0], # (1, 1, 4, 3) output tensor + [63.0, 108.0, 81.0], + [123.0, 198.0, 141.0], + [112.0, 177.0, 124.0], + ] + ] + ] + ).astype(np.float32) + op_expect( + node_with_padding, + inputs=[x, W], + outputs=[y_with_padding], + op_type=op_type, + op_name=op_name, + ) + + op_name = "test_conv_with_strides_no_padding" + # Convolution with strides=2 and no padding + node_without_padding = onnx.helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=[3, 3], + pads=[0, 0, 0, 0], + strides=[ + 2, + 2, + ], # Default values for other attributes: dilations=[1, 1], groups=1 + name=op_name, + ) + y_without_padding = np.array( + [[[[54.0, 72.0], [144.0, 162.0], [234.0, 252.0]]]] # (1, 1, 3, 2) output tensor + ).astype(np.float32) + op_expect( + node_without_padding, + inputs=[x, W], + outputs=[y_without_padding], + op_type=op_type, + op_name=op_name, + ) + + op_name = "test_conv_with_strides_and_asymmetric_padding" + # Convolution with strides=2 and padding only along one dimension (the H dimension in NxCxHxW tensor) + node_with_asymmetric_padding = onnx.helper.make_node( + "Conv", + inputs=["x", "W"], + outputs=["y"], + kernel_shape=[3, 3], + pads=[1, 0, 1, 0], + strides=[ + 2, + 2, + ], # Default values for other attributes: dilations=[1, 1], groups=1 + name=op_name, + ) + y_with_asymmetric_padding = np.array( + [ + [ + [ + [21.0, 33.0], # (1, 1, 4, 2) output tensor + [99.0, 117.0], + [189.0, 207.0], + [171.0, 183.0], + ] + ] + ] + ).astype(np.float32) + op_expect( + node_with_asymmetric_padding, + inputs=[x, W], + outputs=[y_with_asymmetric_padding], + op_type=op_type, + op_name=op_name, + ) + + +def test_convtranspose(): + op_name, op_type = "test_convtranspose", "ConvTranspose" + x = np.array([[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]]).astype( # (1, 1, 3, 3) + np.float32 + ) + + # NOCC:invalid-name(其他:onnx example) + W = np.array( + [ + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], # (1, 2, 3, 3) + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ] + ).astype(np.float32) + + node = onnx.helper.make_node("ConvTranspose", ["X", "W"], ["Y"], name=op_name) + + y = np.array( + [ + [ + [ + [0.0, 1.0, 3.0, 3.0, 2.0], # (1, 2, 5, 5) + [3.0, 8.0, 15.0, 12.0, 7.0], + [9.0, 21.0, 36.0, 27.0, 15.0], + [9.0, 20.0, 33.0, 24.0, 13.0], + [6.0, 13.0, 21.0, 15.0, 8.0], + ], + [ + [0.0, 1.0, 3.0, 3.0, 2.0], + [3.0, 8.0, 15.0, 12.0, 7.0], + [9.0, 21.0, 36.0, 27.0, 15.0], + [9.0, 20.0, 33.0, 24.0, 13.0], + [6.0, 13.0, 21.0, 15.0, 8.0], + ], + ] + ] + ).astype(np.float32) + + op_expect(node, inputs=[x, W], outputs=[y], op_type=op_type, op_name=op_name) + + op_name, op_type = "test_convtranspose_1d", "ConvTranspose" + + x = np.array([[[0.0, 1.0, 2.0]]]).astype(np.float32) # (1, 1, 3) + + # NOCC:invalid-name(其他:onnx example) + W = np.array([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]).astype(np.float32) # (1, 2, 3) + + node = onnx.helper.make_node("ConvTranspose", ["X", "W"], ["Y"], name=op_name) + + y = np.array([[[0.0, 1.0, 3.0, 3.0, 2.0], [0.0, 1.0, 3.0, 3.0, 2.0]]]).astype( # (1, 2, 5) + np.float32 + ) + + op_expect(node, inputs=[x, W], outputs=[y], op_type=op_type, op_name=op_name) + + op_name, op_type = "test_convtranspose_3d", "ConvTranspose" + x = np.array( + [ + [ + [ + [ + [0.0, 1.0, 2.0, 3.0, 4.0], # (1, 1, 3, 4, 5) + [5.0, 6.0, 7.0, 8.0, 9.0], + [10.0, 11.0, 12.0, 13.0, 14.0], + [15.0, 16.0, 17.0, 18.0, 19.0], + ], + [ + [20.0, 21.0, 22.0, 23.0, 24.0], + [25.0, 26.0, 27.0, 28.0, 29.0], + [30.0, 31.0, 32.0, 33.0, 34.0], + [35.0, 36.0, 37.0, 38.0, 39.0], + ], + [ + [40.0, 41.0, 42.0, 43.0, 44.0], + [45.0, 46.0, 47.0, 48.0, 49.0], + [50.0, 51.0, 52.0, 53.0, 54.0], + [55.0, 56.0, 57.0, 58.0, 59.0], + ], + ] + ] + ] + ).astype(np.float32) + + # NOCC:invalid-name(其他:onnx example) + W = np.array( + [ + [ + [ + [ + [1.0, 1.0, 1.0], # (1, 2, 3, 3, 3) + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ], + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ], + ] + ] + ).astype(np.float32) + + node = onnx.helper.make_node("ConvTranspose", ["X", "W"], ["Y"], name=op_name) + + y = np.array( + [ + [ + [ + [ + [0.0, 1.0, 3.0, 6.0, 9.0, 7.0, 4.0], # (1, 2, 5, 6, 7) + [5.0, 12.0, 21.0, 27.0, 33.0, 24.0, 13.0], + [15.0, 33.0, 54.0, 63.0, 72.0, 51.0, 27.0], + [30.0, 63.0, 99.0, 108.0, 117.0, 81.0, 42.0], + [25.0, 52.0, 81.0, 87.0, 93.0, 64.0, 33.0], + [15.0, 31.0, 48.0, 51.0, 54.0, 37.0, 19.0], + ], + [ + [20.0, 42.0, 66.0, 72.0, 78.0, 54.0, 28.0], + [50.0, 104.0, 162.0, 174.0, 186.0, 128.0, 66.0], + [90.0, 186.0, 288.0, 306.0, 324.0, 222.0, 114.0], + [120.0, 246.0, 378.0, 396.0, 414.0, 282.0, 144.0], + [90.0, 184.0, 282.0, 294.0, 306.0, 208.0, 106.0], + [50.0, 102.0, 156.0, 162.0, 168.0, 114.0, 58.0], + ], + [ + [60.0, 123.0, 189.0, 198.0, 207.0, 141.0, 72.0], + [135.0, 276.0, 423.0, 441.0, 459.0, 312.0, 159.0], + [225.0, 459.0, 702.0, 729.0, 756.0, 513.0, 261.0], + [270.0, 549.0, 837.0, 864.0, 891.0, 603.0, 306.0], + [195.0, 396.0, 603.0, 621.0, 639.0, 432.0, 219.0], + [105.0, 213.0, 324.0, 333.0, 342.0, 231.0, 117.0], + ], + [ + [60.0, 122.0, 186.0, 192.0, 198.0, 134.0, 68.0], + [130.0, 264.0, 402.0, 414.0, 426.0, 288.0, 146.0], + [210.0, 426.0, 648.0, 666.0, 684.0, 462.0, 234.0], + [240.0, 486.0, 738.0, 756.0, 774.0, 522.0, 264.0], + [170.0, 344.0, 522.0, 534.0, 546.0, 368.0, 186.0], + [90.0, 182.0, 276.0, 282.0, 288.0, 194.0, 98.0], + ], + [ + [40.0, 81.0, 123.0, 126.0, 129.0, 87.0, 44.0], + [85.0, 172.0, 261.0, 267.0, 273.0, 184.0, 93.0], + [135.0, 273.0, 414.0, 423.0, 432.0, 291.0, 147.0], + [150.0, 303.0, 459.0, 468.0, 477.0, 321.0, 162.0], + [105.0, 212.0, 321.0, 327.0, 333.0, 224.0, 113.0], + [55.0, 111.0, 168.0, 171.0, 174.0, 117.0, 59.0], + ], + ], + [ + [ + [0.0, 1.0, 3.0, 6.0, 9.0, 7.0, 4.0], + [5.0, 12.0, 21.0, 27.0, 33.0, 24.0, 13.0], + [15.0, 33.0, 54.0, 63.0, 72.0, 51.0, 27.0], + [30.0, 63.0, 99.0, 108.0, 117.0, 81.0, 42.0], + [25.0, 52.0, 81.0, 87.0, 93.0, 64.0, 33.0], + [15.0, 31.0, 48.0, 51.0, 54.0, 37.0, 19.0], + ], + [ + [20.0, 42.0, 66.0, 72.0, 78.0, 54.0, 28.0], + [50.0, 104.0, 162.0, 174.0, 186.0, 128.0, 66.0], + [90.0, 186.0, 288.0, 306.0, 324.0, 222.0, 114.0], + [120.0, 246.0, 378.0, 396.0, 414.0, 282.0, 144.0], + [90.0, 184.0, 282.0, 294.0, 306.0, 208.0, 106.0], + [50.0, 102.0, 156.0, 162.0, 168.0, 114.0, 58.0], + ], + [ + [60.0, 123.0, 189.0, 198.0, 207.0, 141.0, 72.0], + [135.0, 276.0, 423.0, 441.0, 459.0, 312.0, 159.0], + [225.0, 459.0, 702.0, 729.0, 756.0, 513.0, 261.0], + [270.0, 549.0, 837.0, 864.0, 891.0, 603.0, 306.0], + [195.0, 396.0, 603.0, 621.0, 639.0, 432.0, 219.0], + [105.0, 213.0, 324.0, 333.0, 342.0, 231.0, 117.0], + ], + [ + [60.0, 122.0, 186.0, 192.0, 198.0, 134.0, 68.0], + [130.0, 264.0, 402.0, 414.0, 426.0, 288.0, 146.0], + [210.0, 426.0, 648.0, 666.0, 684.0, 462.0, 234.0], + [240.0, 486.0, 738.0, 756.0, 774.0, 522.0, 264.0], + [170.0, 344.0, 522.0, 534.0, 546.0, 368.0, 186.0], + [90.0, 182.0, 276.0, 282.0, 288.0, 194.0, 98.0], + ], + [ + [40.0, 81.0, 123.0, 126.0, 129.0, 87.0, 44.0], + [85.0, 172.0, 261.0, 267.0, 273.0, 184.0, 93.0], + [135.0, 273.0, 414.0, 423.0, 432.0, 291.0, 147.0], + [150.0, 303.0, 459.0, 468.0, 477.0, 321.0, 162.0], + [105.0, 212.0, 321.0, 327.0, 333.0, 224.0, 113.0], + [55.0, 111.0, 168.0, 171.0, 174.0, 117.0, 59.0], + ], + ], + ] + ] + ).astype(np.float32) + + op_expect(node, inputs=[x, W], outputs=[y], op_type=op_type, op_name=op_name) + + op_name, op_type = "test_convtranspose_pads", "ConvTranspose" + + x = np.array([[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]]).astype( # (1, 1, 3, 3) + np.float32 + ) + + # NOCC:invalid-name(其他:onnx example) + W = np.array( + [ + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], # (1, 2, 3, 3) + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ] + ).astype(np.float32) + + node = onnx.helper.make_node( + "ConvTranspose", + ["X", "W"], + ["Y"], + strides=[3, 2], + pads=[1, 2, 1, 2], + name=op_name, + ) + + y = np.array( + [ + [ + [ + [1.0, 1.0, 3.0], # (1, 2, 7, 3) + [1.0, 1.0, 3.0], + [7.0, 4.0, 9.0], + [7.0, 4.0, 9.0], + [7.0, 4.0, 9.0], + [13.0, 7.0, 15.0], + [13.0, 7.0, 15.0], + ], + [ + [1.0, 1.0, 3.0], + [1.0, 1.0, 3.0], + [7.0, 4.0, 9.0], + [7.0, 4.0, 9.0], + [7.0, 4.0, 9.0], + [13.0, 7.0, 15.0], + [13.0, 7.0, 15.0], + ], + ] + ] + ).astype(np.float32) + + op_expect(node, inputs=[x, W], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_cos(): + op_name, op_type = "test_cos_example", "Cos" + node = onnx.helper.make_node("Cos", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([-1, 0, 1]).astype(np.float32) + y = np.cos(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name, op_type = "test_cos", "Cos" + node = onnx.helper.make_node("Cos", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.cos(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_cosh(): + op_name, op_type = "test_cosh_example", "Cosh" + node = onnx.helper.make_node("Cosh", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([-1, 0, 1]).astype(np.float32) + y = np.cosh(x) # expected output [1.54308069, 1., 1.54308069] + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name, op_type = "test_cosh", "Cosh" + node = onnx.helper.make_node("Cosh", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.cosh(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_depthtospace(): + op_name, op_type = "test_depthtospace_crd_mode_example", "DepthToSpace" + node = onnx.helper.make_node( + "DepthToSpace", + inputs=["x"], + outputs=["y"], + blocksize=2, + mode="CRD", + name=op_name, + ) + + # (1, 8, 2, 3) input tensor + x = np.array( + [ + [ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[9.0, 10.0, 11.0], [12.0, 13.0, 14.0]], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + [[27.0, 28.0, 29.0], [30.0, 31.0, 32.0]], + [[36.0, 37.0, 38.0], [39.0, 40.0, 41.0]], + [[45.0, 46.0, 47.0], [48.0, 49.0, 50.0]], + [[54.0, 55.0, 56.0], [57.0, 58.0, 59.0]], + [[63.0, 64.0, 65.0], [66.0, 67.0, 68.0]], + ] + ] + ).astype(np.float32) + + # (1, 2, 4, 6) output tensor + y = np.array( + [ + [ + [ + [0.0, 9.0, 1.0, 10.0, 2.0, 11.0], + [18.0, 27.0, 19.0, 28.0, 20.0, 29.0], + [3.0, 12.0, 4.0, 13.0, 5.0, 14.0], + [21.0, 30.0, 22.0, 31.0, 23.0, 32.0], + ], + [ + [36.0, 45.0, 37.0, 46.0, 38.0, 47.0], + [54.0, 63.0, 55.0, 64.0, 56.0, 65.0], + [39.0, 48.0, 40.0, 49.0, 41.0, 50.0], + [57.0, 66.0, 58.0, 67.0, 59.0, 68.0], + ], + ] + ] + ).astype(np.float32) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_depthtospace_example" + node = onnx.helper.make_node( + "DepthToSpace", + inputs=["x"], + outputs=["y"], + blocksize=2, + mode="DCR", + name=op_name, + ) + + # (1, 8, 2, 3) input tensor + x = np.array( + [ + [ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[9.0, 10.0, 11.0], [12.0, 13.0, 14.0]], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + [[27.0, 28.0, 29.0], [30.0, 31.0, 32.0]], + [[36.0, 37.0, 38.0], [39.0, 40.0, 41.0]], + [[45.0, 46.0, 47.0], [48.0, 49.0, 50.0]], + [[54.0, 55.0, 56.0], [57.0, 58.0, 59.0]], + [[63.0, 64.0, 65.0], [66.0, 67.0, 68.0]], + ] + ] + ).astype(np.float32) + + # (1, 2, 4, 6) output tensor + y = np.array( + [ + [ + [ + [0.0, 18.0, 1.0, 19.0, 2.0, 20.0], + [36.0, 54.0, 37.0, 55.0, 38.0, 56.0], + [3.0, 21.0, 4.0, 22.0, 5.0, 23.0], + [39.0, 57.0, 40.0, 58.0, 41.0, 59.0], + ], + [ + [9.0, 27.0, 10.0, 28.0, 11.0, 29.0], + [45.0, 63.0, 46.0, 64.0, 47.0, 65.0], + [12.0, 30.0, 13.0, 31.0, 14.0, 32.0], + [48.0, 66.0, 49.0, 67.0, 50.0, 68.0], + ], + ] + ] + ).astype(np.float32) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_div(): + op_name, op_type = "test_div_example", "Div" + node = onnx.helper.make_node("Div", inputs=["x", "y"], outputs=["z"], name=op_name) + + x = np.array([3, 4]).astype(np.float32) + y = np.array([1, 2]).astype(np.float32) + z = x / y # expected output [3., 2.] + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + op_name, op_type = "test_div", "Div" + node = onnx.helper.make_node("Div", inputs=["x", "y"], outputs=["z"], name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.random.rand(3, 4, 5).astype(np.float32) + 1.0 + z = x / y + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + op_name, op_type = "test_div_bcast", "Div" + node = onnx.helper.make_node("Div", inputs=["x", "y"], outputs=["z"], name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.random.rand(5).astype(np.float32) + 1.0 + z = x / y + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + +@pytest.mark.skip(reason="TensorRT segmentfault") +def test_einsum(): + op_name, op_type = "test_einsum_batch_diagonal", "Einsum" + eqn = "...ii ->...i" + node = onnx.helper.make_node("Einsum", inputs=["x"], outputs=["y"], equation=eqn, name=op_name) + + # NOCC:invalid-name(其他:onnx example) + X = np.random.randn(3, 5, 5).astype(np.float32) + from onnx.backend.test.case.node.einsum import einsum_reference_implementation + + # NOCC:invalid-name(其他:onnx example) + Z = einsum_reference_implementation(eqn, (X,)) + op_expect(node, inputs=[X], outputs=[Z], op_type=op_type, op_name=op_name) + + +def test_elu(): + op_name, op_type = "test_elu_example", "Elu" + node = onnx.helper.make_node("Elu", inputs=["x"], outputs=["y"], alpha=2.0, name=op_name) + + x = np.array([-1, 0, 1]).astype(np.float32) + # expected output [-1.2642411, 0., 1.] + y = np.clip(x, 0, np.inf) + (np.exp(np.clip(x, -np.inf, 0)) - 1) * 2.0 + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name, op_type = "test_elu", "Elu" + node = onnx.helper.make_node("Elu", inputs=["x"], outputs=["y"], alpha=2.0, name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.clip(x, 0, np.inf) + (np.exp(np.clip(x, -np.inf, 0)) - 1) * 2.0 + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name, op_type = "test_elu_default", "Elu" + default_alpha = 1.0 + node = onnx.helper.make_node("Elu", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.clip(x, 0, np.inf) + (np.exp(np.clip(x, -np.inf, 0)) - 1) * default_alpha + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_erf(): + op_name, op_type = "test_erf", "Erf" + node = onnx.helper.make_node("Erf", inputs=["x"], outputs=["y"], name=op_name) + + x = np.random.randn(1, 3, 32, 32).astype(np.float32) + import math + + y = np.vectorize(math.erf)(x).astype(np.float32) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_exp(): + op_name, op_type = "test_exp_example", "Exp" + node = onnx.helper.make_node("Exp", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([-1, 0, 1]).astype(np.float32) + y = np.exp(x) # expected output [0.36787945, 1., 2.71828175] + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name, op_type = "test_exp", "Exp" + node = onnx.helper.make_node("Exp", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.exp(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_eyelike(): + op_name, op_type = "test_eyelike_populate_off_main_diagonal", "EyeLike" + shape = (4, 5) + off_diagonal_offset = 1 + node = onnx.helper.make_node( + "EyeLike", + inputs=["x"], + outputs=["y"], + k=off_diagonal_offset, + dtype=onnx.TensorProto.FLOAT, + name=op_name, + ) + + x = np.random.randint(0, 100, size=shape, dtype=np.int32) + y = np.eye(shape[0], shape[1], k=off_diagonal_offset, dtype=np.float32) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_eyelike_with_dtype" + shape = (3, 4) + node = onnx.helper.make_node( + "EyeLike", + inputs=["x"], + outputs=["y"], + dtype=onnx.TensorProto.FLOAT, + name=op_name, + ) + + x = np.random.randint(0, 100, size=shape, dtype=np.int32) + y = np.eye(shape[0], shape[1], dtype=np.float32) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_eyelike_without_dtype" + shape = (4, 4) + node = onnx.helper.make_node("EyeLike", inputs=["x"], outputs=["y"], name=op_name) + + x = np.random.randint(0, 100, size=shape, dtype=np.int32) + y = np.eye(shape[0], shape[1], dtype=np.int32) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_floor(): + op_name, op_type = "test_floor_example", "Floor" + node = onnx.helper.make_node("Floor", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([-1.5, 1.2, 2]).astype(np.float32) + y = np.floor(x) # expected output [-2., 1., 2.] + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name, op_type = "test_floor", "Floor" + node = onnx.helper.make_node("Floor", inputs=["x"], outputs=["y"], name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.floor(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def verify_rnn( + seq_length, + batch_size, + input_size, + hidden_size, + rnn_type="LSTM", + use_bias=False, + activations=None, + alphas=None, + betas=None, + use_initial_state=False, + use_peep=False, + linear_before_reset=False, + op_name=None, + layout=0, +): + if rnn_type == "LSTM": + multiplier = 4 + elif rnn_type == "GRU": + multiplier = 3 + else: + raise NotImplementedError("%s RNNs not yet supported." % rnn_type) + + x_np = np.random.uniform(size=(seq_length, batch_size, input_size)).astype("float32") + w_np = np.random.uniform(size=(1, multiplier * hidden_size, input_size)).astype("float32") + r_np = np.random.uniform(size=(1, multiplier * hidden_size, hidden_size)).astype("float32") + input_names = ["X", "W", "R"] + + input_tensors = [ + helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_np.shape)), + helper.make_tensor_value_info("W", TensorProto.FLOAT, list(w_np.shape)), + helper.make_tensor_value_info("R", TensorProto.FLOAT, list(r_np.shape)), + ] + + input_values = [x_np, w_np, r_np] + + if use_bias: + b_np = np.random.uniform(size=(1, multiplier * 2 * hidden_size)).astype("float32") + input_names.append("B") + input_tensors.append( + helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, multiplier * 2 * hidden_size]) + ) + input_values.append(b_np) + + if use_initial_state: + assert use_bias is True, "Initial states must have bias specified." + sequence_np = np.repeat(seq_length, batch_size).astype("int32") + input_names.append("sequence_lens") + input_tensors.append( + helper.make_tensor_value_info("sequence_lens", TensorProto.INT32, [batch_size]) + ) + input_values.append(sequence_np) + + initial_h_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype("float32") + input_names.append("initial_h") + input_tensors.append( + helper.make_tensor_value_info( + "initial_h", TensorProto.FLOAT, [1, batch_size, hidden_size] + ) + ) + input_values.append(initial_h_np) + + if rnn_type == "LSTM": + initial_c_np = np.random.uniform(size=(1, batch_size, hidden_size)).astype("float32") + input_names.append("initial_c") + input_tensors.append( + helper.make_tensor_value_info( + "initial_c", TensorProto.FLOAT, [1, batch_size, hidden_size] + ) + ) + input_values.append(initial_c_np) + + if use_peep and rnn_type == "LSTM": + assert use_initial_state is True, "Peepholes require initial state to be specified." + p_np = np.random.uniform(size=(1, 3 * hidden_size)).astype("float32") + input_names.append("P") + input_tensors.append( + helper.make_tensor_value_info("P", TensorProto.FLOAT, [1, 3 * hidden_size]) + ) + input_values.append(p_np) + + Y_shape = [seq_length, 1, batch_size, hidden_size] + Y_h_shape = [1, batch_size, hidden_size] + outputs = ["Y", "Y_h"] + + graph_outputs = [ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(Y_shape)), + helper.make_tensor_value_info("Y_h", TensorProto.FLOAT, list(Y_h_shape)), + ] + output_shapes = [Y_shape, Y_h_shape] + + if rnn_type == "LSTM": + Y_c_shape_0 = [1, batch_size, hidden_size] + outputs.append("Y_c") + graph_outputs.append( + helper.make_tensor_value_info("Y_c", TensorProto.FLOAT, list(Y_c_shape_0)) + ) + output_shapes.append(Y_c_shape_0) + + rnn_node = helper.make_node( + rnn_type, + inputs=input_names, + outputs=outputs, + hidden_size=hidden_size, + layout=0, + name=op_name, + ) + if activations is not None: + activations_attr = helper.make_attribute("activations", activations) + rnn_node.attribute.append(activations_attr) + if alphas is not None: + alphas_attr = helper.make_attribute("activation_alpha", alphas) + rnn_node.attribute.append(alphas_attr) + if betas is not None: + betas_attr = helper.make_attribute("activation_beta", betas) + rnn_node.attribute.append(betas_attr) + if linear_before_reset and rnn_type == "GRU": + lbr_attr = helper.make_attribute("linear_before_reset", 1) + rnn_node.attribute.append(lbr_attr) + + graph = helper.make_graph([rnn_node], "rnn_test", inputs=input_tensors, outputs=graph_outputs) + + model = helper.make_model(graph, producer_name="rnn_test") + + verify_with_ort_with_trt(model, input_values, op_name, layout=layout) + + +def test_gather(): + op_name, op_type = "test_gather_0", "Gather" + node = onnx.helper.make_node( + "Gather", inputs=["data", "indices"], outputs=["y"], axis=0, name=op_name + ) + data = np.random.randn(5, 4, 3, 2).astype(np.float32) + indices = np.array([0, 1, 3]) + y = np.take(data, indices, axis=0) + + op_expect( + node, + inputs=[data, indices.astype(np.int64)], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + op_name = "test_gather_1" + node = onnx.helper.make_node( + "Gather", inputs=["data", "indices"], outputs=["y"], axis=1, name=op_name + ) + data = np.random.randn(5, 4, 3, 2).astype(np.float32) + indices = np.array([0, 1, 3]) + y = np.take(data, indices, axis=1) + + op_expect( + node, + inputs=[data, indices.astype(np.int64)], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + op_name = "test_gather_2d_indices" + node = onnx.helper.make_node( + "Gather", inputs=["data", "indices"], outputs=["y"], axis=1, name=op_name + ) + data = np.random.randn(3, 3).astype(np.float32) + indices = np.array([[0, 2]]) + y = np.take(data, indices, axis=1) + + op_expect( + node, + inputs=[data, indices.astype(np.int64)], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + op_name = "test_gather_negative_indices" + node = onnx.helper.make_node( + "Gather", inputs=["data", "indices"], outputs=["y"], axis=0, name=op_name + ) + data = np.arange(10).astype(np.float32) + indices = np.array([0, -9, -10]) + y = np.take(data, indices, axis=0) + + # print(y) + # [0. 1. 0.] + + op_expect( + node, + inputs=[data, indices.astype(np.int64)], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + +def test_gatherelement(): + op_name, op_type = "test_gather_elements_0", "GatherElements" + axis = 1 + node = onnx.helper.make_node( + "GatherElements", + inputs=["data", "indices"], + outputs=["y"], + axis=axis, + name=op_name, + ) + data = np.array([[1, 2], [3, 4]], dtype=np.float32) + indices = np.array([[0, 0], [1, 0]], dtype=np.int32) + + from onnx.backend.test.case.node.gatherelements import gather_elements + + y = gather_elements(data, indices, axis) + # print(y) produces + # [[1, 1], + # [4, 3]] + + op_expect( + node, + inputs=[data, indices.astype(np.int64)], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + op_name = "test_gather_elements_1" + axis = 0 + node = onnx.helper.make_node( + "GatherElements", + inputs=["data", "indices"], + outputs=["y"], + axis=axis, + name=op_name, + ) + data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32) + indices = np.array([[1, 2, 0], [2, 0, 0]], dtype=np.int32) + + y = gather_elements(data, indices, axis) + # print(y) produces + # [[4, 8, 3], + # [7, 2, 3]] + op_expect( + node, + inputs=[data, indices.astype(np.int64)], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + op_name = "test_gather_elements_negative_indices" + axis = 0 + node = onnx.helper.make_node( + "GatherElements", + inputs=["data", "indices"], + outputs=["y"], + axis=axis, + name=op_name, + ) + data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32) + indices = np.array([[-1, -2, 0], [-2, 0, 0]], dtype=np.int32) + + y = gather_elements(data, indices, axis) + # print(y) produces + # [[7, 5, 3], + # [4, 2, 3]] + op_expect( + node, + inputs=[data, indices.astype(np.int64)], + outputs=[y], + op_type=op_type, + op_name=op_name, + ) + + +def test_gathernd(): + op_name, op_type = "test_gathernd_example_float32", "GatherND" + node = onnx.helper.make_node( + "GatherND", inputs=["data", "indices"], outputs=["output"], name=op_name + ) + + data = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.float32) + indices = np.array([[[0, 1]], [[1, 0]]], dtype=np.int64) + from onnx.backend.test.case.node.gathernd import gather_nd_impl + + output = gather_nd_impl(data, indices, 0) + expected_output = np.array([[[2, 3]], [[4, 5]]], dtype=np.float32) + assert np.array_equal(output, expected_output) + op_expect(node, inputs=[data, indices], outputs=[output], op_type=op_type, op_name=op_name) + + op_name = "test_gathernd_example_int32" + node = onnx.helper.make_node( + "GatherND", inputs=["data", "indices"], outputs=["output"], name=op_name + ) + + data = np.array([[0, 1], [2, 3]], dtype=np.int32) + indices = np.array([[0, 0], [1, 1]], dtype=np.int64) + output = gather_nd_impl(data, indices, 0) + expected_output = np.array([0, 3], dtype=np.int32) + assert np.array_equal(output, expected_output) + op_expect(node, inputs=[data, indices], outputs=[output], op_type=op_type, op_name=op_name) + + op_name = "test_gathernd_example_int32_batch_dim1" + node = onnx.helper.make_node( + "GatherND", + inputs=["data", "indices"], + outputs=["output"], + batch_dims=1, + name=op_name, + ) + + data = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=np.int32) + indices = np.array([[1], [0]], dtype=np.int64) + output = gather_nd_impl(data, indices, 1) + expected_output = np.array([[2, 3], [4, 5]], dtype=np.int32) + assert np.array_equal(output, expected_output) + op_expect(node, inputs=[data, indices], outputs=[output], op_type=op_type, op_name=op_name) + + +def test_gemm(): + op_name, op_type = "test_gemm_all_attributes", "Gemm" + node = onnx.helper.make_node( + "Gemm", + inputs=["a", "b", "c"], + outputs=["y"], + alpha=0.25, + beta=0.35, + transA=1, + transB=1, + name=op_name, + ) + a = np.random.ranf([4, 3]).astype(np.float32) + b = np.random.ranf([5, 4]).astype(np.float32) + c = np.random.ranf([1, 5]).astype(np.float32) + from onnx.backend.test.case.node.gemm import gemm_reference_implementation + + y = gemm_reference_implementation(a, b, c, transA=1, transB=1, alpha=0.25, beta=0.35) + op_expect(node, inputs=[a, b, c], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_gemm_alpha" + node = onnx.helper.make_node( + "Gemm", inputs=["a", "b", "c"], outputs=["y"], alpha=0.5, name=op_name + ) + a = np.random.ranf([3, 5]).astype(np.float32) + b = np.random.ranf([5, 4]).astype(np.float32) + c = np.zeros([1, 4]).astype(np.float32) + y = gemm_reference_implementation(a, b, c, alpha=0.5) + op_expect(node, inputs=[a, b, c], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_gemm_beta" + node = onnx.helper.make_node( + "Gemm", inputs=["a", "b", "c"], outputs=["y"], beta=0.5, name=op_name + ) + a = np.random.ranf([2, 7]).astype(np.float32) + b = np.random.ranf([7, 4]).astype(np.float32) + c = np.random.ranf([1, 4]).astype(np.float32) + y = gemm_reference_implementation(a, b, c, beta=0.5) + op_expect(node, inputs=[a, b, c], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_globalaveragepool(): + op_name, op_type = "test_globalaveragepool", "GlobalAveragePool" + node = onnx.helper.make_node("GlobalAveragePool", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(1, 3, 5, 5).astype(np.float32) + y = np.mean(x, axis=tuple(range(2, np.ndim(x))), keepdims=True) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_globalaveragepool_precomputed" + node = onnx.helper.make_node("GlobalAveragePool", inputs=["x"], outputs=["y"], name=op_name) + x = np.array( + [ + [ + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ] + ] + ] + ).astype(np.float32) + y = np.array([[[[5]]]]).astype(np.float32) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_globalmaxpool(): + op_name = "test_globalmaxpool" + op_type = "GlobalMaxPool" + node = onnx.helper.make_node("GlobalMaxPool", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(1, 3, 5, 5).astype(np.float32) + y = np.max(x, axis=tuple(range(2, np.ndim(x))), keepdims=True) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_globalmaxpool_precomputed" + node = onnx.helper.make_node("GlobalMaxPool", inputs=["x"], outputs=["y"], name=op_name) + x = np.array( + [ + [ + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ] + ] + ] + ).astype(np.float32) + y = np.array([[[[9]]]]).astype(np.float32) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_hardsigmoid(): + op_name, op_type = "test_hardsigmoid_example", "HardSigmoid" + node = onnx.helper.make_node( + "HardSigmoid", inputs=["x"], outputs=["y"], alpha=0.5, beta=0.6, name=op_name + ) + + x = np.array([-1, 0, 1]).astype(np.float32) + y = np.clip(x * 0.5 + 0.6, 0, 1) # expected output [0.1, 0.6, 1.] + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_hardsigmoid" + node = onnx.helper.make_node( + "HardSigmoid", inputs=["x"], outputs=["y"], alpha=0.5, beta=0.6, name=op_name + ) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.clip(x * 0.5 + 0.6, 0, 1) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_hardsigmoid_default" + + default_alpha = 0.2 + default_beta = 0.5 + node = onnx.helper.make_node("HardSigmoid", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.clip(x * default_alpha + default_beta, 0, 1) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_hardswish(): + op_name, op_type = "test_hardswish", "HardSwish" + node = onnx.helper.make_node("HardSwish", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(3, 4, 5).astype(np.float32) + from onnx.backend.test.case.node.hardswish import hardswish + + y = hardswish(x) + + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_hardmax(): + op_name, op_type = "test_hardmax_example", "Hardmax" + node = onnx.helper.make_node("Hardmax", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([[3, 0, 1, 2], [2, 5, 1, 0], [0, 1, 3, 2], [0, 1, 2, 3]]).astype(np.float32) + # expect result: + # [[1. 0. 0. 0.] + # [0. 1. 0. 0.] + # [0. 0. 1. 0.] + # [0. 0. 0. 1.]] + from onnx.backend.test.case.node.hardmax import hardmax + + y = hardmax(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_identity(): + op_name, op_type = "test_identity", "Identity" + node = onnx.helper.make_node("Identity", inputs=["x"], outputs=["y"], name=op_name) + + data = np.array( + [ + [ + [ + [1, 2], + [3, 4], + ] + ] + ], + dtype=np.float32, + ) + + op_expect(node, inputs=[data], outputs=[data], op_type=op_type, op_name=op_name) + + +def test_instancenormalization(): + op_name, op_type = "test_instancenorm_example", "InstanceNormalization" + + def _instancenorm_test_mode(x, s, bias, epsilon=1e-5): # type: ignore + dims_x = len(x.shape) + axis = tuple(range(2, dims_x)) + mean = np.mean(x, axis=axis, keepdims=True) + var = np.var(x, axis=axis, keepdims=True) + dim_ones = (1,) * (dims_x - 2) + s = s.reshape(-1, *dim_ones) + bias = bias.reshape(-1, *dim_ones) + return s * (x - mean) / np.sqrt(var + epsilon) + bias + + # input size: (1, 2, 1, 3) + x = np.array([[[[-1, 0, 1]], [[2, 3, 4]]]]).astype(np.float32) + s = np.array([1.0, 1.5]).astype(np.float32) + bias = np.array([0, 1]).astype(np.float32) + y = _instancenorm_test_mode(x, s, bias).astype(np.float32) + + node = onnx.helper.make_node( + "InstanceNormalization", inputs=["x", "s", "bias"], outputs=["y"], name=op_name + ) + + # output size: (1, 2, 1, 3) + op_expect(node, inputs=[x, s, bias], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_instancenorm_epsilon" + # input size: (2, 3, 4, 5) + x = np.random.randn(2, 3, 4, 5).astype(np.float32) + s = np.random.randn(3).astype(np.float32) + bias = np.random.randn(3).astype(np.float32) + epsilon = 1e-2 + y = _instancenorm_test_mode(x, s, bias, epsilon).astype(np.float32) + + node = onnx.helper.make_node( + "InstanceNormalization", + inputs=["x", "s", "bias"], + outputs=["y"], + epsilon=epsilon, + name=op_name, + ) + + # output size: (2, 3, 4, 5) + op_expect(node, inputs=[x, s, bias], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_leakyrelu(): + op_name, op_type = "test_leakyrelu_example", "LeakyRelu" + node = onnx.helper.make_node("LeakyRelu", inputs=["x"], outputs=["y"], alpha=0.1, name=op_name) + + x = np.array([-1, 0, 1]).astype(np.float32) + # expected output [-0.1, 0., 1.] + y = np.clip(x, 0, np.inf) + np.clip(x, -np.inf, 0) * 0.1 + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_leakyrelu" + node = onnx.helper.make_node("LeakyRelu", inputs=["x"], outputs=["y"], alpha=0.1, name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.clip(x, 0, np.inf) + np.clip(x, -np.inf, 0) * 0.1 + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_leakyrelu_default" + default_alpha = 0.01 + node = onnx.helper.make_node("LeakyRelu", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.clip(x, 0, np.inf) + np.clip(x, -np.inf, 0) * default_alpha + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_log(): + op_name = "test_log_example" + op_type = "Log" + node = onnx.helper.make_node("Log", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([1, 10]).astype(np.float32) + y = np.log(x) # expected output [0., 2.30258512] + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_log" + node = onnx.helper.make_node("Log", inputs=["x"], outputs=["y"], name=op_name) + + x = np.exp(np.random.randn(3, 4, 5).astype(np.float32)) + y = np.log(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +@pytest.mark.skip(reason="Wrong answer, at axis 1") +def test_logsoftmax(): + op_name, op_type = "test_logsoftmax_example_1", "LogSoftmax" + node = onnx.helper.make_node("LogSoftmax", inputs=["x"], outputs=["y"], name=op_name) + x = np.array([[-1, 0, 1]]).astype(np.float32) + # expected output + # [[-2.4076061 -1.407606 -0.407606 ]] + from onnx.backend.test.case.node.logsoftmax import logsoftmax + + y = logsoftmax(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + x = np.array([[0, 1, 2, 3], [10000, 10001, 10002, 10003]]).astype(np.float32) + axis_order = [0, 1, -1] + for axis in axis_order: + op_name = "test_logsoftmax_axis_{}".format(str(axis + 1)) + node = onnx.helper.make_node( + "LogSoftmax", inputs=["x"], outputs=["y"], axis=axis, name=op_name + ) + y = logsoftmax(x, axis=axis) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_matmul(): + op_name, op_type = "test_matmul_2d", "MatMul" + node = onnx.helper.make_node("MatMul", inputs=["a", "b"], outputs=["c"], name=op_name) + + # 2d + a = np.random.randn(3, 4).astype(np.float32) + b = np.random.randn(4, 3).astype(np.float32) + c = np.matmul(a, b) + op_expect(node, inputs=[a, b], outputs=[c], op_type=op_type, op_name=op_name) + + +def test_max(): + op_name = "test_max_example" + op_type = "Max" + data_0 = np.array([3, 2, 1]).astype(np.float32) + data_1 = np.array([1, 4, 4]).astype(np.float32) + data_2 = np.array([2, 5, 3]).astype(np.float32) + result = np.array([3, 5, 4]).astype(np.float32) + node = onnx.helper.make_node( + "Max", inputs=["data_0", "data_1", "data_2"], outputs=["result"], name=op_name + ) + op_expect( + node, + inputs=[data_0, data_1, data_2], + outputs=[result], + op_type=op_type, + op_name=op_name, + ) + + op_name = "test_max_two_inputs" + result = np.maximum(data_0, data_1) + node = onnx.helper.make_node( + "Max", inputs=["data_0", "data_1"], outputs=["result"], name=op_name + ) + op_expect( + node, + inputs=[data_0, data_1], + outputs=[result], + op_type=op_type, + op_name=op_name, + ) + + +def _test_maxpool_2d_ceil(): + op_name, op_type = "test_maxpool_2d_ceil", "MaxPool" + node = onnx.helper.make_node( + "MaxPool", + inputs=["x"], + outputs=["y"], + kernel_shape=[3, 3], + strides=[2, 2], + ceil_mode=True, + name=op_name, + ) + x = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ] + ).astype(np.float32) + y = np.array([[[[11, 12], [15, 16]]]]).astype(np.float32) + + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def _test_maxpool_1d_default(): + op_name, op_type = "test_maxpool_1d_default", "MaxPool" + node = onnx.helper.make_node( + "MaxPool", inputs=["x"], outputs=["y"], kernel_shape=[2], name=op_name + ) + x = np.random.randn(1, 3, 32).astype(np.float32) + x_shape = np.shape(x) + kernel_shape = [2] + strides = [1] + from onnx.backend.test.case.node.pool_op_common import get_output_shape, pool + + out_shape = get_output_shape("VALID", x_shape[2:], kernel_shape, strides) + padded = x + y = pool(padded, x_shape, kernel_shape, strides, out_shape, [0], "MAX") + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_maxpool(): + _test_maxpool_2d_ceil() + _test_maxpool_1d_default() + + +def test_mean(): + op_name, op_type = "test_mean_example", "Mean" + data_0 = np.array([3, 0, 2]).astype(np.float32) + data_1 = np.array([1, 3, 4]).astype(np.float32) + data_2 = np.array([2, 6, 6]).astype(np.float32) + result = np.array([2, 3, 4]).astype(np.float32) + node = onnx.helper.make_node( + "Mean", inputs=["data_0", "data_1", "data_2"], outputs=["result"], name=op_name + ) + op_expect( + node, + inputs=[data_0, data_1, data_2], + outputs=[result], + op_type=op_type, + op_name=op_name, + ) + + op_name = "test_mean_two_inputs" + result = np.divide(np.add(data_0, data_1), 2.0) + node = onnx.helper.make_node( + "Mean", inputs=["data_0", "data_1"], outputs=["result"], name=op_name + ) + op_expect( + node, + inputs=[data_0, data_1], + outputs=[result], + op_type=op_type, + op_name=op_name, + ) + + +def test_min(): + op_name, op_type = "test_min_example", "Min" + data_0 = np.array([3, 2, 1]).astype(np.float32) + data_1 = np.array([1, 4, 4]).astype(np.float32) + data_2 = np.array([2, 5, 0]).astype(np.float32) + result = np.array([1, 2, 0]).astype(np.float32) + node = onnx.helper.make_node( + "Min", inputs=["data_0", "data_1", "data_2"], outputs=["result"], name=op_name + ) + op_expect( + node, + inputs=[data_0, data_1, data_2], + outputs=[result], + op_type=op_type, + op_name=op_name, + ) + + op_name = "test_min_two_inputs" + result = np.minimum(data_0, data_1) + node = onnx.helper.make_node( + "Min", inputs=["data_0", "data_1"], outputs=["result"], name=op_name + ) + op_expect( + node, + inputs=[data_0, data_1], + outputs=[result], + op_type=op_type, + op_name=op_name, + ) + + +def test_mul(): + op_name, op_type = "test_mul_example", "Mul" + node = onnx.helper.make_node("Mul", inputs=["x", "y"], outputs=["z"], name=op_name) + + x = np.array([1, 2, 3]).astype(np.float32) + y = np.array([4, 5, 6]).astype(np.float32) + z = x * y # expected output [4., 10., 18.] + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + op_name = "test_mul" + node = onnx.helper.make_node("Mul", inputs=["x", "y"], outputs=["z"], name=op_name) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.random.randn(3, 4, 5).astype(np.float32) + z = x * y + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + op_name = "test_mul_bcast" + node = onnx.helper.make_node("Mul", inputs=["x", "y"], outputs=["z"], name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.random.randn(5).astype(np.float32) + z = x * y + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + +def test_neg(): + op_name, op_type = "test_neg_example", "Neg" + node = onnx.helper.make_node("Neg", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([-4, 2]).astype(np.float32) + y = np.negative(x) # expected output [4., -2.], + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_neg" + node = onnx.helper.make_node("Neg", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.negative(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_negativeloglikelihoodloss(): + op_name, op_type = "test_nllloss_NC", "NegativeLogLikelihoodLoss" + reduction = "none" + node = onnx.helper.make_node( + "NegativeLogLikelihoodLoss", + inputs=["input", "target"], + outputs=["loss"], + reduction=reduction, + name=op_name, + ) + + # NOCC:invalid-name(其他:onnx example) + N, C = 3, 5 + np.random.seed(0) + input = np.random.rand(N, C).astype(np.float32) + target = np.random.randint(0, high=C, size=(N,)).astype(np.int64) + from onnx.backend.test.case.node.negativeloglikelihoodloss import ( + compute_negative_log_likelihood_loss, + ) + + negative_log_likelihood_loss = compute_negative_log_likelihood_loss( + input, target, weight=None, reduction=reduction + ) + + op_expect( + node, + inputs=[input, target], + outputs=[negative_log_likelihood_loss], + op_type=op_type, + op_name=op_name, + ) + + +def test_prelu(): + op_name, op_type = "test_prelu_example", "PRelu" + node = onnx.helper.make_node("PRelu", inputs=["x", "slope"], outputs=["y"], name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + slope = np.random.randn(3, 4, 5).astype(np.float32) + y = np.clip(x, 0, np.inf) + np.clip(x, -np.inf, 0) * slope + + op_expect(node, inputs=[x, slope], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_prelu_broadcast" + node = onnx.helper.make_node("PRelu", inputs=["x", "slope"], outputs=["y"], name=op_name) + + x = np.random.randn(3, 4, 5).astype(np.float32) + slope = np.random.randn(5).astype(np.float32) + y = np.clip(x, 0, np.inf) + np.clip(x, -np.inf, 0) * slope + + op_expect(node, inputs=[x, slope], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_pow(): + op_name, op_type = "test_pow_example", "Pow" + node = onnx.helper.make_node("Pow", inputs=["x", "y"], outputs=["z"], name=op_name) + + x = np.array([1, 2, 3]).astype(np.float32) + y = np.array([4, 5, 6]).astype(np.float32) + z = pow(x, y) # expected output [1., 32., 729.] + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + op_name = "test_pow" + node = onnx.helper.make_node("Pow", inputs=["x", "y"], outputs=["z"], name=op_name) + x = np.arange(60).reshape(3, 4, 5).astype(np.float32) + y = np.random.randn(3, 4, 5).astype(np.float32) + z = pow(x, y) + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + op_name = "test_pow_bcast_scalar" + node = onnx.helper.make_node("Pow", inputs=["x", "y"], outputs=["z"], name=op_name) + + x = np.array([1, 2, 3]).astype(np.float32) + y = np.array([2]).astype(np.float32) + z = pow(x, y) # expected output [1., 4., 9.] + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + op_name = "test_pow_bcast_array" + node = onnx.helper.make_node("Pow", inputs=["x", "y"], outputs=["z"], name=op_name) + x = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) + y = np.array([[1, 2, 3]]).astype(np.float32) + # expected output [[1, 4, 27], [4, 25, 216]] + z = pow(x, y) + op_expect(node, inputs=[x, y], outputs=[z], op_type=op_type, op_name=op_name) + + +def test_reciprocal(): + op_name, op_type = "test_reciprocal_example", "Reciprocal" + node = onnx.helper.make_node("Reciprocal", inputs=["x"], outputs=["y"], name=op_name) + + x = np.array([-4, 2]).astype(np.float32) + y = np.reciprocal(x) # expected output [-0.25, 0.5], + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + op_name = "test_reciprocal" + node = onnx.helper.make_node("Reciprocal", inputs=["x"], outputs=["y"], name=op_name) + x = np.random.rand(3, 4, 5).astype(np.float32) + 0.5 + y = np.reciprocal(x) + op_expect(node, inputs=[x], outputs=[y], op_type=op_type, op_name=op_name) + + +def test_reducel1(): + op_name, op_type = "test_reduce_l1_default_axes_keepdims_example", "ReduceL1" + shape = [3, 2, 2] + axes = None + keepdims = 1 + + node = onnx.helper.make_node( + "ReduceL1", + inputs=["data"], + outputs=["reduced"], + keepdims=keepdims, + name=op_name, + ) + + data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) + # print(data) + # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] + + reduced = np.sum(a=np.abs(data), axis=axes, keepdims=keepdims == 1) + # print(reduced) + # [[[78.]]] + + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + np.random.seed(0) + data = np.random.uniform(-10, 10, shape).astype(np.float32) + reduced = np.sum(a=np.abs(data), axis=axes, keepdims=keepdims == 1) + + op_name = "test_reduce_l1_default_axes_keepdims_random" + node = onnx.helper.make_node( + "ReduceL1", + inputs=["data"], + outputs=["reduced"], + keepdims=keepdims, + name=op_name, + ) + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + +def test_reducel2(): + op_name, op_type = "test_reduce_l2_default_axes_keepdims_example", "ReduceL2" + shape = [3, 2, 2] + axes = None + keepdims = 1 + + node = onnx.helper.make_node( + "ReduceL2", + inputs=["data"], + outputs=["reduced"], + keepdims=keepdims, + name=op_name, + ) + + data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) + # print(data) + # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] + + reduced = np.sqrt(np.sum(a=np.square(data), axis=axes, keepdims=keepdims == 1)) + # print(reduced) + # [[[25.49509757]]] + + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + op_name = "test_reduce_l2_default_axes_keepdims_random" + np.random.seed(0) + data = np.random.uniform(-10, 10, shape).astype(np.float32) + reduced = np.sqrt(np.sum(a=np.square(data), axis=axes, keepdims=keepdims == 1)) + node = onnx.helper.make_node( + "ReduceL2", + inputs=["data"], + outputs=["reduced"], + keepdims=keepdims, + name=op_name, + ) + + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + +@pytest.mark.skip(reason="ORT: Unrecognized attribute: axes for operator ReduceLogSu") +def test_reducelogsum(): + op_name, op_type = "test_reduce_log_sum_default", "ReduceLogSum" + node = onnx.helper.make_node("ReduceLogSum", inputs=["data"], outputs=["reduced"], name=op_name) + data = np.random.ranf([3, 4, 5]).astype(np.float32) + reduced = np.log(np.sum(data, keepdims=True)) + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + op_name = "test_reduce_log_sum_negative_axes" + node = onnx.helper.make_node( + "ReduceLogSum", inputs=["data"], outputs=["reduced"], axes=[-2], name=op_name + ) + data = np.random.ranf([3, 4, 5]).astype(np.float32) + reduced = np.log(np.sum(data, axis=(-2), keepdims=True)) + # print(reduced) + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + op_name = "test_reduce_log_sum_desc_axes" + node = onnx.helper.make_node( + "ReduceLogSum", + inputs=["data"], + outputs=["reduced"], + axes=[2, 1], + keepdims=0, + name=op_name, + ) + data = np.random.ranf([3, 4, 5]).astype(np.float32) + reduced = np.log(np.sum(data, axis=(2, 1), keepdims=False)) + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + op_name = "test_reduce_log_sum_asc_axes" + node = onnx.helper.make_node( + "ReduceLogSum", + inputs=["data"], + outputs=["reduced"], + axes=[0, 1], + keepdims=0, + name=op_name, + ) + data = np.random.ranf([3, 4, 5]).astype(np.float32) + reduced = np.log(np.sum(data, axis=(0, 1), keepdims=False)) + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + +def test_reducelogsumexp(): + op_name, op_type = ( + "test_reduce_log_sum_exp_default_axes_keepdims_example", + "ReduceLogSumExp", + ) + shape = [3, 2, 2] + axes = None + keepdims = 1 + + node = onnx.helper.make_node( + "ReduceLogSumExp", + inputs=["data"], + outputs=["reduced"], + keepdims=keepdims, + name=op_name, + ) + + data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32) + reduced = np.log(np.sum(np.exp(data), axis=axes, keepdims=keepdims == 1)) + # print(reduced) + # [[[60.00671387]]] + + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + op_name = "test_reduce_log_sum_exp_default_axes_keepdims_random" + node = onnx.helper.make_node( + "ReduceLogSumExp", + inputs=["data"], + outputs=["reduced"], + keepdims=keepdims, + name=op_name, + ) + + np.random.seed(0) + data = np.random.uniform(-10, 10, shape).astype(np.float32) + reduced = np.log(np.sum(np.exp(data), axis=axes, keepdims=keepdims == 1)) + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + +def test_reducemax(): + op_name, op_type = "test_reduce_max_default_axes_keepdim_example", "ReduceMax" + shape = [3, 2, 2] + axes = None + keepdims = 1 + node = onnx.helper.make_node( + "ReduceMax", + inputs=["data"], + outputs=["reduced"], + keepdims=keepdims, + name=op_name, + ) + + data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32) + reduced = np.maximum.reduce(data, axis=axes, keepdims=keepdims == 1) + # print(reduced) + # [[[60.]]] + + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + op_name = "test_reduce_max_default_axes_keepdims_random" + node = onnx.helper.make_node( + "ReduceMax", + inputs=["data"], + outputs=["reduced"], + keepdims=keepdims, + name=op_name, + ) + np.random.seed(0) + data = np.random.uniform(-10, 10, shape).astype(np.float32) + reduced = np.maximum.reduce(data, axis=axes, keepdims=keepdims == 1) + + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + +def test_reducemean(): + op_name, op_type = "test_reduce_mean_default_axes_keepdims_example", "ReduceMean" + shape = [3, 2, 2] + axes = None + keepdims = 1 + + node = onnx.helper.make_node( + "ReduceMean", + inputs=["data"], + outputs=["reduced"], + keepdims=keepdims, + name=op_name, + ) + + data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32) + reduced = np.mean(data, axis=axes, keepdims=keepdims == 1) + # print(reduced) + # [[[18.25]]] + + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + op_name = "test_reduce_mean_default_axes_keepdims_random" + + node = onnx.helper.make_node( + "ReduceMean", + inputs=["data"], + outputs=["reduced"], + keepdims=keepdims, + name=op_name, + ) + np.random.seed(0) + data = np.random.uniform(-10, 10, shape).astype(np.float32) + reduced = np.mean(data, axis=axes, keepdims=keepdims == 1) + + op_expect(node, inputs=[data], outputs=[reduced], op_type=op_type, op_name=op_name) + + +def test_reducesum(): + batch_size = 32 + op_name = "reduce_sum_1" + with tf.Graph().as_default(): + input_ph = tf.placeholder( + dtype=tf.float32, shape=[batch_size, 256], name="input" + ) # [batchsize, 10] + input_data = np.random.rand(batch_size, 256).astype(np.float32) + x = tf.math.reduce_sum(input_ph, axis=1, name=op_name) + _ = tf.identity(x, name="output") + verify_tf_with_trt_result([input_data], ["input:0"], ["output:0"], op_name=op_name) + + +def test_maxunpool(): + def verify_maxunpool( + data, indices, kernel_shape, strides, output_shape=None, pads=None, op_name=None + ): + input_names = ["xT", "xI"] + input_info = [ + helper.make_tensor_value_info("xT", TensorProto.FLOAT, list(data.shape)), + helper.make_tensor_value_info("xI", TensorProto.INT64, list(indices.shape)), + ] + input_values = [data, indices] + # input_values = [data ] + if output_shape is not None: + input_names.append("output_shape") + input_info.append( + helper.make_tensor_value_info( + "output_shape", TensorProto.INT64, list(output_shape.shape) + ) + ) + input_values.append(output_shape) + else: + # Compute expected output shape + output_shape = np.asarray(([1, 1] + list(strides))) * np.asarray(list(data.shape)) + output_shape += np.asarray(([0, 0] + list(kernel_shape))) - np.asarray( + ([0, 0] + list(strides)) + ) + if pads is not None: + output_shape -= np.asarray( + [0, 0] + list(np.sum(np.reshape(list(pads), [-1, 2]), axis=-1)) + ) + output_shape = [int(i) for i in output_shape] + + node = helper.make_node( + "MaxUnpool", + inputs=input_names, + outputs=["y"], + kernel_shape=kernel_shape, + name=op_name, + ) + + if pads is not None: + pad_attr = helper.make_attribute("pads", pads) + node.attribute.append(pad_attr) + + if strides is not None: + strides_attr = helper.make_attribute("strides", strides) + node.attribute.append(strides_attr) + + graph = helper.make_graph( + [node], + "maxunpool_test", + inputs=input_info, + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], + ) + + model = helper.make_model(graph, producer_name="size_test") + verify_with_ort_with_trt(model, input_values, op_name=op_name, opset=11) + + # NOCC:invalid-name(其他:onnx example) + xT = np.array([[[[5, 6], [7, 8]]]], dtype=np.float32) + # NOCC:invalid-name(其他:onnx example) + xI = np.array([[[[0, 7], [13, 15]]]], dtype=np.int64) + verify_maxunpool(xT, xI, [2, 2], strides=[2, 2], op_name="max_unpool_1") + + +def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype, op_name): + inp_array1 = np.random.randint(0, 5, size=indices_shape) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype, name="input") + out = tf.one_hot(in1, depth, on_value, off_value, axis, dtype=out_dtype, name=op_name) + out = tf.identity(out, "output") + verify_tf_with_trt_result([inp_array1], ["input:0"], ["output:0"], op_name) + + +def test_forward_one_hot(): + _test_forward_one_hot((3,), 3, 1.0, 0.0, -1, "float32", "onehot_2") + + +def test_where(): + op_name, op_type = "test_where", "Where" + node = onnx.helper.make_node( + "Where", inputs=["condition", "x", "y"], outputs=["z"], name=op_name + ) + condition = np.array([[1, 0], [1, 1]], dtype=bool) + x = np.array([[1, 2], [3, 4]], dtype=np.int64) + y = np.array([[9, 8], [7, 6]], dtype=np.int64) + z = np.where(condition, x, y) # expected output [[1, 8], [3, 4]] + op_expect(node, inputs=[condition, x, y], outputs=[z], op_type=op_type, op_name=op_name) + + +def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None): + op_name = "slice_0" + if axes: + y = helper.make_node( + "Slice", ["in"], ["out"], axes=axes, starts=starts, ends=ends, name=op_name + ) + else: + y = helper.make_node("Slice", ["in"], ["out"], starts=starts, ends=ends, name=op_name) + + graph = helper.make_graph( + [y], + "slice_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outdata.shape))], + ) + + model = helper.make_model(graph, producer_name="slice_test") + # verify_with_ort_with_trt(model, [indata], [outdata.shape], op_name=op_name, opset=1) + verify_with_ort_with_trt(model, [indata], op_name=op_name, opset=1) + + +def test_slice(): + x = np.random.randn(20, 10, 5).astype(np.float32) + _test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1)) + + +def verify_pad_v11(indata, pads, mode="constant", value=0.0): + op_name = "pad_001" + indata = np.array(indata).astype(np.float32) + # numpy expect result + len_dim = len(pads) // 2 + np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)] + pads = np.array(pads) + # onnx graph + if mode in ["edge", "reflect"]: + inputs = [indata] + outdata = np.pad(indata, pad_width=np_pads, mode=mode) + node = helper.make_node( + "Pad", inputs=["input", "pads"], outputs=["output"], mode=mode, name=op_name + ) + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("pads", TensorProto.INT64, (len(pads),)), + ], + initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads)], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], + ) + else: + inputs = [indata] + outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) + node = helper.make_node( + "Pad", + inputs=["input", "pads", "constant_value"], + outputs=["output"], + mode="constant", + name=op_name, + ) + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("pads", TensorProto.INT64, (len(pads),)), + helper.make_tensor_value_info("constant_value", TensorProto.FLOAT, (1,)), + ], + initializer=[ + helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads), + helper.make_tensor("constant_value", TensorProto.FLOAT, (1,), [value]), + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], + ) + model = helper.make_model(graph, producer_name="pad_test") + verify_with_ort_with_trt(model, inputs, op_name, opset=11) + + +@pytest.mark.skip(reason="TensorRT segmentfault") +def test_pad(): + verify_pad_v11(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], "constant", 0.0) + + +@pytest.mark.skip(reason="TensorRT segmentfault") +def test_batch_norm(): + def verify_batch_norm(in_shape): + op_name = "batchNorm_{}".format(sum(in_shape)) + batchnorm = onnx.helper.make_node( + "BatchNormalization", + inputs=["x", "scale", "B", "mean", "var"], + outputs=["Y"], + name=op_name, + ) + + graph = helper.make_graph( + [batchnorm], + "batchnorm_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("mean", TensorProto.FLOAT, [in_shape[1]]), + helper.make_tensor_value_info("var", TensorProto.FLOAT, [in_shape[1]]), + ], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(in_shape))], + ) + + model = helper.make_model(graph, producer_name="batchnorm_test") + # X, scale, b, mean, var + inshapes = [in_shape, in_shape[1], in_shape[1], in_shape[1], in_shape[1]] + inputs = [np.random.uniform(size=ishape).astype("float32") for ishape in inshapes] + + verify_with_ort_with_trt(model, inputs, op_name=op_name) + + verify_batch_norm([1, 3, 224, 224]) + verify_batch_norm([1, 3, 24, 24]) + verify_batch_norm([16, 3, 24, 24]) + verify_batch_norm([16, 16, 24, 24]) + verify_batch_norm([16, 16, 10, 10]) + + +def verify_softmax(inshape, axis, op_name): + indata = np.random.uniform(size=inshape).astype(np.float32) + outshape = inshape + y = helper.make_node("Softmax", ["in"], ["out"], name=op_name) + if axis is not None: + axis_attr = helper.make_attribute("axis", axis) + y.attribute.append(axis_attr) + + graph = helper.make_graph( + [y], + "Softmax_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))], + ) + + model = helper.make_model(graph, producer_name="Softmax_test") + verify_with_ort_with_trt(model, [indata], op_name=op_name) + + +def test_softmax(): + verify_softmax((1, 10), None, op_name="softmax_0") + # verify_softmax((1, 10), 1, op_name='softmax_1') + + +def verify_mod(x_shape, y_shape, fmod, out_shape, dtype="float32", op_name=""): + x_np = np.random.uniform(-100.0, 100.0, x_shape).astype(dtype) + y_np = np.random.uniform(-100.0, 100.0, y_shape).astype(dtype) + y_np = np.where(y_np == 0, 1, y_np) # remove 0's to avoid division by zero error + + mod_node = helper.make_node("Mod", inputs=["x", "y"], outputs=["z"], fmod=fmod, name=op_name) + + onnx_dtype = TensorProto.FLOAT if dtype == "float32" else TensorProto.INT32 + graph = helper.make_graph( + [mod_node], + "mod_test", + inputs=[ + helper.make_tensor_value_info("x", onnx_dtype, list(x_shape)), + helper.make_tensor_value_info("y", onnx_dtype, list(y_shape)), + ], + outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="mod_test") + # verify_with_ort_with_trt(model, [x_np, y_np], [out_shape], op_name=op_name) + verify_with_ort_with_trt(model, [x_np, y_np], op_name=op_name) + + +def test_mod(): + # Mod + verify_mod( + x_shape=[1, 32, 32], + y_shape=[1, 1, 32], + fmod=0, + out_shape=(1, 32, 32), + dtype="int32", + op_name="tvm_mod", + ) + + +def verify_mean(input_dim, op_name): + dtype = "float32" + a_np1 = np.random.uniform(size=input_dim).astype(dtype) + a_np2 = np.random.uniform(size=input_dim).astype(dtype) + a_np3 = np.random.uniform(size=input_dim).astype(dtype) + + mean_node = helper.make_node("Mean", ["a_np1", "a_np2", "a_np3"], ["out"], name=op_name) + + graph = helper.make_graph( + [mean_node], + "Mean_test", + inputs=[ + helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, list(input_dim)), + helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, list(input_dim)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(input_dim))], + ) + + model = helper.make_model(graph, producer_name="Mean_test") + verify_with_ort_with_trt(model, [a_np1, a_np2, a_np3], op_name=op_name) + + +def test_forward_mean(): + verify_mean((1, 3, 20, 20), op_name="mean_111") + verify_mean((20, 20), op_name="mean_222") + + +def verify_instance_norm(shape, axis=1, op_name="default"): + x = np.random.randn(*shape).astype(np.float32) + gamma = np.random.randn(shape[1]).astype(np.float32) + beta = np.random.randn(shape[1]).astype(np.float32) + epsilon = 1e-5 + + node = onnx.helper.make_node( + "InstanceNormalization", + inputs=["x", "gamma", "beta"], + outputs=["y"], + epsilon=epsilon, + name=op_name, + ) + graph = helper.make_graph( + [node], + "instance_norm_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],)), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))], + ) + model = helper.make_model(graph, producer_name="instance_norm_test") + verify_with_ort_with_trt(model, [x, gamma, beta], op_name=op_name) + + +def test_instance_norm(): + verify_instance_norm((2, 3, 4, 5), op_name="instance_norm") + # verify_instance_norm((32, 64, 80, 64)) + # verify_instance_norm((8, 6, 5)) + # verify_instance_norm((8, 7, 6, 5, 4)) + + +def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None, op_name=None): + in_array = np.random.uniform(size=shape).astype(dtype) + + if alpha is None and beta is None and bias is None: + alpha = 0.0001 + beta = 0.75 + bias = 1.0 + node = onnx.helper.make_node( + "LRN", inputs=["in"], outputs=["out"], size=nsize, name=op_name + ) + else: + node = onnx.helper.make_node( + "LRN", + inputs=["in"], + outputs=["out"], + alpha=alpha, + beta=beta, + bias=bias, + size=nsize, + name=op_name, + ) + + graph = helper.make_graph( + [node], + "lrn_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(shape))], + ) + model = helper.make_model(graph, producer_name="lrn_test") + verify_with_ort_with_trt(model, [in_array], op_name=op_name) + + +def test_lrn(): + verify_lrn((5, 5, 5, 5), 3, "float32", op_name="test_lrn_1") + verify_lrn( + (5, 5, 5, 5), + 3, + "float32", + alpha=0.0002, + beta=0.5, + bias=2.0, + op_name="test_lrn_2", + ) + + +def test_lstm(): + # # Different activation testing. + # # Default value hardsigmoid. + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=False, + activations=["HardSigmoid", "Tanh", "Tanh"], + rnn_type="LSTM", + op_name="test_lstm_without_bias", + layout=1, + ) + + +def test_binary_ops(): + in_shape = (1, 2, 3, 3) + dtype = "float32" + out_shape = in_shape + + def verify_binary_ops(op, x, y, out_type="float32", op_name=None): + z = helper.make_node(op, ["in1", "in2"], ["out"], name=op_name) + graph = helper.make_graph( + [z], + "_test", + inputs=[ + helper.make_tensor_value_info("in1", TensorProto.FLOAT, x.shape), + helper.make_tensor_value_info("in2", TensorProto.FLOAT, y.shape), + ], + outputs=[ + helper.make_tensor_value_info( + "out", + mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(out_type)], + list(out_shape), + ) + ], + ) + model = helper.make_model(graph, producer_name="_test") + verify_with_ort_with_trt(model, [x, y], op_name=op_name) + + x = np.random.uniform(size=in_shape).astype(dtype) + y = np.random.uniform(size=in_shape).astype(dtype) + z = np.random.uniform(size=(3,)).astype(dtype) + verify_binary_ops("Sub", x, y, op_name="sub_1") + verify_binary_ops("Sub", x, z, op_name="sub_2") + + +def verify_reduce_func(func, data, axis, keepdims, op_name=None): + inshape = data.shape + outshape = np.sum(data, axis=axis, keepdims=keepdims == 1).shape + + if axis: + node = onnx.helper.make_node( + func, + inputs=["x"], + outputs=["y"], + axes=axis, + keepdims=keepdims, + name=op_name, + ) + else: + node = onnx.helper.make_node( + func, inputs=["x"], outputs=["y"], keepdims=keepdims, name=op_name + ) + + graph = helper.make_graph( + [node], + "reduce_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))], + ) + + model = helper.make_model(graph, producer_name="reduce_test") + + verify_with_ort_with_trt(model, [data], opset=11, op_name=op_name) + + +def test_all_reduce_funcs(): + funcs = [ + # "ReduceMax", + # "ReduceMean", + # "ReduceMin", + # "ReduceProd", + # "ReduceSum", + # "ReduceSumSquare", + "ReduceLogSum", + "ReduceLogSumExp", + "ReduceL1", + "ReduceL2", + ] + + for func in funcs: + for keepdims in [True, False]: + verify_reduce_func( + func, + np.random.randn(3, 2, 2).astype(np.float32), + axis=None, + keepdims=keepdims, + op_name=func + str(int(keepdims)) + "1", + ) + + verify_reduce_func( + func, + np.random.randn(3, 2, 3).astype(np.float32), + axis=None, + keepdims=keepdims, + op_name=func + str(int(keepdims)) + "2", + ) + + verify_reduce_func( + func, + np.random.randn(3, 3, 3).astype(np.float32), + axis=(1,), + keepdims=keepdims, + op_name=func + str(int(keepdims)) + "3", + ) + + verify_reduce_func( + func, + np.random.randn(3, 3, 3, 1).astype(np.float32), + axis=(1, 2), + keepdims=keepdims, + op_name=func + str(int(keepdims)) + "4", + ) + + verify_reduce_func( + func, + np.random.randn(3, 3, 3, 1).astype(np.float32), + axis=(1,), + keepdims=keepdims, + op_name=func + str(int(keepdims)) + "5", + ) + + verify_reduce_func( + func, + np.random.randn(1, 3, 4, 1).astype(np.float32), + axis=(1,), + keepdims=keepdims, + op_name=func + str(int(keepdims)) + "6", + ) + + +def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11, op_name=None): + indata = np.array(indata).astype(np.float32) + outdatas = [np.array(o).astype(np.float32) for o in outdatas] + inputs = [helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape))] + input_names = ["input"] + initializer = [] + + if split: + split_index = range(len(split)) + else: + split_index = range(len(outdatas)) + + if pass_split: + if opset >= 13: + input_names.append("split") + np_split = np.array(split).astype(np.int64) + inputs.append( + helper.make_tensor_value_info("split", TensorProto.INT64, list(np_split.shape)) + ) + indata = [indata, np_split] + initializer.append( + helper.make_tensor("split", TensorProto.INT64, list(np_split.shape), np_split) + ) + node = helper.make_node( + "Split", + inputs=input_names, + outputs=["output_{}".format(i) for i in range(len(split_index))], + axis=axis, + name=op_name, + ) + + if pass_split and opset < 13: + split_attr = helper.make_attribute("split", split) + node.attribute.append(split_attr) + + graph = helper.make_graph( + [node], + "split_test", + inputs=inputs, + initializer=initializer, + outputs=[ + helper.make_tensor_value_info( + "output_{}".format(i), TensorProto.FLOAT, list(outdatas[i].shape) + ) + for i in range(len(split_index)) + ], + ) + model = helper.make_model(graph, producer_name="split_test") + verify_with_ort_with_trt(model, indata, opset=opset, op_name=op_name) + + +def test_split(): + # 1D + verify_split( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + [2, 2, 2], + 0, + op_name="split_1", + ) + verify_split( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + [2, 2, 2], + 0, + False, + op_name="split_2", + ) + # 2D + verify_split( + [[1.0, 2.0, 3.0, 4.0], [7.0, 8.0, 9.0, 10.0]], + [[[1.0, 2.0], [7.0, 8.0]], [[3.0, 4.0], [9.0, 10.0]]], + [2, 2], + 1, + op_name="split_4", + ) + # Split evenly (unstack) + verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False, op_name="split_5") + # Split a single value to a single value + verify_split([1], [[1]], [1], pass_split=True, op_name="split_6") + + +def verify_xor(x_shape, y_shape, op_name=None): + x_np = np.random.choice(a=[False, True], size=x_shape).astype("bool") + y_np = np.random.choice(a=[False, True], size=y_shape).astype("bool") + + np_out = np.logical_xor(x_np, y_np) + out_shape = np_out.shape + + xor_node = helper.make_node("Xor", inputs=["x", "y"], outputs=["z"], name=op_name) + + onnx_dtype = TensorProto.BOOL + graph = helper.make_graph( + [xor_node], + "xor_test", + inputs=[ + helper.make_tensor_value_info("x", onnx_dtype, list(x_shape)), + helper.make_tensor_value_info("y", onnx_dtype, list(y_shape)), + ], + outputs=[helper.make_tensor_value_info("z", onnx_dtype, list(out_shape))], + ) + model = helper.make_model(graph, producer_name="xor_test") + verify_with_ort_with_trt(model, [x_np, y_np], op_name=op_name) + + +@pytest.mark.skip(reason="TensorRT segmentfault") +def test_xor(): + # XOR + verify_xor(x_shape=[1, 32, 32], y_shape=[1, 32, 32], op_name="test_xor_1") + + # Xor broadcast + verify_xor(x_shape=[1, 32, 32], y_shape=[1, 1, 32], op_name="test_xor_2") + + +def verify_if(cond_array, op_name): + # Given a bool scalar input cond. + # return constant tensor x if cond is True, otherwise return constant tensor y. + then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) + else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5]) + + x = np.array([1, 2, 3, 4, 5]).astype(np.float32) + y = np.array([5, 4, 3, 2, 1]).astype(np.float32) + + then_const_node = onnx.helper.make_node( + "Constant", inputs=[], outputs=["then_out"], value=numpy_helper.from_array(x) + ) + + else_const_node = onnx.helper.make_node( + "Constant", inputs=[], outputs=["else_out"], value=numpy_helper.from_array(y) + ) + + then_body = onnx.helper.make_graph([then_const_node], "then_body", [], [then_out]) + + else_body = onnx.helper.make_graph([else_const_node], "else_body", [], [else_out]) + + if_node = onnx.helper.make_node( + "If", + inputs=["cond"], + outputs=["res"], + then_branch=then_body, + else_branch=else_body, + name=op_name, + ) + + if_graph = onnx.helper.make_graph( + [if_node], + "if_outer", + inputs=[ + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + ], + outputs=[ + onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, [5]), + ], + ) + + if_model = onnx.helper.make_model(if_graph) + if cond_array: + cond = np.array([1]).astype("bool") + else: + cond = np.array(1).astype("bool") + verify_with_ort_with_trt(if_model, [cond], op_name=op_name) + + +@pytest.mark.skip( + reason="ORT: NOT_IMPLEMENTED : Could not find an implementation for If(19) node with name 'if_test_1'" +) +def test_if(): + # Confirm that if works with cond as an array or scalar. + verify_if(cond_array=False, op_name="if_test_1") + verify_if(cond_array=True, op_name="if_test_2") + + +def test_softmax_cross_entropyloss(): + op_name = "test_SoftmaxCrossEntropyLoss" + reduction = "mean" + ignore_index = np.int64(-1) + + node = onnx.helper.make_node( + "SoftmaxCrossEntropyLoss", + inputs=["x", "y", "w"], + outputs=["z"], + reduction=reduction, + ignore_index=ignore_index, + name=op_name, + ) + # NOCC:invalid-name(其他:onnx example) + N, C, dim1 = 3, 5, 6 + np.random.seed(0) + x = np.random.rand(N, C, dim1).astype(np.float32) + labels = np.random.randint(0, high=C, size=(N, dim1)).astype(np.int64) + labels[0][0] = -1 + weight = np.random.rand(C).astype(np.float32) + from onnx.backend.test.case.node.softmaxcrossentropy import softmaxcrossentropy + + sce = softmaxcrossentropy( + x, labels, weight=weight, reduction=reduction, ignore_index=ignore_index + ) + + op_expect( + node, + inputs=[x, labels, weight], + outputs=[sce], + op_name=op_name, + op_type="float32", + ) + + +def _test_logical(method, op_name): + batch_size = 128 + input_data = (2 * np.random.rand(batch_size, 256) - 1).astype(np.float32) + with tf.Graph().as_default(): + input_ph = tf.placeholder(dtype=tf.float32, shape=[batch_size, 256], name="input") + x = tf.nn.relu(input_ph) + mask = tf.cast(x, tf.bool) + x = tf.nn.relu(tf.layers.dense(x, 256)) + y = x + x = tf.cast(x, tf.bool) + if method == "or": + x = tf.math.logical_or(x, mask, name=op_name) + elif method == "and": + x = tf.math.logical_and(x, mask, name=op_name) + elif method == "not": + x = tf.math.logical_not(x, name=op_name) + elif method == "equal": + x = tf.math.equal(x, mask, name=op_name) + elif method == "greater": + x = tf.math.greater(y, input_ph, name=op_name) + elif method == "xor": + x = tf.math.logical_xor(x, mask, name=op_name) + elif method == "is_inf": + x = tf.math.is_inf(input_ph, name=op_name) + elif method == "is_nan": + x = tf.math.is_nan(input_ph, name=op_name) + _ = tf.identity(x, name="output") + verify_tf_with_trt_result([input_data], ["input:0"], ["output:0"], op_name) + + +@pytest.mark.skip(reason="TensorRT segmentfault") +def test_logical(): + _test_logical("or", "test_logical_or") + _test_logical("and", "test_logical_and") + _test_logical("not", "test_logical_not") + _test_logical("equal", "test_logical_equal") + _test_logical("greater", "test_logical_greater") + _test_logical("xor", "test_logical_xor") + _test_logical("is_inf", "test_logical_inf") + _test_logical("is_nan", "test_logical_nan") + + +@pytest.mark.skip(reason="TensorFlow segmentfault") +def test_scatternd(): + batch_size = 32 + op_name = "scatternd" + with tf.Graph().as_default(): + input_ph = tf.placeholder( + dtype=tf.float32, shape=[batch_size, 10], name="input" + ) # [batchsize, 10] + input_data = np.random.rand(batch_size, 10).astype(np.float32) + x = tf.layers.dense(input_ph, 1) + # duplicated indices case (undefined) + # test ScatterND (32, 128, 128, 256) (32, 600, 3) (32, 600, 256) + data = tf.tile(tf.reshape(tf.layers.dense(x, 128 * 128), [-1, 128, 128, 1]), [1, 1, 1, 256]) + x = tf.add(x, 1) + idx = tf.reshape(tf.layers.dense(x, 600 * 3), [-1, 600, 3]) + idx = tf.cast(tf.clip_by_value(idx, 0, 1), tf.int32) + indices = idx + # indices = tf.zeros([32, 600, 3], dtype=tf.dtypes.int32) + # indices = tf.stack([tf.range(tf.shape(x)[0]), idx], axis=1) + x = tf.add(x, 2) + updates = tf.reshape(tf.layers.dense(x, 600 * 256), [-1, 600, 256]) + # updates = tf.ones([32, 600, 256]) + x = tf.tensor_scatter_nd_update(data, indices, updates, name=op_name) + # x = tf.scatter_nd(indices, updates, data.shape) + _ = tf.identity(x, name="output") + verify_tf_with_trt_result([input_data], ["input:0"], ["output:0"], op_name) + +if __name__ == "__main__": + test_abs() + test_acos() + test_and() + test_add() + test_argmax() + test_argmin() + test_asin() + test_asinh() + test_atan() + test_atanh() + test_averagepool() + test_batchnormalization() + test_ceil() + test_celu() + test_clip() + test_concat() + test_conv() + test_convtranspose() + test_cos() + test_cosh() + test_depthtospace() + test_div() + # ------100 limited library + test_einsum() + test_elu() + test_erf() + test_exp() + test_eyelike() + test_floor() + test_gather() + test_gatherelement() + test_gathernd() + test_gemm() + test_globalaveragepool() + test_globalmaxpool() + test_hardsigmoid() + test_hardswish() + test_hardmax() + test_identity() + test_instancenormalization() + test_leakyrelu() + test_log() + test_logsoftmax() + test_matmul() + test_max() + test_maxpool() + test_mean() + test_min() + test_mul() + test_neg() + test_negativeloglikelihoodloss() + # ---------100 limited library + test_prelu() + test_pow() + test_reciprocal() + test_reducel1() + test_reducel2() + test_reducelogsum() + test_reducelogsumexp() + test_reducemax() + test_reducemean() + test_reducesum() + test_maxunpool() + test_forward_one_hot() + test_where() + test_slice() + test_pad() + test_batch_norm() + test_softmax() + test_mod() + test_forward_mean() + test_instance_norm() + test_lrn() + test_binary_ops() + test_all_reduce_funcs() + test_split() + test_xor() + test_if() + test_logical() + test_scatternd() diff --git a/tests/python/tpat/cuda/trt.py b/tests/python/tpat/cuda/trt.py new file mode 100644 index 000000000000..4cf4151c2f43 --- /dev/null +++ b/tests/python/tpat/cuda/trt.py @@ -0,0 +1,178 @@ +# +# Copyright 1993-2019 NVIDIA Corporation. All rights reserved. +# +# NOTICE TO LICENSEE: +# +# This source code and/or documentation ("Licensed Deliverables") are +# subject to NVIDIA intellectual property rights under U.S. and +# international Copyright laws. +# +# These Licensed Deliverables contained herein is PROPRIETARY and +# CONFIDENTIAL to NVIDIA and is being provided under the terms and +# conditions of a form of NVIDIA software license agreement by and +# between NVIDIA and Licensee ("License Agreement") or electronically +# accepted by Licensee. Notwithstanding any terms or conditions to +# the contrary in the License Agreement, reproduction or disclosure +# of the Licensed Deliverables to any third party without the express +# written consent of NVIDIA is prohibited. +# +# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE +# LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE +# SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS +# PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. +# NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED +# DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, +# NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. +# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE +# LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY +# SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY +# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS +# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE +# OF THESE LICENSED DELIVERABLES. +# +# U.S. Government End Users. These Licensed Deliverables are a +# "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT +# 1995), consisting of "commercial computer software" and "commercial +# computer software documentation" as such terms are used in 48 +# C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government +# only as a commercial end item. Consistent with 48 C.F.R.12.212 and +# 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all +# U.S. Government End Users acquire the Licensed Deliverables with +# only those rights set forth herein. +# +# Any use of the Licensed Deliverables in individual and commercial +# software must include, in the user documentation and internal +# comments to the code, the above Disclaimer and U.S. Government End +# Users Notice. +# + +import ctypes +import os + +import numpy as np +import pycuda.autoinit +import pycuda.driver as cuda +import tensorrt as trt + + +def GiB(val): + return val * 1 << 30 + + +# Simple helper data class that's a little nicer to use than a 2-tuple. +class HostDeviceMem(object): + def __init__(self, host_mem, device_mem): + self.host = host_mem + self.device = device_mem + + def __str__(self): + return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) + + def __repr__(self): + return self.__str__() + + +# Allocates all buffers required for an engine, i.e. host/device inputs/outputs. +def allocate_buffers(engine): + inputs = [] + outputs = [] + bindings = [] + stream = cuda.Stream() + for binding in engine: + size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size + dtype = trt.nptype(engine.get_binding_dtype(binding)) + # Allocate host and device buffers + host_mem = cuda.pagelocked_empty(size, dtype) + device_mem = cuda.mem_alloc(host_mem.nbytes) + # Append the device buffer to device bindings. + bindings.append(int(device_mem)) + # Append to the appropriate list. + if engine.binding_is_input(binding): + inputs.append(HostDeviceMem(host_mem, device_mem)) + else: + outputs.append(HostDeviceMem(host_mem, device_mem)) + return inputs, outputs, bindings, stream + + +# This function is generalized for multiple inputs/outputs. +# inputs and outputs are expected to be lists of HostDeviceMem objects. +def do_inference(context, bindings, inputs, outputs, stream, batch_size=1): + # Transfer input data to the GPU. + [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] + # Run inference. + context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) + # Transfer predictions back from the GPU. + [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] + # Synchronize the stream + stream.synchronize() + # Return only the host outputs. + return [out.host for out in outputs] + + +def build_engine( + onnx_model_path, + trt_logger=trt.Logger(trt.Logger.WARNING), + trt_engine_datatype=trt.DataType.FLOAT, + batch_size=1, + silent=False, +): + try: + with trt.Builder(trt_logger) as builder, builder.create_network( # type: ignore + 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) # type: ignore + ) as network, trt.OnnxParser( # type: ignore + network, trt_logger + ) as parser: + # https://github.com/NVIDIA/TensorRT/blob/main/demo/BERT/builder.py#L405 + builder_config = builder.create_builder_config() + builder_config.max_workspace_size = 2 << 60 + builder.max_batch_size = batch_size + + if trt_engine_datatype == trt.DataType.HALF: + builder_config.set_flag(trt.BuilderFlag.FP16) + elif trt_engine_datatype == trt.DataType.INT8: + builder_config.set_flag(trt.BuilderFlag.INT8) + + with open(onnx_model_path, "rb") as model: + # parse onnx model + parser.parse(model.read()) + for i in range(parser.num_errors): + print(parser.get_error(i)) + engine = builder.build_engine(network, builder_config) + if engine is None: + print("[ERROR] engine is None") + exit(-1) + return engine + except Exception as e: + print(e.with_traceback()) + + +def save_engine(engine, engine_dest_path): + buf = engine.serialize() + with open(engine_dest_path, "wb") as f: + f.write(buf) + + +def load_engine(trt_runtime, engine_path): + with open(engine_path, "rb") as f: + engine_data = f.read() + engine = trt_runtime.deserialize_cuda_engine(engine_data) + return engine + + +def load_plugin(trt_plugins): + libs = [] + for trt_plugin in trt_plugins: + assert os.path.isfile(trt_plugin) + lib = ctypes.CDLL(trt_plugin, winmode=0) + libs.append(lib) + return libs + + +def remove_plugin(libs): + for lib in libs: + _unload_lib(lib) + + +def _unload_lib(lib): + del lib