Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions unstructured/metrics/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import torch
from numba import njit

IOU_THRESHOLDS = torch.tensor(
[0.5000, 0.5500, 0.6000, 0.6500, 0.7000, 0.7500, 0.8000, 0.8500, 0.9000, 0.9500]
Expand Down Expand Up @@ -303,8 +304,8 @@ def _change_bbox_bounds_for_image_size(
Returns:
clipped_boxes: Clipped bboxes in XYXY format of [..., 4] shape
"""
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(min=0, max=img_shape[1])
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(min=0, max=img_shape[0])
# Use Numba-accelerated function for inplace fast clipping
_numba_clip_boxes(boxes, img_shape[0], img_shape[1])
return boxes

@staticmethod
Expand Down Expand Up @@ -697,6 +698,42 @@ def _compute_detection_metrics_per_cls(
return ap, precision, recall


@njit(cache=True, fastmath=True)
def _numba_clip_boxes(boxes: np.ndarray, img_height: int, img_width: int) -> np.ndarray:
# This helper efficiently clips _inplace_ for (..., 4)-shaped XYXY boxes:
# x1 = min(max(x1, 0), img_width)
# y1 = min(max(y1, 0), img_height)
# x2 = min(max(x2, 0), img_width)
# y2 = min(max(y2, 0), img_height)
#
# Works with both 2D (N, 4) and higher dimensional (..., 4) arrays.
boxes.shape
boxes_reshaped = boxes.reshape(-1, 4)
for i in range(boxes_reshaped.shape[0]):
# x1
if boxes_reshaped[i, 0] < 0:
boxes_reshaped[i, 0] = 0
if boxes_reshaped[i, 0] > img_width:
boxes_reshaped[i, 0] = img_width
# y1
if boxes_reshaped[i, 1] < 0:
boxes_reshaped[i, 1] = 0
if boxes_reshaped[i, 1] > img_height:
boxes_reshaped[i, 1] = img_height
# x2
if boxes_reshaped[i, 2] < 0:
boxes_reshaped[i, 2] = 0
if boxes_reshaped[i, 2] > img_width:
boxes_reshaped[i, 2] = img_width
# y2
if boxes_reshaped[i, 3] < 0:
boxes_reshaped[i, 3] = 0
if boxes_reshaped[i, 3] > img_height:
boxes_reshaped[i, 3] = img_height
# No need to reshape as it is a view of original
return boxes


if __name__ == "__main__":
from dataclasses import asdict

Expand Down