diff --git a/unstructured/metrics/object_detection.py b/unstructured/metrics/object_detection.py index 7c28721518..6eb8175b47 100644 --- a/unstructured/metrics/object_detection.py +++ b/unstructured/metrics/object_detection.py @@ -8,6 +8,7 @@ import numpy as np import torch +from numba import njit IOU_THRESHOLDS = torch.tensor( [0.5000, 0.5500, 0.6000, 0.6500, 0.7000, 0.7500, 0.8000, 0.8500, 0.9000, 0.9500] @@ -303,8 +304,8 @@ def _change_bbox_bounds_for_image_size( Returns: clipped_boxes: Clipped bboxes in XYXY format of [..., 4] shape """ - boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(min=0, max=img_shape[1]) - boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(min=0, max=img_shape[0]) + # Use Numba-accelerated function for inplace fast clipping + _numba_clip_boxes(boxes, img_shape[0], img_shape[1]) return boxes @staticmethod @@ -697,6 +698,42 @@ def _compute_detection_metrics_per_cls( return ap, precision, recall +@njit(cache=True, fastmath=True) +def _numba_clip_boxes(boxes: np.ndarray, img_height: int, img_width: int) -> np.ndarray: + # This helper efficiently clips _inplace_ for (..., 4)-shaped XYXY boxes: + # x1 = min(max(x1, 0), img_width) + # y1 = min(max(y1, 0), img_height) + # x2 = min(max(x2, 0), img_width) + # y2 = min(max(y2, 0), img_height) + # + # Works with both 2D (N, 4) and higher dimensional (..., 4) arrays. + boxes.shape + boxes_reshaped = boxes.reshape(-1, 4) + for i in range(boxes_reshaped.shape[0]): + # x1 + if boxes_reshaped[i, 0] < 0: + boxes_reshaped[i, 0] = 0 + if boxes_reshaped[i, 0] > img_width: + boxes_reshaped[i, 0] = img_width + # y1 + if boxes_reshaped[i, 1] < 0: + boxes_reshaped[i, 1] = 0 + if boxes_reshaped[i, 1] > img_height: + boxes_reshaped[i, 1] = img_height + # x2 + if boxes_reshaped[i, 2] < 0: + boxes_reshaped[i, 2] = 0 + if boxes_reshaped[i, 2] > img_width: + boxes_reshaped[i, 2] = img_width + # y2 + if boxes_reshaped[i, 3] < 0: + boxes_reshaped[i, 3] = 0 + if boxes_reshaped[i, 3] > img_height: + boxes_reshaped[i, 3] = img_height + # No need to reshape as it is a view of original + return boxes + + if __name__ == "__main__": from dataclasses import asdict