From 03217364f93899d303a04e51a96ed5dea6eabdf4 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 2 May 2024 20:41:59 -0400 Subject: [PATCH 1/7] untested exporter --- export.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 export.py diff --git a/export.py b/export.py new file mode 100644 index 0000000..08c9453 --- /dev/null +++ b/export.py @@ -0,0 +1,99 @@ +import os +import torch +import tqdm +import argparse +import subprocess +import numpy as np +from modules.model import * + +def build_onnx_engine(weights: str, + onnx_weight: str, + imgsz: tuple = (480,640), + use_dyanmic_axis: bool = True, + onnx_opset: int = 17) -> None: + if onnx_weight is None: + raise Exception("Onnx file path cannot be None.") + dev = 'cpu' + net = XFeatModel().to(dev) + net.load_state_dict(torch.load(weights, map_location=dev)) + net.eval() + onnx_weight = None + #Random input + x = torch.randn(1,3,*imgsz).to(dev) + if onnx_weight is None: + raise Exception("Onnx file path cannot be None.") + if use_dyanmic_axis: + dyanmic_axis = { + "image": {0: "batch"}, + } + else: + dyanmic_axis = {} + + torch.onnx.export( + net, + x, + onnx_weight, + input_names=["image"], + output_names=["feats", "keypoints", "heatmap"], + dynamic_axes=dyanmic_axis, + opset_version=onnx_opset, + ) + +def build_tensorrt_engine(weights: str, + imgsz: tuple = (480,640), + fp16_mode: bool = True, + use_dyanmic_axis: bool = True, + onnx_opset: int = 17) -> None: + + if weights.endswith(".pt"): + # Replace ".pt" with ".onnx" + onnx_weight = weights[:-3] + ".onnx" + else: + raise Exception("File path does not end with '.pt'.") + + build_onnx_engine(weights, onnx_weight, imgsz, use_dyanmic_axis, onnx_opset) + + if not os.path.exists(onnx_weight): + raise Exception("ONNX export does not exist") + + if onnx_weight.endswith(".onnx"): + # Replace ".pt" with ".onnx" + engine_weight = weights[:-5] + ".engine" + else: + raise Exception("File path does not end with '.onnx'.") + + args = ["/usr/src/tensorrt/bin/trtexec"] + args.append(f"--onnx={onnx_weight}") + args.append(f"--saveEngine={engine_weight}") + + if fp16_mode: + args += ["--fp16"] + + args += [f"--shapes=image:1x3x{imgsz[0]}x{imgsz[1]}"] + + subprocess.call(args) + print(f"Finished TensorRT engine export to {engine_weight}.") + +def main(): + parser = argparse.ArgumentParser(description='Create ONNX and TensorRT export for XFeat.') + parser.add_argument('--weights', type=str, default=f'{os.path.abspath(os.path.dirname(__file__))}/weights/xfeat.pt', help='Path to the weights pt file to process') + parser.add_argument('--imgsz', type=tuple, default=(480,640), help='Input image size') + parser.add_argument("--fp16_mode", type=bool, default=True) + parser.add_argument("--use_dyanmic_axis", type=bool, default=True) + parser.add_argument("--onnx_opset", type=int, default=17) + args = parser.parse_args() + weights = args.weights + imgsz = args.imgsz + fp16_mode = args.fp16_mode + onnx_opset = args.onnx_opset + use_dyanmic_axis = args.use_dyanmic_axis + build_tensorrt_engine(weights, imgsz, fp16_mode, use_dyanmic_axis, onnx_opset) + +if __name__ == '__main__': + main() + + + + + + From 41fa366e1e7e696509382e7de5f946994ba2403f Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 2 May 2024 23:20:21 -0400 Subject: [PATCH 2/7] tensorrt export serialization does not work --- export.py | 24 ++++++++++----- modules/model.py | 16 ++++++++++ modules/xfeat.py | 78 +++++++++++++++++++++++++++++++++++++++++++----- realtime_demo.py | 2 +- 4 files changed, 103 insertions(+), 17 deletions(-) diff --git a/export.py b/export.py index 08c9453..70a96ee 100644 --- a/export.py +++ b/export.py @@ -15,9 +15,17 @@ def build_onnx_engine(weights: str, raise Exception("Onnx file path cannot be None.") dev = 'cpu' net = XFeatModel().to(dev) - net.load_state_dict(torch.load(weights, map_location=dev)) - net.eval() - onnx_weight = None + class TempModule(torch.nn.Module): + def __init__(self, parent): + super().__init__() + self.parent = parent + def forward(self, image): + output = self.parent.forward(image) + return ( + output[0], #feats + output[1], #keypoints + output[2] #heatmap + ) #Random input x = torch.randn(1,3,*imgsz).to(dev) if onnx_weight is None: @@ -28,13 +36,13 @@ def build_onnx_engine(weights: str, } else: dyanmic_axis = {} - + net = TempModule(net) torch.onnx.export( net, x, onnx_weight, - input_names=["image"], - output_names=["feats", "keypoints", "heatmap"], + input_names=XFeatModel.get_xfeat_input_names(), + output_names=XFeatModel.get_xfeat_output_names(), dynamic_axes=dyanmic_axis, opset_version=onnx_opset, ) @@ -58,7 +66,7 @@ def build_tensorrt_engine(weights: str, if onnx_weight.endswith(".onnx"): # Replace ".pt" with ".onnx" - engine_weight = weights[:-5] + ".engine" + engine_weight = onnx_weight[:-5] + ".engine" else: raise Exception("File path does not end with '.onnx'.") @@ -73,7 +81,7 @@ def build_tensorrt_engine(weights: str, subprocess.call(args) print(f"Finished TensorRT engine export to {engine_weight}.") - + def main(): parser = argparse.ArgumentParser(description='Create ONNX and TensorRT export for XFeat.') parser.add_argument('--weights', type=str, default=f'{os.path.abspath(os.path.dirname(__file__))}/weights/xfeat.pt', help='Path to the weights pt file to process') diff --git a/modules/model.py b/modules/model.py index 57539fd..46907b2 100644 --- a/modules/model.py +++ b/modules/model.py @@ -9,6 +9,8 @@ import torch.nn.functional as F import time +from typing import List +from dataclasses import dataclass class BasicLayer(nn.Module): """ Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU @@ -152,3 +154,17 @@ def forward(self, x): keypoints = self.keypoint_head(self._unfold2d(x, ws=8)) #Keypoint map logits return feats, keypoints, heatmap + + @staticmethod + def get_xfeat_input_names() -> List[str]: + return ["image"] + + @staticmethod + def get_xfeat_output_names() -> List[str]: + return ["feats", "keypoints", "heatmap"] + +@dataclass +class XFeatModelOutput: + feats: torch.Tensor + keypoints: torch.Tensor + heatmap: torch.Tensor diff --git a/modules/xfeat.py b/modules/xfeat.py index f60a63d..100a948 100644 --- a/modules/xfeat.py +++ b/modules/xfeat.py @@ -11,6 +11,8 @@ import tqdm +from typing import List + from modules.model import * from modules.interpolator import InterpolateSparse2d @@ -20,18 +22,25 @@ class XFeat(nn.Module): It supports inference for both sparse and semi-dense feature extraction & matching. """ - def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', top_k = 4096): + def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', top_k = 4096, use_engine=False): super().__init__() self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - self.net = XFeatModel().to(self.dev).eval() self.top_k = top_k - - if weights is not None: - if isinstance(weights, str): - print('loading weights from: ' + weights) - self.net.load_state_dict(torch.load(weights, map_location=self.dev)) + if use_engine: + if os.path.exists(os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.engine'): + weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.engine' else: - self.net.load_state_dict(weights) + raise Exception('Engine file does not exist.') + self.net = XFeat.load_xfeat_engine(weights) + self.dev = 'cuda' # force cuda for TensorRT + else: + self.net = XFeatModel().to(self.dev).eval() + if weights is not None: + if isinstance(weights, str): + print('loading weights from: ' + weights) + self.net.load_state_dict(torch.load(weights, map_location=self.dev)) + else: + self.net.load_state_dict(weights) self.interpolator = InterpolateSparse2d('bicubic') @@ -344,3 +353,56 @@ def parse_input(self, x): x = torch.tensor(x).permute(0,3,1,2)/255 return x + + @staticmethod + def load_xfeat_engine(engine_path: str): + if not engine_path.endswith(".engine"): + raise Exception('Invalid Engine file.') + + import tensorrt as trt + from torch2trt import TRTModule + + with trt.Logger() as logger, trt.Runtime(logger) as runtime: + with open(engine_path, 'rb') as f: + engine_bytes = f.read() + engine = runtime.deserialize_cuda_engine(engine_bytes) + + base_module = TRTModule( + engine, + input_names=XFeatModel.get_xfeat_input_names(), + output_names=XFeatModel.get_xfeat_output_names(), + ) + + class Wrapper(torch.nn.Module): + def __init__(self, base_module: TRTModule, max_batch_size: int = 1): + super().__init__() + self.base_module = base_module + self.max_batch_size = max_batch_size + + @torch.no_grad() + def forward(self, image): + + b = image.shape[0] + + results = [] + + for start_index in range(0, b, self.max_batch_size): + end_index = min(b, start_index + self.max_batch_size) + image_slice = image[start_index:end_index] + # with torch_timeit_sync("run_engine"): + output = self.base_module(image_slice) + results.append( + output + ) + + return XFeatModelOutput( + image_embeds=torch.cat([r[0] for r in results], dim=0), + image_class_embeds=torch.cat([r[1] for r in results], dim=0), + logit_shift=torch.cat([r[2] for r in results], dim=0), + logit_scale=torch.cat([r[3] for r in results], dim=0), + pred_boxes=torch.cat([r[4] for r in results], dim=0) + ) + + xfeat_engine = Wrapper(base_module=base_module) + return xfeat_engine + diff --git a/realtime_demo.py b/realtime_demo.py index 42f7ec5..354b2d1 100644 --- a/realtime_demo.py +++ b/realtime_demo.py @@ -65,7 +65,7 @@ def init_method(method, max_kpts): elif method == "SIFT": return Method(descriptor=cv2.SIFT_create(max_kpts, contrastThreshold=-1, edgeThreshold=1000), matcher=cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)) elif method == "XFeat": - return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts)), matcher=XFeat()) + return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts, use_engine=True)), matcher=XFeat(use_engine=True)) else: raise RuntimeError("Invalid Method.") From 463a45b45078b6efab8b09f2b2f7ea445b4a34d6 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 3 May 2024 02:34:39 -0400 Subject: [PATCH 3/7] clean commit --- export.py | 44 +++++++++++++++++++------------------------- modules/model.py | 5 ----- modules/xfeat.py | 44 ++++++-------------------------------------- realtime_demo.py | 11 ++++++----- 4 files changed, 31 insertions(+), 73 deletions(-) diff --git a/export.py b/export.py index 70a96ee..78d9f25 100644 --- a/export.py +++ b/export.py @@ -9,49 +9,40 @@ def build_onnx_engine(weights: str, onnx_weight: str, imgsz: tuple = (480,640), - use_dyanmic_axis: bool = True, + use_dynamic_axis: bool = True, onnx_opset: int = 17) -> None: if onnx_weight is None: raise Exception("Onnx file path cannot be None.") dev = 'cpu' - net = XFeatModel().to(dev) - class TempModule(torch.nn.Module): - def __init__(self, parent): - super().__init__() - self.parent = parent - def forward(self, image): - output = self.parent.forward(image) - return ( - output[0], #feats - output[1], #keypoints - output[2] #heatmap - ) + net = XFeatModel().to(dev).eval() + net.load_state_dict(torch.load(weights, map_location=dev)) #Random input x = torch.randn(1,3,*imgsz).to(dev) if onnx_weight is None: raise Exception("Onnx file path cannot be None.") - if use_dyanmic_axis: - dyanmic_axis = { + if use_dynamic_axis: + dynamic_axis = { "image": {0: "batch"}, } else: - dyanmic_axis = {} - net = TempModule(net) + dynamic_axis = {} + # net = TempModule(net) torch.onnx.export( net, x, onnx_weight, input_names=XFeatModel.get_xfeat_input_names(), output_names=XFeatModel.get_xfeat_output_names(), - dynamic_axes=dyanmic_axis, + dynamic_axes=dynamic_axis, opset_version=onnx_opset, ) def build_tensorrt_engine(weights: str, imgsz: tuple = (480,640), fp16_mode: bool = True, - use_dyanmic_axis: bool = True, - onnx_opset: int = 17) -> None: + use_dynamic_axis: bool = True, + onnx_opset: int = 17, + workspace: int = 4096) -> None: if weights.endswith(".pt"): # Replace ".pt" with ".onnx" @@ -59,7 +50,7 @@ def build_tensorrt_engine(weights: str, else: raise Exception("File path does not end with '.pt'.") - build_onnx_engine(weights, onnx_weight, imgsz, use_dyanmic_axis, onnx_opset) + build_onnx_engine(weights, onnx_weight, imgsz, use_dynamic_axis, onnx_opset) if not os.path.exists(onnx_weight): raise Exception("ONNX export does not exist") @@ -73,6 +64,7 @@ def build_tensorrt_engine(weights: str, args = ["/usr/src/tensorrt/bin/trtexec"] args.append(f"--onnx={onnx_weight}") args.append(f"--saveEngine={engine_weight}") + args.append(f"--workspace={workspace}") if fp16_mode: args += ["--fp16"] @@ -85,17 +77,19 @@ def build_tensorrt_engine(weights: str, def main(): parser = argparse.ArgumentParser(description='Create ONNX and TensorRT export for XFeat.') parser.add_argument('--weights', type=str, default=f'{os.path.abspath(os.path.dirname(__file__))}/weights/xfeat.pt', help='Path to the weights pt file to process') - parser.add_argument('--imgsz', type=tuple, default=(480,640), help='Input image size') + parser.add_argument('--imgsz', nargs=2, type=int, default=[480,640], help='Input image size') parser.add_argument("--fp16_mode", type=bool, default=True) - parser.add_argument("--use_dyanmic_axis", type=bool, default=True) + parser.add_argument("--use_dynamic_axis", type=bool, default=True) parser.add_argument("--onnx_opset", type=int, default=17) + parser.add_argument("--workspace", type=int, default=4096) args = parser.parse_args() weights = args.weights imgsz = args.imgsz fp16_mode = args.fp16_mode onnx_opset = args.onnx_opset - use_dyanmic_axis = args.use_dyanmic_axis - build_tensorrt_engine(weights, imgsz, fp16_mode, use_dyanmic_axis, onnx_opset) + use_dynamic_axis = args.use_dynamic_axis + workspace = args.workspace + build_tensorrt_engine(weights, imgsz, fp16_mode, use_dynamic_axis, onnx_opset, workspace) if __name__ == '__main__': main() diff --git a/modules/model.py b/modules/model.py index 46907b2..97bcfa7 100644 --- a/modules/model.py +++ b/modules/model.py @@ -163,8 +163,3 @@ def get_xfeat_input_names() -> List[str]: def get_xfeat_output_names() -> List[str]: return ["feats", "keypoints", "heatmap"] -@dataclass -class XFeatModelOutput: - feats: torch.Tensor - keypoints: torch.Tensor - heatmap: torch.Tensor diff --git a/modules/xfeat.py b/modules/xfeat.py index 100a948..31cd2fe 100644 --- a/modules/xfeat.py +++ b/modules/xfeat.py @@ -11,8 +11,6 @@ import tqdm -from typing import List - from modules.model import * from modules.interpolator import InterpolateSparse2d @@ -26,7 +24,8 @@ def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../w super().__init__() self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.top_k = top_k - if use_engine: + self.use_engine = use_engine + if self.use_engine: if os.path.exists(os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.engine'): weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.engine' else: @@ -64,12 +63,12 @@ def detectAndCompute(self, x, top_k = None): B, _, _H1, _W1 = x.shape M1, K1, H1 = self.net(x) + M1 = F.normalize(M1, dim=1) #Convert logits to heatmap and extract kpts K1h = self.get_kpts_heatmap(K1) mkpts = self.NMS(K1h, threshold=0.05, kernel_size=5) - #Compute reliability scores _nearest = InterpolateSparse2d('nearest') _bilinear = InterpolateSparse2d('bilinear') @@ -281,7 +280,6 @@ def match(self, feats1, feats2, min_cossim = 0.82): cossim = feats1 @ feats2.t() cossim_t = feats2 @ feats1.t() - _, match12 = cossim.max(dim=1) _, match21 = cossim_t.max(dim=1) @@ -361,48 +359,18 @@ def load_xfeat_engine(engine_path: str): import tensorrt as trt from torch2trt import TRTModule + trt.init_libnvinfer_plugins(None,'') with trt.Logger() as logger, trt.Runtime(logger) as runtime: with open(engine_path, 'rb') as f: engine_bytes = f.read() engine = runtime.deserialize_cuda_engine(engine_bytes) - base_module = TRTModule( + xfeat_trt = TRTModule( engine, input_names=XFeatModel.get_xfeat_input_names(), output_names=XFeatModel.get_xfeat_output_names(), ) - class Wrapper(torch.nn.Module): - def __init__(self, base_module: TRTModule, max_batch_size: int = 1): - super().__init__() - self.base_module = base_module - self.max_batch_size = max_batch_size - - @torch.no_grad() - def forward(self, image): - - b = image.shape[0] - - results = [] - - for start_index in range(0, b, self.max_batch_size): - end_index = min(b, start_index + self.max_batch_size) - image_slice = image[start_index:end_index] - # with torch_timeit_sync("run_engine"): - output = self.base_module(image_slice) - results.append( - output - ) - - return XFeatModelOutput( - image_embeds=torch.cat([r[0] for r in results], dim=0), - image_class_embeds=torch.cat([r[1] for r in results], dim=0), - logit_shift=torch.cat([r[2] for r in results], dim=0), - logit_scale=torch.cat([r[3] for r in results], dim=0), - pred_boxes=torch.cat([r[4] for r in results], dim=0) - ) - - xfeat_engine = Wrapper(base_module=base_module) - return xfeat_engine + return xfeat_trt diff --git a/realtime_demo.py b/realtime_demo.py index 354b2d1..9b08b6b 100644 --- a/realtime_demo.py +++ b/realtime_demo.py @@ -22,6 +22,7 @@ def argparser(): parser.add_argument('--max_kpts', type=int, default=3_000, help='Maximum number of keypoints.') parser.add_argument('--method', type=str, choices=['ORB', 'SIFT', 'XFeat'], default='XFeat', help='Local feature detection method to use.') parser.add_argument('--cam', type=int, default=0, help='Webcam device number.') + parser.add_argument('--use_engine', type=bool, default=False, help='Use generated TensorRT engine file.') return parser.parse_args() @@ -59,13 +60,13 @@ def __init__(self, descriptor, matcher): self.descriptor = descriptor self.matcher = matcher -def init_method(method, max_kpts): +def init_method(method, max_kpts, use_engine=False): if method == "ORB": return Method(descriptor=cv2.ORB_create(max_kpts, fastThreshold=10), matcher=cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)) elif method == "SIFT": return Method(descriptor=cv2.SIFT_create(max_kpts, contrastThreshold=-1, edgeThreshold=1000), matcher=cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)) elif method == "XFeat": - return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts, use_engine=True)), matcher=XFeat(use_engine=True)) + return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts, use_engine=use_engine)), matcher=XFeat(use_engine=use_engine)) else: raise RuntimeError("Invalid Method.") @@ -97,7 +98,7 @@ def __init__(self, args): self.max_cnt = 30 #avg FPS over this number of frames #Set local feature method here -- we expect cv2 or Kornia convention - self.method = init_method(args.method, max_kpts=args.max_kpts) + self.method = init_method(args.method, max_kpts=args.max_kpts, use_engine=args.use_engine) # Setting up font for captions self.font = cv2.FONT_HERSHEY_SIMPLEX @@ -129,9 +130,9 @@ def setup_camera(self): def draw_quad(self, frame, point_list): if len(self.corners) > 1: for i in range(len(self.corners) - 1): - cv2.line(frame, point_list[i], point_list[i + 1], self.line_color, self.line_thickness, lineType = self.line_type) + cv2.line(frame, tuple(point_list[i]), tuple(point_list[i + 1]), self.line_color, self.line_thickness, lineType = self.line_type) if len(self.corners) == 4: # Close the quadrilateral if 4 corners are defined - cv2.line(frame, point_list[3], point_list[0], self.line_color, self.line_thickness, lineType = self.line_type) + cv2.line(frame, tuple(point_list[3]), tuple(point_list[0]), self.line_color, self.line_thickness, lineType = self.line_type) def mouse_callback(self, event, x, y, flags, param): if event == cv2.EVENT_LBUTTONDOWN: From b9145dbb0c7b2f158f57b1ac09521348e8535c1a Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 3 May 2024 02:52:47 -0400 Subject: [PATCH 4/7] remove unused import --- modules/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/model.py b/modules/model.py index 97bcfa7..45eb49d 100644 --- a/modules/model.py +++ b/modules/model.py @@ -10,7 +10,7 @@ import time from typing import List -from dataclasses import dataclass + class BasicLayer(nn.Module): """ Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU From 1f304ce13f62621b138dfb4ed5963a0d3976bcf8 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 3 May 2024 02:53:34 -0400 Subject: [PATCH 5/7] ignore model exports --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 68bc17f..ce39e3c 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +*.onnx +*.engine From 35bd72e993898984b7f5deec81d6c45559cace9e Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 3 May 2024 03:05:43 -0400 Subject: [PATCH 6/7] update readme docs --- README.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/README.md b/README.md index 92c93f9..97355ca 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,31 @@ xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True output = xfeat.detectAndCompute(torch.randn(1,3,480,640), top_k = 4096)[0] ``` +### TensorRT - Export +Its advisible to use a [NGC container](https://catalog.ngc.nvidia.com/containers). For example, for the NVIDIA Jetson platform refer to [L4T ML](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-ml/tags). The two additional dependencies you need are [TensorRT](https://github.com/NVIDIA/TensorRT), which is usually available in L4T-ML containers, and [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt). + +``` +python3 export.py --help + +usage: export.py [-h] [--weights WEIGHTS] [--imgsz IMGSZ IMGSZ] [--fp16_mode FP16_MODE] [--use_dynamic_axis USE_DYNAMIC_AXIS] + [--onnx_opset ONNX_OPSET] [--workspace WORKSPACE] + +Create ONNX and TensorRT export for XFeat. + +optional arguments: + -h, --help show this help message and exit + --weights WEIGHTS Path to the weights pt file to process + --imgsz IMGSZ IMGSZ Input image size + --fp16_mode FP16_MODE + --use_dynamic_axis USE_DYNAMIC_AXIS + --onnx_opset ONNX_OPSET + --workspace WORKSPACE +``` +### TensorRT - Demo +``` +python3 realtime_demo.py --method XFeat --use_engine True +``` + ### Training XFeat training code will be released soon. Please stay tuned. From 8b2b710dc4f313e06d1cd735755109b61ad66452 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 3 May 2024 10:40:29 -0400 Subject: [PATCH 7/7] use fp16 precision --- README.md | 6 +++++- modules/xfeat.py | 10 +++++++++- realtime_demo.py | 10 +++++++--- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 97355ca..f571a8d 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,11 @@ output = xfeat.detectAndCompute(torch.randn(1,3,480,640), top_k = 4096)[0] ``` ### TensorRT - Export -Its advisible to use a [NGC container](https://catalog.ngc.nvidia.com/containers). For example, for the NVIDIA Jetson platform refer to [L4T ML](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-ml/tags). The two additional dependencies you need are [TensorRT](https://github.com/NVIDIA/TensorRT), which is usually available in L4T-ML containers, and [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt). +Its advisible to use a [NGC container](https://catalog.ngc.nvidia.com/containers). For example, for the NVIDIA Jetson platform refer to [L4T ML](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-ml/tags). The additional dependencies you need: +- [TensorRT](https://github.com/NVIDIA/TensorRT) - usually availble inside the docker container +- `onnx` +- `onnxruntime` +- [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt) ``` python3 export.py --help diff --git a/modules/xfeat.py b/modules/xfeat.py index 31cd2fe..eb200c5 100644 --- a/modules/xfeat.py +++ b/modules/xfeat.py @@ -20,11 +20,15 @@ class XFeat(nn.Module): It supports inference for both sparse and semi-dense feature extraction & matching. """ - def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', top_k = 4096, use_engine=False): + def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', + top_k = 4096, + use_engine=False, + use_fp16=False): super().__init__() self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.top_k = top_k self.use_engine = use_engine + self.use_fp16 = use_fp16 if self.use_engine: if os.path.exists(os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.engine'): weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.engine' @@ -32,6 +36,8 @@ def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../w raise Exception('Engine file does not exist.') self.net = XFeat.load_xfeat_engine(weights) self.dev = 'cuda' # force cuda for TensorRT + if self.use_fp16: + self.net.half() else: self.net = XFeatModel().to(self.dev).eval() if weights is not None: @@ -182,6 +188,8 @@ def preprocess_tensor(self, x): if isinstance(x, np.ndarray) and x.shape == 3: x = torch.tensor(x).permute(2,0,1)[None] x = x.to(self.dev).float() + if self.use_fp16: + x.half() H, W = x.shape[-2:] _H, _W = (H//32) * 32, (W//32) * 32 diff --git a/realtime_demo.py b/realtime_demo.py index 9b08b6b..9ee7052 100644 --- a/realtime_demo.py +++ b/realtime_demo.py @@ -23,6 +23,7 @@ def argparser(): parser.add_argument('--method', type=str, choices=['ORB', 'SIFT', 'XFeat'], default='XFeat', help='Local feature detection method to use.') parser.add_argument('--cam', type=int, default=0, help='Webcam device number.') parser.add_argument('--use_engine', type=bool, default=False, help='Use generated TensorRT engine file.') + parser.add_argument('--use_fp16', type=bool, default=False, help='Use generated TensorRT engine file with fp16 precision.') return parser.parse_args() @@ -60,13 +61,13 @@ def __init__(self, descriptor, matcher): self.descriptor = descriptor self.matcher = matcher -def init_method(method, max_kpts, use_engine=False): +def init_method(method, max_kpts, use_engine=False, use_fp16=False): if method == "ORB": return Method(descriptor=cv2.ORB_create(max_kpts, fastThreshold=10), matcher=cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)) elif method == "SIFT": return Method(descriptor=cv2.SIFT_create(max_kpts, contrastThreshold=-1, edgeThreshold=1000), matcher=cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)) elif method == "XFeat": - return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts, use_engine=use_engine)), matcher=XFeat(use_engine=use_engine)) + return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts, use_engine=use_engine, use_fp16=use_fp16)), matcher=XFeat(use_engine=use_engine, use_fp16=use_fp16)) else: raise RuntimeError("Invalid Method.") @@ -98,7 +99,10 @@ def __init__(self, args): self.max_cnt = 30 #avg FPS over this number of frames #Set local feature method here -- we expect cv2 or Kornia convention - self.method = init_method(args.method, max_kpts=args.max_kpts, use_engine=args.use_engine) + self.method = init_method(args.method, + max_kpts=args.max_kpts, + use_engine=args.use_engine, + use_fp16=args.use_fp16) # Setting up font for captions self.font = cv2.FONT_HERSHEY_SIMPLEX