From bfb83215977b8a1f88c0871904093700939611f1 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Sun, 22 Mar 2026 21:12:15 -0700 Subject: [PATCH 01/25] NMS implementation for CPU --- torchvision/_autograd_registrations.py | 10 ++ torchvision/_meta_registrations.py | 14 +++ .../csrc/ops/cpu/nms_rotated_kernel.cpp | 113 ++++++++++++++++++ torchvision/csrc/ops/nms_rotated.cpp | 33 +++++ torchvision/csrc/ops/nms_rotated.h | 21 ++++ torchvision/csrc/ops/ops.h | 1 + torchvision/ops/__init__.py | 2 + torchvision/ops/boxes.py | 27 +++++ 8 files changed, 221 insertions(+) create mode 100644 torchvision/csrc/ops/cpu/nms_rotated_kernel.cpp create mode 100644 torchvision/csrc/ops/nms_rotated.cpp create mode 100644 torchvision/csrc/ops/nms_rotated.h diff --git a/torchvision/_autograd_registrations.py b/torchvision/_autograd_registrations.py index 18d9ced6c54..564657ee35a 100644 --- a/torchvision/_autograd_registrations.py +++ b/torchvision/_autograd_registrations.py @@ -235,6 +235,15 @@ def _autocast_nms(dets, scores, iou_threshold): ) +def _autocast_nms_rotated(dets, scores, iou_threshold): + with torch._C._ExcludeDispatchKeyGuard(_all_autocast_keys): + return torch.ops.torchvision.nms_rotated( + _autocast_cast(dets), + _autocast_cast(scores), + iou_threshold, + ) + + def _autocast_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): orig_dtype = input.dtype with torch._C._ExcludeDispatchKeyGuard(_all_autocast_keys): @@ -358,6 +367,7 @@ def _autocast_deform_conv2d( # nms and roi_align: registered for all autocast device types for _key in ("AutocastCUDA", "AutocastCPU", "AutocastXPU"): _autocast_lib.impl("nms", _autocast_nms, _key) + _autocast_lib.impl("nms_rotated", _autocast_nms_rotated, _key) _autocast_lib.impl("roi_align", _autocast_roi_align, _key) # Other ops: CUDA autocast only diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index f75bfb77a7f..e7a183250e4 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -174,6 +174,20 @@ def meta_nms(dets, scores, iou_threshold): return dets.new_empty(num_to_keep, dtype=torch.long) +@torch.library.register_fake("torchvision::nms_rotated") +def meta_nms_rotated(dets, scores, iou_threshold): + torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") + torch._check(dets.size(1) == 5, lambda: f"boxes should have 5 elements in dimension 1, got {dets.size(1)}") + torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}") + torch._check( + dets.size(0) == scores.size(0), + lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}", + ) + ctx = torch._custom_ops.get_ctx() + num_to_keep = ctx.create_unbacked_symint() + return dets.new_empty(num_to_keep, dtype=torch.long) + + @register_meta("deform_conv2d") def meta_deform_conv2d( input, diff --git a/torchvision/csrc/ops/cpu/nms_rotated_kernel.cpp b/torchvision/csrc/ops/cpu/nms_rotated_kernel.cpp new file mode 100644 index 00000000000..c0a3b617eae --- /dev/null +++ b/torchvision/csrc/ops/cpu/nms_rotated_kernel.cpp @@ -0,0 +1,113 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// This file contains code adapted from Detectron2's nms_rotated +// implementation, which is licensed under the Apache License, Version 2.0. +// Original source: https://github.com/facebookresearch/detectron2 +// License: https://github.com/facebookresearch/detectron2/blob/main/LICENSE + +#include "../box_iou_rotated_utils.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +template +at::Tensor nms_rotated_cpu_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); + TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); + TORCH_CHECK( + dets.scalar_type() == scores.scalar_type(), + "dets should have the same type as scores"); + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong)); + } + + auto order_t = std::get<1>( + scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + + auto ndets = dets.size(0); + at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); + at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); + + auto suppressed = suppressed_t.data_ptr(); + auto keep = keep_t.data_ptr(); + auto order = order_t.data_ptr(); + + int64_t num_to_keep = 0; + + for (int64_t _i = 0; _i < ndets; _i++) { + auto i = order[_i]; + if (suppressed[i] == 1) { + continue; + } + + keep[num_to_keep++] = i; + + for (int64_t _j = _i + 1; _j < ndets; _j++) { + auto j = order[_j]; + if (suppressed[j] == 1) { + continue; + } + + auto ovr = single_box_iou_rotated( + dets[i].data_ptr(), dets[j].data_ptr()); + if (ovr >= iou_threshold) { + suppressed[j] = 1; + } + } + } + return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); +} + +at::Tensor nms_rotated_cpu( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 5, + "boxes should have 5 elements in dimension 1 (cx, cy, w, h, angle), got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); + + auto result = at::empty({0}, dets.options()); + + AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated_cpu", [&] { + result = nms_rotated_cpu_kernel(dets, scores, iou_threshold); + }); + return result; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::nms_rotated"), + TORCH_FN(nms_rotated_cpu)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/nms_rotated.cpp b/torchvision/csrc/ops/nms_rotated.cpp new file mode 100644 index 00000000000..da619e6c32c --- /dev/null +++ b/torchvision/csrc/ops/nms_rotated.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "nms_rotated.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +at::Tensor nms_rotated( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.nms_rotated.nms_rotated"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::nms_rotated", "") + .typed(); + return op.call(dets, scores, iou_threshold); +} + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::nms_rotated(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/nms_rotated.h b/torchvision/csrc/ops/nms_rotated.h new file mode 100644 index 00000000000..98bc225f691 --- /dev/null +++ b/torchvision/csrc/ops/nms_rotated.h @@ -0,0 +1,21 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API at::Tensor nms_rotated( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold); + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/ops.h b/torchvision/csrc/ops/ops.h index 9902c3b1ecd..173f2ef77b1 100644 --- a/torchvision/csrc/ops/ops.h +++ b/torchvision/csrc/ops/ops.h @@ -3,6 +3,7 @@ #include "box_iou_rotated.h" #include "deform_conv2d.h" #include "nms.h" +#include "nms_rotated.h" #include "ps_roi_align.h" #include "ps_roi_pool.h" #include "roi_align.h" diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 827505b842d..7e48945545e 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -10,6 +10,7 @@ generalized_box_iou, masks_to_boxes, nms, + nms_rotated, remove_small_boxes, ) from .ciou_loss import complete_box_iou_loss @@ -35,6 +36,7 @@ "deform_conv2d", "DeformConv2d", "nms", + "nms_rotated", "batched_nms", "remove_small_boxes", "clip_boxes_to_image", diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index a089af2c4ad..6779de7b760 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -48,6 +48,33 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: return torch.ops.torchvision.nms(boxes, scores, iou_threshold) +def nms_rotated(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: + """ + Performs non-maximum suppression (NMS) on the rotated boxes according + to their intersection-over-union (IoU). + + NMS iteratively removes lower scoring boxes which have an + IoU greater than ``iou_threshold`` with another (higher scoring) + box. + + Args: + boxes (Tensor[N, 5])): rotated boxes to perform NMS on. They + are expected to be in ``(cx, cy, w, h, angle)`` format where + ``(cx, cy)`` is the center, ``(w, h)`` is width and height, + and ``angle`` is the rotation angle in degrees. + scores (Tensor[N]): scores for each one of the boxes + iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold + + Returns: + Tensor: int64 tensor with the indices of the elements that have been kept + by NMS, sorted in decreasing order of scores + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(nms_rotated) + _assert_has_ops() + return torch.ops.torchvision.nms_rotated(boxes, scores, iou_threshold) + + def batched_nms( boxes: Tensor, scores: Tensor, From 88758757c1b50ac69f1b3be4f879d745dca2447a Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Sun, 22 Mar 2026 22:58:29 -0700 Subject: [PATCH 02/25] add tests --- test/test_ops.py | 119 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 9521f21a815..1f94f9db061 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2007,6 +2007,125 @@ def test_cuda_cpu_consistency(self): torch.testing.assert_close(iou_cpu, iou_cuda.cpu(), atol=1e-5, rtol=1e-5) +class TestNMSRotated: + def _reference_horizontal_nms(self, boxes, scores, iou_threshold): + """ + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + (Note here 5 == 4 + 1, i.e., 4-dim horizontal box + 1-dim prob) + iou_threshold: intersection over union threshold. + Returns: + picked: a list of indexes of the kept boxes + """ + picked = [] + _, indexes = scores.sort(descending=True) + while len(indexes) > 0: + current = indexes[0] + picked.append(current.item()) + if len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[1:] + rest_boxes = boxes[indexes, :] + iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1) + indexes = indexes[iou <= iou_threshold] + + return torch.as_tensor(picked) + + @staticmethod + def _nms_edit_distance(keep1, keep2): + """ + Compare the "keep" result of two nms call. + They are allowed to be different in terms of edit distance + due to floating point precision issues, e.g., + if a box happen to have an IoU of 0.5 with another box, + one implementation may choose to keep it while another may discard it. + """ + keep1, keep2 = keep1.cpu(), keep2.cpu() + if torch.equal(keep1, keep2): + # they should be equal most of the time + return 0 + keep1, keep2 = tuple(keep1), tuple(keep2) + m, n = len(keep1), len(keep2) + + # edit distance with DP + f = [np.arange(n + 1), np.arange(n + 1)] + for i in range(m): + cur_row = i % 2 + other_row = (i + 1) % 2 + f[other_row][0] = i + 1 + for j in range(n): + f[other_row][j + 1] = ( + f[cur_row][j] + if keep1[i] == keep2[j] + else min(min(f[cur_row][j], f[cur_row][j + 1]), f[other_row][j]) + 1 + ) + return f[m % 2][n] + + @staticmethod + def _create_tensors(N, device="cpu"): + boxes = torch.rand(N, 4, device=device) * 200 + boxes[:, 2:] += boxes[:, :2] + scores = torch.rand(N, device=device) + return boxes, scores + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_nms_rotated_0_degree(self, iou): + N = 1000 + boxes, scores = self._create_tensors(N) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + + keep_ref = self._reference_horizontal_nms(boxes, scores, iou) + keep = ops.nms_rotated(rotated_boxes, scores, iou) + err_msg = f"Rotated NMS incompatible with reference implementation for IoU={iou}" + assert self._nms_edit_distance(keep, keep_ref) <= 1, err_msg + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_nms_rotated_90_degrees(self, iou): + N = 1000 + boxes, scores = self._create_tensors(N) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + # Swap width and height for 90 degrees so reference horizontal NMS can be used + rotated_boxes[:, 2] = boxes[:, 3] - boxes[:, 1] + rotated_boxes[:, 3] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 4] = 90 + + keep_ref = self._reference_horizontal_nms(boxes, scores, iou) + keep = ops.nms_rotated(rotated_boxes, scores, iou) + err_msg = f"Rotated NMS incompatible with reference implementation for IoU={iou}" + assert self._nms_edit_distance(keep, keep_ref) <= 1, err_msg + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_nms_rotated_180_degrees(self, iou): + N = 1000 + boxes, scores = self._create_tensors(N) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + rotated_boxes[:, 4] = 180 + + keep_ref = self._reference_horizontal_nms(boxes, scores, iou) + keep = ops.nms_rotated(rotated_boxes, scores, iou) + err_msg = f"Rotated NMS incompatible with reference implementation for IoU={iou}" + assert self._nms_edit_distance(keep, keep_ref) <= 1, err_msg + + def test_nms_rotated_scriptable(self): + class TestingModule(torch.nn.Module): + def forward(self, boxes, scores, threshold): + return ops.nms_rotated(boxes, scores, threshold) + + m = TestingModule() + _ = torch.jit.script(m) + + def get_boxes(dtype, device): box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) From 0e924edd7d8d41eebc8f711e1fce79a072ae7f93 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 24 Mar 2026 01:17:10 -0700 Subject: [PATCH 03/25] fuse two implementations --- test/test_ops.py | 8 +- torchvision/csrc/ops/cpu/nms_kernel.cpp | 122 +++++++++++++----- .../csrc/ops/cpu/nms_rotated_kernel.cpp | 113 ---------------- torchvision/ops/__init__.py | 2 - torchvision/ops/boxes.py | 58 ++++----- 5 files changed, 123 insertions(+), 180 deletions(-) delete mode 100644 torchvision/csrc/ops/cpu/nms_rotated_kernel.cpp diff --git a/test/test_ops.py b/test/test_ops.py index 1f94f9db061..d90acf11c25 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2080,7 +2080,7 @@ def test_nms_rotated_0_degree(self, iou): rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] keep_ref = self._reference_horizontal_nms(boxes, scores, iou) - keep = ops.nms_rotated(rotated_boxes, scores, iou) + keep = ops.nms(rotated_boxes, scores, iou, fmt="cxcywhr") err_msg = f"Rotated NMS incompatible with reference implementation for IoU={iou}" assert self._nms_edit_distance(keep, keep_ref) <= 1, err_msg @@ -2097,7 +2097,7 @@ def test_nms_rotated_90_degrees(self, iou): rotated_boxes[:, 4] = 90 keep_ref = self._reference_horizontal_nms(boxes, scores, iou) - keep = ops.nms_rotated(rotated_boxes, scores, iou) + keep = ops.nms(rotated_boxes, scores, iou, fmt="cxcywhr") err_msg = f"Rotated NMS incompatible with reference implementation for IoU={iou}" assert self._nms_edit_distance(keep, keep_ref) <= 1, err_msg @@ -2113,14 +2113,14 @@ def test_nms_rotated_180_degrees(self, iou): rotated_boxes[:, 4] = 180 keep_ref = self._reference_horizontal_nms(boxes, scores, iou) - keep = ops.nms_rotated(rotated_boxes, scores, iou) + keep = ops.nms(rotated_boxes, scores, iou, fmt="cxcywhr") err_msg = f"Rotated NMS incompatible with reference implementation for IoU={iou}" assert self._nms_edit_distance(keep, keep_ref) <= 1, err_msg def test_nms_rotated_scriptable(self): class TestingModule(torch.nn.Module): def forward(self, boxes, scores, threshold): - return ops.nms_rotated(boxes, scores, threshold) + return ops.nms(boxes, scores, threshold, fmt="cxcywhr") m = TestingModule() _ = torch.jit.script(m) diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 454ce118a6d..681f005786e 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -1,16 +1,30 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// The rotated box IoU computation is adapted from Detectron2's implementation, +// which is licensed under the Apache License, Version 2.0. +// Original source: https://github.com/facebookresearch/detectron2 +// License: https://github.com/facebookresearch/detectron2/blob/main/LICENSE + #include #include +#include "../box_iou_rotated_utils.h" + namespace vision { namespace ops { namespace { -template +template at::Tensor nms_kernel_impl( const at::Tensor& dets, const at::Tensor& scores, - double iou_threshold) { + double iou_threshold, + IoUFunc iou_func) { TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); TORCH_CHECK( @@ -21,13 +35,6 @@ at::Tensor nms_kernel_impl( return at::empty({0}, dets.options().dtype(at::kLong)); } - auto x1_t = dets.select(1, 0).contiguous(); - auto y1_t = dets.select(1, 1).contiguous(); - auto x2_t = dets.select(1, 2).contiguous(); - auto y2_t = dets.select(1, 3).contiguous(); - - at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); - auto order_t = std::get<1>( scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); @@ -38,11 +45,6 @@ at::Tensor nms_kernel_impl( auto suppressed = suppressed_t.data_ptr(); auto keep = keep_t.data_ptr(); auto order = order_t.data_ptr(); - auto x1 = x1_t.data_ptr(); - auto y1 = y1_t.data_ptr(); - auto x2 = x2_t.data_ptr(); - auto y2 = y2_t.data_ptr(); - auto areas = areas_t.data_ptr(); int64_t num_to_keep = 0; @@ -52,26 +54,14 @@ at::Tensor nms_kernel_impl( continue; } keep[num_to_keep++] = i; - auto ix1 = x1[i]; - auto iy1 = y1[i]; - auto ix2 = x2[i]; - auto iy2 = y2[i]; - auto iarea = areas[i]; for (int64_t _j = _i + 1; _j < ndets; _j++) { auto j = order[_j]; if (suppressed[j] == 1) { continue; } - auto xx1 = std::max(ix1, x1[j]); - auto yy1 = std::max(iy1, y1[j]); - auto xx2 = std::min(ix2, x2[j]); - auto yy2 = std::min(iy2, y2[j]); - - auto w = std::max(static_cast(0), xx2 - xx1); - auto h = std::max(static_cast(0), yy2 - yy1); - auto inter = w * h; - auto ovr = inter / (iarea + areas[j] - inter); + + auto ovr = iou_func(dets, i, j); if (ovr > iou_threshold) { suppressed[j] = 1; } @@ -80,6 +70,44 @@ at::Tensor nms_kernel_impl( return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); } +template +struct AABBIoU { + scalar_t operator()(const at::Tensor& dets, int64_t i, int64_t j) const { + auto x1 = dets.data_ptr(); + auto ncols = dets.size(1); + + auto ix1 = x1[i * ncols + 0]; + auto iy1 = x1[i * ncols + 1]; + auto ix2 = x1[i * ncols + 2]; + auto iy2 = x1[i * ncols + 3]; + auto iarea = (ix2 - ix1) * (iy2 - iy1); + + auto jx1 = x1[j * ncols + 0]; + auto jy1 = x1[j * ncols + 1]; + auto jx2 = x1[j * ncols + 2]; + auto jy2 = x1[j * ncols + 3]; + auto jarea = (jx2 - jx1) * (jy2 - jy1); + + auto xx1 = std::max(ix1, jx1); + auto yy1 = std::max(iy1, jy1); + auto xx2 = std::min(ix2, jx2); + auto yy2 = std::min(iy2, jy2); + + auto w = std::max(static_cast(0), xx2 - xx1); + auto h = std::max(static_cast(0), yy2 - yy1); + auto inter = w * h; + return inter / (iarea + jarea - inter); + } +}; + +template +struct RotatedIoU { + scalar_t operator()(const at::Tensor& dets, int64_t i, int64_t j) const { + return single_box_iou_rotated( + dets[i].data_ptr(), dets[j].data_ptr()); + } +}; + at::Tensor nms_kernel( const at::Tensor& dets, const at::Tensor& scores, @@ -106,7 +134,40 @@ at::Tensor nms_kernel( auto result = at::empty({0}, dets.options()); AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { - result = nms_kernel_impl(dets, scores, iou_threshold); + result = nms_kernel_impl( + dets, scores, iou_threshold, AABBIoU{}); + }); + return result; +} + +at::Tensor nms_rotated_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 5, + "boxes should have 5 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); + + auto result = at::empty({0}, dets.options()); + + AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated_kernel", [&] { + result = nms_kernel_impl( + dets, scores, iou_threshold, RotatedIoU{}); }); return result; } @@ -115,6 +176,9 @@ at::Tensor nms_kernel( TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::nms_rotated"), + TORCH_FN(nms_rotated_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cpu/nms_rotated_kernel.cpp b/torchvision/csrc/ops/cpu/nms_rotated_kernel.cpp deleted file mode 100644 index c0a3b617eae..00000000000 --- a/torchvision/csrc/ops/cpu/nms_rotated_kernel.cpp +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. -// -// This file contains code adapted from Detectron2's nms_rotated -// implementation, which is licensed under the Apache License, Version 2.0. -// Original source: https://github.com/facebookresearch/detectron2 -// License: https://github.com/facebookresearch/detectron2/blob/main/LICENSE - -#include "../box_iou_rotated_utils.h" - -#include -#include - -namespace vision { -namespace ops { - -namespace { - -template -at::Tensor nms_rotated_cpu_kernel( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); - TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); - TORCH_CHECK( - dets.scalar_type() == scores.scalar_type(), - "dets should have the same type as scores"); - - if (dets.numel() == 0) { - return at::empty({0}, dets.options().dtype(at::kLong)); - } - - auto order_t = std::get<1>( - scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); - - auto ndets = dets.size(0); - at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); - at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); - - auto suppressed = suppressed_t.data_ptr(); - auto keep = keep_t.data_ptr(); - auto order = order_t.data_ptr(); - - int64_t num_to_keep = 0; - - for (int64_t _i = 0; _i < ndets; _i++) { - auto i = order[_i]; - if (suppressed[i] == 1) { - continue; - } - - keep[num_to_keep++] = i; - - for (int64_t _j = _i + 1; _j < ndets; _j++) { - auto j = order[_j]; - if (suppressed[j] == 1) { - continue; - } - - auto ovr = single_box_iou_rotated( - dets[i].data_ptr(), dets[j].data_ptr()); - if (ovr >= iou_threshold) { - suppressed[j] = 1; - } - } - } - return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); -} - -at::Tensor nms_rotated_cpu( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - TORCH_CHECK( - dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK( - dets.size(1) == 5, - "boxes should have 5 elements in dimension 1 (cx, cy, w, h, angle), got ", - dets.size(1)); - TORCH_CHECK( - scores.dim() == 1, - "scores should be a 1d tensor, got ", - scores.dim(), - "D"); - TORCH_CHECK( - dets.size(0) == scores.size(0), - "boxes and scores should have same number of elements in dimension 0, got ", - dets.size(0), - " and ", - scores.size(0)); - - auto result = at::empty({0}, dets.options()); - - AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated_cpu", [&] { - result = nms_rotated_cpu_kernel(dets, scores, iou_threshold); - }); - return result; -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchvision, CPU, m) { - m.impl( - TORCH_SELECTIVE_NAME("torchvision::nms_rotated"), - TORCH_FN(nms_rotated_cpu)); -} - -} // namespace ops -} // namespace vision diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 7e48945545e..827505b842d 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -10,7 +10,6 @@ generalized_box_iou, masks_to_boxes, nms, - nms_rotated, remove_small_boxes, ) from .ciou_loss import complete_box_iou_loss @@ -36,7 +35,6 @@ "deform_conv2d", "DeformConv2d", "nms", - "nms_rotated", "batched_nms", "remove_small_boxes", "clip_boxes_to_image", diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 6779de7b760..4c6be7b7711 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -17,7 +17,7 @@ from ._utils import _upcast -def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: +def nms(boxes: Tensor, scores: Tensor, iou_threshold: float, fmt: str = "xyxy") -> Tensor: """ Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU). @@ -32,47 +32,38 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: to the behavior of argsort in PyTorch when repeated values are present. Args: - boxes (Tensor[N, 4])): boxes to perform NMS on. They - are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and - ``0 <= y1 < y2``. + boxes (Tensor[N, K])): boxes to perform NMS on. They + are expected to be in the format specified by ``fmt``. scores (Tensor[N]): scores for each one of the boxes iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold + fmt (str): Format of the input boxes. + Default is "xyxy" to preserve backward compatibility. - Returns: - Tensor: int64 tensor with the indices of the elements that have been kept - by NMS, sorted in decreasing order of scores - """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(nms) - _assert_has_ops() - return torch.ops.torchvision.nms(boxes, scores, iou_threshold) - + Supported axis-aligned format (K=4): -def nms_rotated(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: - """ - Performs non-maximum suppression (NMS) on the rotated boxes according - to their intersection-over-union (IoU). + - ``'xyxy'``: boxes are represented via (x1, y1, x2, y2) corner coordinates. - NMS iteratively removes lower scoring boxes which have an - IoU greater than ``iou_threshold`` with another (higher scoring) - box. + Supported rotated format (K=5): - Args: - boxes (Tensor[N, 5])): rotated boxes to perform NMS on. They - are expected to be in ``(cx, cy, w, h, angle)`` format where - ``(cx, cy)`` is the center, ``(w, h)`` is width and height, - and ``angle`` is the rotation angle in degrees. - scores (Tensor[N]): scores for each one of the boxes - iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold + - ``'cxcywhr'``: boxes are represented via center, width, height, and rotation angle. + (cx, cy) is the center, (w, h) is width and height, r is rotation angle in degrees. Returns: Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted in decreasing order of scores """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(nms_rotated) + _log_api_usage_once(nms) _assert_has_ops() - return torch.ops.torchvision.nms_rotated(boxes, scores, iou_threshold) + + if fmt == "xyxy": + return torch.ops.torchvision.nms(boxes, scores, iou_threshold) + elif fmt == "cxcywhr": + return torch.ops.torchvision.nms_rotated(boxes, scores, iou_threshold) + else: + raise ValueError( + f"Unsupported format '{fmt}'. " f"Supported formats: 'xyxy' (axis-aligned), 'cxcywhr' (rotated)." + ) def batched_nms( @@ -350,14 +341,17 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple elif fmt == "xywh": lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2] rb = torch.min( - boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] + boxes1[..., None, :2] + boxes1[..., None, 2:], + boxes2[..., None, :, :2] + boxes2[..., None, :, 2:], ) # [...,N,M,2] else: # fmt == "cxcywh": lt = torch.max( - boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2 + boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, + boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2, ) # [N,M,2] rb = torch.min( - boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2 + boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, + boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2, ) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] From 667e1b4a52689a7e797ffa5647a6814730499f6e Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 24 Mar 2026 01:30:45 -0700 Subject: [PATCH 04/25] Remove Detectron2 license header from nms_kernel.cpp since the NMS algorithm is standard TorchVision code; attribution already in box_iou_rotated_utils.h --- torchvision/csrc/ops/cpu/nms_kernel.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 681f005786e..590e8cbc119 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -3,11 +3,6 @@ // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -// -// The rotated box IoU computation is adapted from Detectron2's implementation, -// which is licensed under the Apache License, Version 2.0. -// Original source: https://github.com/facebookresearch/detectron2 -// License: https://github.com/facebookresearch/detectron2/blob/main/LICENSE #include #include From f01ca8bf2640ea7c6c00082c8ab5452793a35452 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 24 Mar 2026 01:42:45 -0700 Subject: [PATCH 05/25] revert the unnecessary linting --- torchvision/ops/boxes.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 4c6be7b7711..55532befb52 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -341,17 +341,14 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple elif fmt == "xywh": lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2] rb = torch.min( - boxes1[..., None, :2] + boxes1[..., None, 2:], - boxes2[..., None, :, :2] + boxes2[..., None, :, 2:], + boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] ) # [...,N,M,2] else: # fmt == "cxcywh": lt = torch.max( - boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, - boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2, + boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2 ) # [N,M,2] rb = torch.min( - boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, - boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2, + boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2 ) # [N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] From 40adc7faaccb4dffb5873b24508206125aef0c53 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 24 Mar 2026 02:04:19 -0700 Subject: [PATCH 06/25] Preserve original IoU computation in fused NMS implementation --- torchvision/csrc/ops/cpu/nms_kernel.cpp | 49 ++++++++++++++----------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 590e8cbc119..7b1fbfb5548 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -67,31 +67,36 @@ at::Tensor nms_kernel_impl( template struct AABBIoU { - scalar_t operator()(const at::Tensor& dets, int64_t i, int64_t j) const { - auto x1 = dets.data_ptr(); - auto ncols = dets.size(1); - - auto ix1 = x1[i * ncols + 0]; - auto iy1 = x1[i * ncols + 1]; - auto ix2 = x1[i * ncols + 2]; - auto iy2 = x1[i * ncols + 3]; - auto iarea = (ix2 - ix1) * (iy2 - iy1); - - auto jx1 = x1[j * ncols + 0]; - auto jy1 = x1[j * ncols + 1]; - auto jx2 = x1[j * ncols + 2]; - auto jy2 = x1[j * ncols + 3]; - auto jarea = (jx2 - jx1) * (jy2 - jy1); - - auto xx1 = std::max(ix1, jx1); - auto yy1 = std::max(iy1, jy1); - auto xx2 = std::min(ix2, jx2); - auto yy2 = std::min(iy2, jy2); + const scalar_t* x1; + const scalar_t* y1; + const scalar_t* x2; + const scalar_t* y2; + const scalar_t* areas; + at::Tensor x1_t, y1_t, x2_t, y2_t, areas_t; + + AABBIoU(const at::Tensor& dets) { + x1_t = dets.select(1, 0).contiguous(); + y1_t = dets.select(1, 1).contiguous(); + x2_t = dets.select(1, 2).contiguous(); + y2_t = dets.select(1, 3).contiguous(); + areas_t = (x2_t - x1_t) * (y2_t - y1_t); + x1 = x1_t.data_ptr(); + y1 = y1_t.data_ptr(); + x2 = x2_t.data_ptr(); + y2 = y2_t.data_ptr(); + areas = areas_t.data_ptr(); + } + + scalar_t operator()(const at::Tensor& /*dets*/, int64_t i, int64_t j) const { + auto xx1 = std::max(x1[i], x1[j]); + auto yy1 = std::max(y1[i], y1[j]); + auto xx2 = std::min(x2[i], x2[j]); + auto yy2 = std::min(y2[i], y2[j]); auto w = std::max(static_cast(0), xx2 - xx1); auto h = std::max(static_cast(0), yy2 - yy1); auto inter = w * h; - return inter / (iarea + jarea - inter); + return inter / (areas[i] + areas[j] - inter); } }; @@ -130,7 +135,7 @@ at::Tensor nms_kernel( AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { result = nms_kernel_impl( - dets, scores, iou_threshold, AABBIoU{}); + dets, scores, iou_threshold, AABBIoU(dets)); }); return result; } From eb4aefe4871152a2021a7cd88cd87876c56833c2 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 24 Mar 2026 02:16:26 -0700 Subject: [PATCH 07/25] Add box caching in fused NMS to match original memory access pattern --- torchvision/csrc/ops/cpu/nms_kernel.cpp | 43 +++++++++++++++++++------ 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 7b1fbfb5548..ee7a910b078 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -50,13 +50,15 @@ at::Tensor nms_kernel_impl( } keep[num_to_keep++] = i; + iou_func.set_box(i); + for (int64_t _j = _i + 1; _j < ndets; _j++) { auto j = order[_j]; if (suppressed[j] == 1) { continue; } - auto ovr = iou_func(dets, i, j); + auto ovr = iou_func.compare(j); if (ovr > iou_threshold) { suppressed[j] = 1; } @@ -74,6 +76,8 @@ struct AABBIoU { const scalar_t* areas; at::Tensor x1_t, y1_t, x2_t, y2_t, areas_t; + scalar_t ix1, iy1, ix2, iy2, iarea; + AABBIoU(const at::Tensor& dets) { x1_t = dets.select(1, 0).contiguous(); y1_t = dets.select(1, 1).contiguous(); @@ -87,24 +91,43 @@ struct AABBIoU { areas = areas_t.data_ptr(); } - scalar_t operator()(const at::Tensor& /*dets*/, int64_t i, int64_t j) const { - auto xx1 = std::max(x1[i], x1[j]); - auto yy1 = std::max(y1[i], y1[j]); - auto xx2 = std::min(x2[i], x2[j]); - auto yy2 = std::min(y2[i], y2[j]); + void set_box(int64_t i) { + ix1 = x1[i]; + iy1 = y1[i]; + ix2 = x2[i]; + iy2 = y2[i]; + iarea = areas[i]; + } + + scalar_t compare(int64_t j) const { + auto xx1 = std::max(ix1, x1[j]); + auto yy1 = std::max(iy1, y1[j]); + auto xx2 = std::min(ix2, x2[j]); + auto yy2 = std::min(iy2, y2[j]); auto w = std::max(static_cast(0), xx2 - xx1); auto h = std::max(static_cast(0), yy2 - yy1); auto inter = w * h; - return inter / (areas[i] + areas[j] - inter); + return inter / (iarea + areas[j] - inter); } }; template struct RotatedIoU { - scalar_t operator()(const at::Tensor& dets, int64_t i, int64_t j) const { + const at::Tensor* dets_ptr; + + RotatedIoU(const at::Tensor& dets) : dets_ptr(&dets) {} + + int64_t cached_i; + + void set_box(int64_t i) { + cached_i = i; + } + + scalar_t compare(int64_t j) const { return single_box_iou_rotated( - dets[i].data_ptr(), dets[j].data_ptr()); + (*dets_ptr)[cached_i].data_ptr(), + (*dets_ptr)[j].data_ptr()); } }; @@ -167,7 +190,7 @@ at::Tensor nms_rotated_kernel( AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated_kernel", [&] { result = nms_kernel_impl( - dets, scores, iou_threshold, RotatedIoU{}); + dets, scores, iou_threshold, RotatedIoU(dets)); }); return result; } From 5499015e440930e8da4bc3854a4b39cd4a7fde82 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 24 Mar 2026 04:05:28 -0700 Subject: [PATCH 08/25] fix the test failure --- torchvision/csrc/ops/cpu/nms_kernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index ee7a910b078..f8087b67be8 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -126,8 +126,8 @@ struct RotatedIoU { scalar_t compare(int64_t j) const { return single_box_iou_rotated( - (*dets_ptr)[cached_i].data_ptr(), - (*dets_ptr)[j].data_ptr()); + (*dets_ptr)[cached_i].template data_ptr(), + (*dets_ptr)[j].template data_ptr()); } }; From c9403453db28831711058a77d5390b208916f5cf Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 24 Mar 2026 19:44:24 -0700 Subject: [PATCH 09/25] remove the torchscript test and err_msg --- test/test_ops.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index d90acf11c25..991fd74190d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2081,8 +2081,7 @@ def test_nms_rotated_0_degree(self, iou): keep_ref = self._reference_horizontal_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou, fmt="cxcywhr") - err_msg = f"Rotated NMS incompatible with reference implementation for IoU={iou}" - assert self._nms_edit_distance(keep, keep_ref) <= 1, err_msg + assert self._nms_edit_distance(keep, keep_ref) <= 1 @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_rotated_90_degrees(self, iou): @@ -2098,8 +2097,7 @@ def test_nms_rotated_90_degrees(self, iou): keep_ref = self._reference_horizontal_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou, fmt="cxcywhr") - err_msg = f"Rotated NMS incompatible with reference implementation for IoU={iou}" - assert self._nms_edit_distance(keep, keep_ref) <= 1, err_msg + assert self._nms_edit_distance(keep, keep_ref) <= 1 @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_rotated_180_degrees(self, iou): @@ -2114,16 +2112,7 @@ def test_nms_rotated_180_degrees(self, iou): keep_ref = self._reference_horizontal_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou, fmt="cxcywhr") - err_msg = f"Rotated NMS incompatible with reference implementation for IoU={iou}" - assert self._nms_edit_distance(keep, keep_ref) <= 1, err_msg - - def test_nms_rotated_scriptable(self): - class TestingModule(torch.nn.Module): - def forward(self, boxes, scores, threshold): - return ops.nms(boxes, scores, threshold, fmt="cxcywhr") - - m = TestingModule() - _ = torch.jit.script(m) + assert self._nms_edit_distance(keep, keep_ref) <= 1 def get_boxes(dtype, device): From 17a54cc96a5ddb68e8be77294830017a08d89cb3 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 24 Mar 2026 21:09:19 -0700 Subject: [PATCH 10/25] Remove the fmt parameter from nms function --- test/test_ops.py | 6 +++--- torchvision/ops/boxes.py | 27 ++++++++++----------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 991fd74190d..b7bcd48c4a5 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2080,7 +2080,7 @@ def test_nms_rotated_0_degree(self, iou): rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] keep_ref = self._reference_horizontal_nms(boxes, scores, iou) - keep = ops.nms(rotated_boxes, scores, iou, fmt="cxcywhr") + keep = ops.nms(rotated_boxes, scores, iou) assert self._nms_edit_distance(keep, keep_ref) <= 1 @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @@ -2096,7 +2096,7 @@ def test_nms_rotated_90_degrees(self, iou): rotated_boxes[:, 4] = 90 keep_ref = self._reference_horizontal_nms(boxes, scores, iou) - keep = ops.nms(rotated_boxes, scores, iou, fmt="cxcywhr") + keep = ops.nms(rotated_boxes, scores, iou) assert self._nms_edit_distance(keep, keep_ref) <= 1 @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @@ -2111,7 +2111,7 @@ def test_nms_rotated_180_degrees(self, iou): rotated_boxes[:, 4] = 180 keep_ref = self._reference_horizontal_nms(boxes, scores, iou) - keep = ops.nms(rotated_boxes, scores, iou, fmt="cxcywhr") + keep = ops.nms(rotated_boxes, scores, iou) assert self._nms_edit_distance(keep, keep_ref) <= 1 diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 55532befb52..5650d4c69f1 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -17,7 +17,7 @@ from ._utils import _upcast -def nms(boxes: Tensor, scores: Tensor, iou_threshold: float, fmt: str = "xyxy") -> Tensor: +def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: """ Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU). @@ -32,21 +32,14 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float, fmt: str = "xyxy") to the behavior of argsort in PyTorch when repeated values are present. Args: - boxes (Tensor[N, K])): boxes to perform NMS on. They - are expected to be in the format specified by ``fmt``. + boxes (Tensor[N, K])): boxes to perform NMS on. + If K=4, boxes are expected to be in ``(x1, y1, x2, y2)`` format + with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + If K=5, boxes are expected to be in ``(cx, cy, w, h, angle)`` format + for rotated boxes, where ``(cx, cy)`` is the center, ``(w, h)`` is + width and height, and ``angle`` is the rotation angle in degrees. scores (Tensor[N]): scores for each one of the boxes iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold - fmt (str): Format of the input boxes. - Default is "xyxy" to preserve backward compatibility. - - Supported axis-aligned format (K=4): - - - ``'xyxy'``: boxes are represented via (x1, y1, x2, y2) corner coordinates. - - Supported rotated format (K=5): - - - ``'cxcywhr'``: boxes are represented via center, width, height, and rotation angle. - (cx, cy) is the center, (w, h) is width and height, r is rotation angle in degrees. Returns: Tensor: int64 tensor with the indices of the elements that have been kept @@ -56,13 +49,13 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float, fmt: str = "xyxy") _log_api_usage_once(nms) _assert_has_ops() - if fmt == "xyxy": + if boxes.size(-1) == 4: return torch.ops.torchvision.nms(boxes, scores, iou_threshold) - elif fmt == "cxcywhr": + elif boxes.size(-1) == 5: return torch.ops.torchvision.nms_rotated(boxes, scores, iou_threshold) else: raise ValueError( - f"Unsupported format '{fmt}'. " f"Supported formats: 'xyxy' (axis-aligned), 'cxcywhr' (rotated)." + f"boxes should have 4 (axis-aligned) or 5 (rotated) elements in the last dimension, got {boxes.size(-1)}" ) From 54aee4c5641b21bcb571fc2ce2b45f3247029389 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Wed, 25 Mar 2026 03:30:03 -0700 Subject: [PATCH 11/25] fix the test failures --- test/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index b7bcd48c4a5..b912c4e1dfe 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -825,8 +825,8 @@ def test_nms_ref(self, iou, seed): def test_nms_input_errors(self): with pytest.raises(RuntimeError): ops.nms(torch.rand(4), torch.rand(3), 0.5) - with pytest.raises(RuntimeError): - ops.nms(torch.rand(3, 5), torch.rand(3), 0.5) + with pytest.raises((RuntimeError, ValueError)): + ops.nms(torch.rand(3, 6), torch.rand(3), 0.5) with pytest.raises(RuntimeError): ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5) with pytest.raises(RuntimeError): From c5cf1c425074c58a02776bffcbbc7add036df367 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Fri, 27 Mar 2026 00:36:13 -0700 Subject: [PATCH 12/25] Reuse TestNMS._reference_nms in TestNMSRotated --- test/test_ops.py | 33 +++++---------------------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index b912c4e1dfe..020f9ac0223 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -770,7 +770,8 @@ def test_is_leaf_node(self, device): class TestNMS: - def _reference_nms(self, boxes, scores, iou_threshold): + @classmethod + def _reference_nms(cls, boxes, scores, iou_threshold): """ Args: boxes: boxes in corner-form @@ -2008,30 +2009,6 @@ def test_cuda_cpu_consistency(self): class TestNMSRotated: - def _reference_horizontal_nms(self, boxes, scores, iou_threshold): - """ - Args: - box_scores (N, 5): boxes in corner-form and probabilities. - (Note here 5 == 4 + 1, i.e., 4-dim horizontal box + 1-dim prob) - iou_threshold: intersection over union threshold. - Returns: - picked: a list of indexes of the kept boxes - """ - picked = [] - _, indexes = scores.sort(descending=True) - while len(indexes) > 0: - current = indexes[0] - picked.append(current.item()) - if len(indexes) == 1: - break - current_box = boxes[current, :] - indexes = indexes[1:] - rest_boxes = boxes[indexes, :] - iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1) - indexes = indexes[iou <= iou_threshold] - - return torch.as_tensor(picked) - @staticmethod def _nms_edit_distance(keep1, keep2): """ @@ -2079,7 +2056,7 @@ def test_nms_rotated_0_degree(self, iou): rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] - keep_ref = self._reference_horizontal_nms(boxes, scores, iou) + keep_ref = TestNMS._reference_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) assert self._nms_edit_distance(keep, keep_ref) <= 1 @@ -2095,7 +2072,7 @@ def test_nms_rotated_90_degrees(self, iou): rotated_boxes[:, 3] = boxes[:, 2] - boxes[:, 0] rotated_boxes[:, 4] = 90 - keep_ref = self._reference_horizontal_nms(boxes, scores, iou) + keep_ref = TestNMS._reference_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) assert self._nms_edit_distance(keep, keep_ref) <= 1 @@ -2110,7 +2087,7 @@ def test_nms_rotated_180_degrees(self, iou): rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] rotated_boxes[:, 4] = 180 - keep_ref = self._reference_horizontal_nms(boxes, scores, iou) + keep_ref = TestNMS._reference_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) assert self._nms_edit_distance(keep, keep_ref) <= 1 From 9c96e2681f5084ff0b24fcdf71708aaaead8d0a4 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Fri, 27 Mar 2026 01:37:32 -0700 Subject: [PATCH 13/25] address the comment --- test/test_ops.py | 45 ++++++++++++--------------------------------- 1 file changed, 12 insertions(+), 33 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 020f9ac0223..f3dc75d020f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2009,36 +2009,6 @@ def test_cuda_cpu_consistency(self): class TestNMSRotated: - @staticmethod - def _nms_edit_distance(keep1, keep2): - """ - Compare the "keep" result of two nms call. - They are allowed to be different in terms of edit distance - due to floating point precision issues, e.g., - if a box happen to have an IoU of 0.5 with another box, - one implementation may choose to keep it while another may discard it. - """ - keep1, keep2 = keep1.cpu(), keep2.cpu() - if torch.equal(keep1, keep2): - # they should be equal most of the time - return 0 - keep1, keep2 = tuple(keep1), tuple(keep2) - m, n = len(keep1), len(keep2) - - # edit distance with DP - f = [np.arange(n + 1), np.arange(n + 1)] - for i in range(m): - cur_row = i % 2 - other_row = (i + 1) % 2 - f[other_row][0] = i + 1 - for j in range(n): - f[other_row][j + 1] = ( - f[cur_row][j] - if keep1[i] == keep2[j] - else min(min(f[cur_row][j], f[cur_row][j + 1]), f[other_row][j]) + 1 - ) - return f[m % 2][n] - @staticmethod def _create_tensors(N, device="cpu"): boxes = torch.rand(N, 4, device=device) * 200 @@ -2048,6 +2018,7 @@ def _create_tensors(N, device="cpu"): @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_rotated_0_degree(self, iou): + torch.manual_seed(0) N = 1000 boxes, scores = self._create_tensors(N) rotated_boxes = torch.zeros(N, 5) @@ -2058,10 +2029,13 @@ def test_nms_rotated_0_degree(self, iou): keep_ref = TestNMS._reference_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) - assert self._nms_edit_distance(keep, keep_ref) <= 1 + torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) + keep_standard_nms = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_standard_nms, atol=0, rtol=0) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_rotated_90_degrees(self, iou): + torch.manual_seed(0) N = 1000 boxes, scores = self._create_tensors(N) rotated_boxes = torch.zeros(N, 5) @@ -2074,10 +2048,13 @@ def test_nms_rotated_90_degrees(self, iou): keep_ref = TestNMS._reference_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) - assert self._nms_edit_distance(keep, keep_ref) <= 1 + torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) + keep_standard_nms = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_standard_nms, atol=0, rtol=0) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_rotated_180_degrees(self, iou): + torch.manual_seed(0) N = 1000 boxes, scores = self._create_tensors(N) rotated_boxes = torch.zeros(N, 5) @@ -2089,7 +2066,9 @@ def test_nms_rotated_180_degrees(self, iou): keep_ref = TestNMS._reference_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) - assert self._nms_edit_distance(keep, keep_ref) <= 1 + torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) + keep_standard_nms = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_standard_nms, atol=0, rtol=0) def get_boxes(dtype, device): From 4a993c80cff38e7c057c3018ee9d0b37ceb2bedf Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Fri, 27 Mar 2026 01:42:05 -0700 Subject: [PATCH 14/25] change the variable name --- test/test_ops.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index f3dc75d020f..1cd8ee6b7c0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2030,8 +2030,8 @@ def test_nms_rotated_0_degree(self, iou): keep_ref = TestNMS._reference_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) - keep_standard_nms = ops.nms(boxes, scores, iou) - torch.testing.assert_close(keep, keep_standard_nms, atol=0, rtol=0) + keep_non_rotated = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_rotated_90_degrees(self, iou): @@ -2049,8 +2049,8 @@ def test_nms_rotated_90_degrees(self, iou): keep_ref = TestNMS._reference_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) - keep_standard_nms = ops.nms(boxes, scores, iou) - torch.testing.assert_close(keep, keep_standard_nms, atol=0, rtol=0) + keep_non_rotated = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_rotated_180_degrees(self, iou): @@ -2067,8 +2067,8 @@ def test_nms_rotated_180_degrees(self, iou): keep_ref = TestNMS._reference_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) - keep_standard_nms = ops.nms(boxes, scores, iou) - torch.testing.assert_close(keep, keep_standard_nms, atol=0, rtol=0) + keep_non_rotated = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) def get_boxes(dtype, device): From 84e960c4c1bdfa640b0feefd8873f5999a7bf7f5 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Fri, 27 Mar 2026 01:55:05 -0700 Subject: [PATCH 15/25] address more comments on the file torchvision/csrc/ops/cpu/nms_kernel.cpp --- torchvision/csrc/ops/cpu/nms_kernel.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index f8087b67be8..9a5eb4f242f 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -58,7 +58,7 @@ at::Tensor nms_kernel_impl( continue; } - auto ovr = iou_func.compare(j); + auto ovr = iou_func.compute(j); if (ovr > iou_threshold) { suppressed[j] = 1; } @@ -68,7 +68,7 @@ at::Tensor nms_kernel_impl( } template -struct AABBIoU { +struct NonRotatedIoU { const scalar_t* x1; const scalar_t* y1; const scalar_t* x2; @@ -78,7 +78,7 @@ struct AABBIoU { scalar_t ix1, iy1, ix2, iy2, iarea; - AABBIoU(const at::Tensor& dets) { + NonRotatedIoU(const at::Tensor& dets) { x1_t = dets.select(1, 0).contiguous(); y1_t = dets.select(1, 1).contiguous(); x2_t = dets.select(1, 2).contiguous(); @@ -99,7 +99,7 @@ struct AABBIoU { iarea = areas[i]; } - scalar_t compare(int64_t j) const { + scalar_t compute(int64_t j) const { auto xx1 = std::max(ix1, x1[j]); auto yy1 = std::max(iy1, y1[j]); auto xx2 = std::min(ix2, x2[j]); @@ -118,15 +118,15 @@ struct RotatedIoU { RotatedIoU(const at::Tensor& dets) : dets_ptr(&dets) {} - int64_t cached_i; + int64_t i; void set_box(int64_t i) { - cached_i = i; + this->i = i; } - scalar_t compare(int64_t j) const { + scalar_t compute(int64_t j) const { return single_box_iou_rotated( - (*dets_ptr)[cached_i].template data_ptr(), + (*dets_ptr)[i].template data_ptr(), (*dets_ptr)[j].template data_ptr()); } }; @@ -158,7 +158,7 @@ at::Tensor nms_kernel( AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { result = nms_kernel_impl( - dets, scores, iou_threshold, AABBIoU(dets)); + dets, scores, iou_threshold, NonRotatedIoU(dets)); }); return result; } From 9ef0c0b4c89081fd7f455552ec9eddf70c6d6071 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Fri, 27 Mar 2026 03:07:30 -0700 Subject: [PATCH 16/25] add batched_nms and the test --- test/test_ops.py | 18 ++++++++++++++++++ torchvision/ops/boxes.py | 12 ++++++++---- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 1cd8ee6b7c0..74bd3ee3522 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2070,6 +2070,24 @@ def test_nms_rotated_180_degrees(self, iou): keep_non_rotated = ops.nms(boxes, scores, iou) torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_batched_nms_rotated_0_degree(self, iou): + torch.manual_seed(0) + N = 2000 + num_classes = 50 + boxes, scores = self._create_tensors(N) + idxs = torch.randint(0, num_classes, (N,)) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + backup = rotated_boxes.clone() + keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou) + keep = ops.batched_nms(rotated_boxes, scores, idxs, iou) + assert torch.allclose(rotated_boxes, backup) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + def get_boxes(dtype, device): box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 5650d4c69f1..43692538232 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -72,9 +72,9 @@ def batched_nms( will not be applied between elements of different categories. Args: - boxes (Tensor[N, 4]): boxes where NMS will be performed. They - are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and - ``0 <= y1 < y2``. + boxes (Tensor[N, K]): boxes where NMS will be performed. + If K=4, boxes are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + If K=5, boxes are expected to be in ``(cx, cy, w, h, angle)`` format. scores (Tensor[N]): scores for each one of the boxes idxs (Tensor[N]): indices of the categories for each one of the boxes. iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold @@ -109,7 +109,11 @@ def _batched_nms_coordinate_trick( return torch.empty((0,), dtype=torch.int64, device=boxes.device) max_coordinate = boxes.max() offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) - boxes_for_nms = boxes + offsets[:, None] + if boxes.size(-1) == 4: + boxes_for_nms = boxes + offsets[:, None] + else: + boxes_for_nms = boxes.clone() + boxes_for_nms[..., :2] = boxes[..., :2] + offsets[:, None] keep = nms(boxes_for_nms, scores, iou_threshold) return keep From 4da74a89c42aa4802175abd896896b176cfd6949 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 31 Mar 2026 10:56:30 -0700 Subject: [PATCH 17/25] Rename _reference_nms to _reference_aligned_nms --- test/test_ops.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 74bd3ee3522..3d73e8ca5d4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -771,7 +771,7 @@ def test_is_leaf_node(self, device): class TestNMS: @classmethod - def _reference_nms(cls, boxes, scores, iou_threshold): + def _reference_aligned_nms(cls, boxes, scores, iou_threshold): """ Args: boxes: boxes in corner-form @@ -819,7 +819,7 @@ def test_nms_ref(self, iou, seed): torch.random.manual_seed(seed) err_msg = "NMS incompatible between CPU and reference implementation for IoU={}" boxes, scores = self._create_tensors_with_iou(1000, iou) - keep_ref = self._reference_nms(boxes, scores, iou) + keep_ref = self._reference_aligned_nms(boxes, scores, iou) keep = ops.nms(boxes, scores, iou) torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou)) @@ -2027,7 +2027,7 @@ def test_nms_rotated_0_degree(self, iou): rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] - keep_ref = TestNMS._reference_nms(boxes, scores, iou) + keep_ref = TestNMS._reference_aligned_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) keep_non_rotated = ops.nms(boxes, scores, iou) @@ -2046,7 +2046,7 @@ def test_nms_rotated_90_degrees(self, iou): rotated_boxes[:, 3] = boxes[:, 2] - boxes[:, 0] rotated_boxes[:, 4] = 90 - keep_ref = TestNMS._reference_nms(boxes, scores, iou) + keep_ref = TestNMS._reference_aligned_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) keep_non_rotated = ops.nms(boxes, scores, iou) @@ -2064,7 +2064,7 @@ def test_nms_rotated_180_degrees(self, iou): rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] rotated_boxes[:, 4] = 180 - keep_ref = TestNMS._reference_nms(boxes, scores, iou) + keep_ref = TestNMS._reference_aligned_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) keep_non_rotated = ops.nms(boxes, scores, iou) From 2ea193b04cd6db61d2b3bb91e340dd2bdd973207 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 31 Mar 2026 12:28:52 -0700 Subject: [PATCH 18/25] Merge TestNMSRotated into TestNMS --- test/test_ops.py | 160 +++++++++++++++++++++++------------------------ 1 file changed, 78 insertions(+), 82 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 3d73e8ca5d4..45e6434cef5 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -770,7 +770,6 @@ def test_is_leaf_node(self, device): class TestNMS: - @classmethod def _reference_aligned_nms(cls, boxes, scores, iou_threshold): """ Args: @@ -946,6 +945,84 @@ def test_batched_nms_implementations(self, seed): empty = torch.empty((0,), dtype=torch.int64) torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None)) + def _create_tensors(self, N, device="cpu"): + boxes = torch.rand(N, 4, device=device) * 200 + boxes[:, 2:] += boxes[:, :2] + scores = torch.rand(N, device=device) + return boxes, scores + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_nms_rotated_0_degree(self, iou): + torch.manual_seed(0) + N = 1000 + boxes, scores = self._create_tensors(N) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + + keep_ref = self._reference_aligned_nms(boxes, scores, iou) + keep = ops.nms(rotated_boxes, scores, iou) + torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) + keep_non_rotated = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_nms_rotated_90_degrees(self, iou): + torch.manual_seed(0) + N = 1000 + boxes, scores = self._create_tensors(N) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + # Swap width and height for 90 degrees so reference horizontal NMS can be used + rotated_boxes[:, 2] = boxes[:, 3] - boxes[:, 1] + rotated_boxes[:, 3] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 4] = 90 + + keep_ref = self._reference_aligned_nms(boxes, scores, iou) + keep = ops.nms(rotated_boxes, scores, iou) + torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) + keep_non_rotated = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_nms_rotated_180_degrees(self, iou): + torch.manual_seed(0) + N = 1000 + boxes, scores = self._create_tensors(N) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + rotated_boxes[:, 4] = 180 + + keep_ref = self._reference_aligned_nms(boxes, scores, iou) + keep = ops.nms(rotated_boxes, scores, iou) + torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) + keep_non_rotated = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_batched_nms_rotated_0_degree(self, iou): + torch.manual_seed(0) + N = 2000 + num_classes = 50 + boxes, scores = self._create_tensors(N) + idxs = torch.randint(0, num_classes, (N,)) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + backup = rotated_boxes.clone() + keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou) + keep = ops.batched_nms(rotated_boxes, scores, idxs, iou) + assert torch.allclose(rotated_boxes, backup) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + optests.generate_opcheck_tests( testcase=TestNMS, @@ -2008,87 +2085,6 @@ def test_cuda_cpu_consistency(self): torch.testing.assert_close(iou_cpu, iou_cuda.cpu(), atol=1e-5, rtol=1e-5) -class TestNMSRotated: - @staticmethod - def _create_tensors(N, device="cpu"): - boxes = torch.rand(N, 4, device=device) * 200 - boxes[:, 2:] += boxes[:, :2] - scores = torch.rand(N, device=device) - return boxes, scores - - @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_nms_rotated_0_degree(self, iou): - torch.manual_seed(0) - N = 1000 - boxes, scores = self._create_tensors(N) - rotated_boxes = torch.zeros(N, 5) - rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 - rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 - rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] - rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] - - keep_ref = TestNMS._reference_aligned_nms(boxes, scores, iou) - keep = ops.nms(rotated_boxes, scores, iou) - torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) - keep_non_rotated = ops.nms(boxes, scores, iou) - torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) - - @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_nms_rotated_90_degrees(self, iou): - torch.manual_seed(0) - N = 1000 - boxes, scores = self._create_tensors(N) - rotated_boxes = torch.zeros(N, 5) - rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 - rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 - # Swap width and height for 90 degrees so reference horizontal NMS can be used - rotated_boxes[:, 2] = boxes[:, 3] - boxes[:, 1] - rotated_boxes[:, 3] = boxes[:, 2] - boxes[:, 0] - rotated_boxes[:, 4] = 90 - - keep_ref = TestNMS._reference_aligned_nms(boxes, scores, iou) - keep = ops.nms(rotated_boxes, scores, iou) - torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) - keep_non_rotated = ops.nms(boxes, scores, iou) - torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) - - @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_nms_rotated_180_degrees(self, iou): - torch.manual_seed(0) - N = 1000 - boxes, scores = self._create_tensors(N) - rotated_boxes = torch.zeros(N, 5) - rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 - rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 - rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] - rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] - rotated_boxes[:, 4] = 180 - - keep_ref = TestNMS._reference_aligned_nms(boxes, scores, iou) - keep = ops.nms(rotated_boxes, scores, iou) - torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) - keep_non_rotated = ops.nms(boxes, scores, iou) - torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) - - @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_batched_nms_rotated_0_degree(self, iou): - torch.manual_seed(0) - N = 2000 - num_classes = 50 - boxes, scores = self._create_tensors(N) - idxs = torch.randint(0, num_classes, (N,)) - rotated_boxes = torch.zeros(N, 5) - rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 - rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 - rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] - rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] - backup = rotated_boxes.clone() - keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou) - keep = ops.batched_nms(rotated_boxes, scores, idxs, iou) - assert torch.allclose(rotated_boxes, backup) - torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) - - def get_boxes(dtype, device): box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) From c249d352e9e1118ebf221d23c8af487700ce36a3 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 31 Mar 2026 21:30:24 -0700 Subject: [PATCH 19/25] combine tests, refactor duplicated parts and add a test for NMS with different rotation angles per box --- test/test_ops.py | 83 +++++++++++++++++------------------------------- 1 file changed, 30 insertions(+), 53 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 45e6434cef5..5dc777a208d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -945,60 +945,29 @@ def test_batched_nms_implementations(self, seed): empty = torch.empty((0,), dtype=torch.int64) torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None)) - def _create_tensors(self, N, device="cpu"): + def _create_rotated_boxes(self, N, angle=0, device="cpu"): boxes = torch.rand(N, 4, device=device) * 200 boxes[:, 2:] += boxes[:, :2] scores = torch.rand(N, device=device) - return boxes, scores - - @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_nms_rotated_0_degree(self, iou): - torch.manual_seed(0) - N = 1000 - boxes, scores = self._create_tensors(N) - rotated_boxes = torch.zeros(N, 5) - rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 - rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 - rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] - rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] - - keep_ref = self._reference_aligned_nms(boxes, scores, iou) - keep = ops.nms(rotated_boxes, scores, iou) - torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) - keep_non_rotated = ops.nms(boxes, scores, iou) - torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) - - @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_nms_rotated_90_degrees(self, iou): - torch.manual_seed(0) - N = 1000 - boxes, scores = self._create_tensors(N) - rotated_boxes = torch.zeros(N, 5) - rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 - rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 - # Swap width and height for 90 degrees so reference horizontal NMS can be used - rotated_boxes[:, 2] = boxes[:, 3] - boxes[:, 1] - rotated_boxes[:, 3] = boxes[:, 2] - boxes[:, 0] - rotated_boxes[:, 4] = 90 - - keep_ref = self._reference_aligned_nms(boxes, scores, iou) - keep = ops.nms(rotated_boxes, scores, iou) - torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) - keep_non_rotated = ops.nms(boxes, scores, iou) - torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + cxcywh = ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh") + rotated_boxes = torch.zeros(N, 5, device=device) + rotated_boxes[:, 0] = cxcywh[:, 0] + rotated_boxes[:, 1] = cxcywh[:, 1] + if angle == 90: + rotated_boxes[:, 2] = cxcywh[:, 3] + rotated_boxes[:, 3] = cxcywh[:, 2] + else: + rotated_boxes[:, 2] = cxcywh[:, 2] + rotated_boxes[:, 3] = cxcywh[:, 3] + rotated_boxes[:, 4] = angle + return boxes, rotated_boxes, scores @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_nms_rotated_180_degrees(self, iou): + @pytest.mark.parametrize("angle", (0, 90, 180)) + def test_nms_rotated(self, iou, angle): torch.manual_seed(0) N = 1000 - boxes, scores = self._create_tensors(N) - rotated_boxes = torch.zeros(N, 5) - rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 - rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 - rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] - rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] - rotated_boxes[:, 4] = 180 - + boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle) keep_ref = self._reference_aligned_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) @@ -1010,19 +979,27 @@ def test_batched_nms_rotated_0_degree(self, iou): torch.manual_seed(0) N = 2000 num_classes = 50 - boxes, scores = self._create_tensors(N) + boxes, rotated_boxes, scores = self._create_rotated_boxes(N) idxs = torch.randint(0, num_classes, (N,)) - rotated_boxes = torch.zeros(N, 5) - rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 - rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 - rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] - rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] backup = rotated_boxes.clone() keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou) keep = ops.batched_nms(rotated_boxes, scores, idxs, iou) assert torch.allclose(rotated_boxes, backup) torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_nms_rotated_different_angles(self, iou): + torch.manual_seed(0) + N = 1000 + boxes, rotated_boxes, scores = self._create_rotated_boxes(N) + rotated_boxes[:, 4] = torch.rand(N) * 360 + keep = ops.nms(rotated_boxes, scores, iou) + assert keep.dtype == torch.int64 + assert keep.dim() == 1 + assert keep.numel() <= N + assert (keep >= 0).all() and (keep < N).all() + assert (scores[keep][:-1] >= scores[keep][1:]).all() + optests.generate_opcheck_tests( testcase=TestNMS, From 50afe5c51b6332966b39f7f4b0b81fee686573d8 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 31 Mar 2026 22:33:12 -0700 Subject: [PATCH 20/25] add test_nms_rotated_specific_angles --- test/test_ops.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 5dc777a208d..1986db7776b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1000,6 +1000,21 @@ def test_nms_rotated_different_angles(self, iou): assert (keep >= 0).all() and (keep < N).all() assert (scores[keep][:-1] >= scores[keep][1:]).all() + def test_nms_rotated_specific_angles(self): + boxes = torch.tensor( + [ + [0, 0, 10, 10, 0], + [0, 0, 10, 10, 45], + [100, 100, 10, 10, 30], + ], + dtype=torch.float32, + ) + scores = torch.tensor([0.9, 0.8, 0.7]) + keep = ops.nms(boxes, scores, iou_threshold=0.5) + assert 0 in keep.tolist() + assert 1 not in keep.tolist() + assert 2 in keep.tolist() + optests.generate_opcheck_tests( testcase=TestNMS, From 825947104d24c76e92487b73a4c1c33986bbbdba Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Tue, 31 Mar 2026 23:57:49 -0700 Subject: [PATCH 21/25] parametrize the existing test_batched_nms_implementations test over the format --- test/test_ops.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 1986db7776b..eed54a7d29e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -920,19 +920,23 @@ def test_nms_float16(self, device): assert_equal(keep32, keep16) @pytest.mark.parametrize("seed", range(10)) + @pytest.mark.parametrize("rotated", (False, True)) @pytest.mark.opcheck_only_one() - def test_batched_nms_implementations(self, seed): + def test_batched_nms_implementations(self, seed, rotated): """Make sure that both implementations of batched_nms yield identical results""" torch.random.manual_seed(seed) num_boxes = 1000 iou_threshold = 0.9 - boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1) - assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2 - assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2 + if rotated: + _, boxes, scores = self._create_rotated_boxes(num_boxes) + else: + boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1) + assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2 + assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2 + scores = torch.rand(num_boxes) - scores = torch.rand(num_boxes) idxs = torch.randint(0, 4, size=(num_boxes,)) keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold) keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) From bdd502deab8840b8dbfe7493566cbf0e64b4f5a9 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Thu, 2 Apr 2026 21:50:33 -0700 Subject: [PATCH 22/25] address the comments by using the torch.testing.assert_close and remove boxes in test_nms_rotated_different_angles function --- test/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index eed54a7d29e..294619c36ef 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -988,14 +988,14 @@ def test_batched_nms_rotated_0_degree(self, iou): backup = rotated_boxes.clone() keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou) keep = ops.batched_nms(rotated_boxes, scores, idxs, iou) - assert torch.allclose(rotated_boxes, backup) + torch.testing.assert_close(rotated_boxes, backup) torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_rotated_different_angles(self, iou): torch.manual_seed(0) N = 1000 - boxes, rotated_boxes, scores = self._create_rotated_boxes(N) + _, rotated_boxes, scores = self._create_rotated_boxes(N) rotated_boxes[:, 4] = torch.rand(N) * 360 keep = ops.nms(rotated_boxes, scores, iou) assert keep.dtype == torch.int64 From 7f5d72b6500fe2d351568e023af5522804ce41d5 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Fri, 3 Apr 2026 00:20:40 -0700 Subject: [PATCH 23/25] parametrizing the angles for the test_batched_nms_rotated function --- test/test_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 294619c36ef..61d98af4c92 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -979,11 +979,12 @@ def test_nms_rotated(self, iou, angle): torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_batched_nms_rotated_0_degree(self, iou): + @pytest.mark.parametrize("angle", (0, 90, 180)) + def test_batched_nms_rotated(self, iou, angle): torch.manual_seed(0) N = 2000 num_classes = 50 - boxes, rotated_boxes, scores = self._create_rotated_boxes(N) + boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle) idxs = torch.randint(0, num_classes, (N,)) backup = rotated_boxes.clone() keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou) From aea073c4fcec26f27151311a9f51322bb0883266 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Fri, 3 Apr 2026 00:43:17 -0700 Subject: [PATCH 24/25] add explanation for angle = 90 and move it out of the _create_rotated_boxes --- test/test_ops.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 61d98af4c92..74e870b1f65 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -955,14 +955,7 @@ def _create_rotated_boxes(self, N, angle=0, device="cpu"): scores = torch.rand(N, device=device) cxcywh = ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh") rotated_boxes = torch.zeros(N, 5, device=device) - rotated_boxes[:, 0] = cxcywh[:, 0] - rotated_boxes[:, 1] = cxcywh[:, 1] - if angle == 90: - rotated_boxes[:, 2] = cxcywh[:, 3] - rotated_boxes[:, 3] = cxcywh[:, 2] - else: - rotated_boxes[:, 2] = cxcywh[:, 2] - rotated_boxes[:, 3] = cxcywh[:, 3] + rotated_boxes[:, :4] = cxcywh rotated_boxes[:, 4] = angle return boxes, rotated_boxes, scores @@ -972,6 +965,13 @@ def test_nms_rotated(self, iou, angle): torch.manual_seed(0) N = 1000 boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle) + if angle == 90: + # widths and heights are intentionally swapped here for 90 degrees case + # so that the reference horizontal nms could be used + rotated_boxes[:, 2], rotated_boxes[:, 3] = ( + rotated_boxes[:, 3].clone(), + rotated_boxes[:, 2].clone(), + ) keep_ref = self._reference_aligned_nms(boxes, scores, iou) keep = ops.nms(rotated_boxes, scores, iou) torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) @@ -985,6 +985,13 @@ def test_batched_nms_rotated(self, iou, angle): N = 2000 num_classes = 50 boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle) + if angle == 90: + # widths and heights are intentionally swapped here for 90 degrees case + # so that the reference horizontal nms could be used + rotated_boxes[:, 2], rotated_boxes[:, 3] = ( + rotated_boxes[:, 3].clone(), + rotated_boxes[:, 2].clone(), + ) idxs = torch.randint(0, num_classes, (N,)) backup = rotated_boxes.clone() keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou) From 2d22de0b86512b690110f0a249a95c270652cbe4 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Thu, 16 Apr 2026 00:49:25 -0700 Subject: [PATCH 25/25] Rotated bounding box NMS implementation for GPU --- test/test_ops.py | 35 ++++--- torchvision/csrc/ops/cuda/nms_kernel.cu | 133 ++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 15 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 74e870b1f65..ce1403ff16c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -961,10 +961,11 @@ def _create_rotated_boxes(self, N, angle=0, device="cpu"): @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("angle", (0, 90, 180)) - def test_nms_rotated(self, iou, angle): + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_nms_rotated(self, iou, angle, device): torch.manual_seed(0) N = 1000 - boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle) + boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle, device=device) if angle == 90: # widths and heights are intentionally swapped here for 90 degrees case # so that the reference horizontal nms could be used @@ -972,19 +973,20 @@ def test_nms_rotated(self, iou, angle): rotated_boxes[:, 3].clone(), rotated_boxes[:, 2].clone(), ) - keep_ref = self._reference_aligned_nms(boxes, scores, iou) + keep_ref = self._reference_aligned_nms(boxes.cpu(), scores.cpu(), iou) keep = ops.nms(rotated_boxes, scores, iou) - torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) + torch.testing.assert_close(keep.cpu(), keep_ref, atol=0, rtol=0) keep_non_rotated = ops.nms(boxes, scores, iou) - torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + torch.testing.assert_close(keep.cpu(), keep_non_rotated.cpu(), atol=0, rtol=0) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("angle", (0, 90, 180)) - def test_batched_nms_rotated(self, iou, angle): + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_batched_nms_rotated(self, iou, angle, device): torch.manual_seed(0) N = 2000 num_classes = 50 - boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle) + boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle, device=device) if angle == 90: # widths and heights are intentionally swapped here for 90 degrees case # so that the reference horizontal nms could be used @@ -992,27 +994,29 @@ def test_batched_nms_rotated(self, iou, angle): rotated_boxes[:, 3].clone(), rotated_boxes[:, 2].clone(), ) - idxs = torch.randint(0, num_classes, (N,)) + idxs = torch.randint(0, num_classes, (N,), device=device) backup = rotated_boxes.clone() keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou) keep = ops.batched_nms(rotated_boxes, scores, idxs, iou) torch.testing.assert_close(rotated_boxes, backup) - torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + torch.testing.assert_close(keep.cpu(), keep_non_rotated.cpu(), atol=0, rtol=0) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) - def test_nms_rotated_different_angles(self, iou): + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_nms_rotated_different_angles(self, iou, device): torch.manual_seed(0) N = 1000 - _, rotated_boxes, scores = self._create_rotated_boxes(N) - rotated_boxes[:, 4] = torch.rand(N) * 360 + _, rotated_boxes, scores = self._create_rotated_boxes(N, device=device) + rotated_boxes[:, 4] = torch.rand(N, device=device) * 360 keep = ops.nms(rotated_boxes, scores, iou) assert keep.dtype == torch.int64 assert keep.dim() == 1 assert keep.numel() <= N - assert (keep >= 0).all() and (keep < N).all() + assert (keep.cpu() >= 0).all() and (keep.cpu() < N).all() assert (scores[keep][:-1] >= scores[keep][1:]).all() - def test_nms_rotated_specific_angles(self): + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_nms_rotated_specific_angles(self, device): boxes = torch.tensor( [ [0, 0, 10, 10, 0], @@ -1020,8 +1024,9 @@ def test_nms_rotated_specific_angles(self): [100, 100, 10, 10, 30], ], dtype=torch.float32, + device=device, ) - scores = torch.tensor([0.9, 0.8, 0.7]) + scores = torch.tensor([0.9, 0.8, 0.7], device=device) keep = ops.nms(boxes, scores, iou_threshold=0.5) assert 0 in keep.tolist() assert 1 not in keep.tolist() diff --git a/torchvision/csrc/ops/cuda/nms_kernel.cu b/torchvision/csrc/ops/cuda/nms_kernel.cu index 44ce8db6b8e..5221ce8fc65 100644 --- a/torchvision/csrc/ops/cuda/nms_kernel.cu +++ b/torchvision/csrc/ops/cuda/nms_kernel.cu @@ -4,6 +4,7 @@ #include #include +#include "../box_iou_rotated_utils.h" #include "cuda_helpers.h" namespace vision { @@ -199,10 +200,142 @@ at::Tensor nms_kernel( return order_t.masked_select(keep); } +template +__global__ void nms_rotated_kernel_impl( + int n_boxes, + double iou_threshold, + const T* dev_boxes, + unsigned long long* dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + // Compared to nms_kernel_impl, where each box is represented with 4 values + // (x1, y1, x2, y2), each rotated box is represented with 5 values + // (x_center, y_center, width, height, angle_degrees) here. + __shared__ T block_boxes[threadsPerBlock * 5]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const T* cur_box = dev_boxes + cur_box_idx * 5; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + // Instead of devIoU used by nms_kernel_impl, here + // we use the single_box_iou_rotated function from box_iou_rotated_utils.h + if (single_box_iou_rotated(cur_box, block_boxes + i * 5) > + iou_threshold) { + t |= 1ULL << i; + } + } + const int col_blocks = ceil_div(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +at::Tensor nms_rotated_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor"); + TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor"); + TORCH_CHECK( + dets.scalar_type() == scores.scalar_type(), + "dets and scores must have the same dtype"); + + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 5, + "boxes should have 5 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); + + at::cuda::CUDAGuard device_guard(dets.device()); + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong)); + } + + auto order_t = std::get<1>( + scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + auto dets_sorted = dets.index_select(0, order_t).contiguous(); + + int dets_num = dets.size(0); + + const int col_blocks = ceil_div(dets_num, threadsPerBlock); + + at::Tensor mask = + at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + + dim3 blocks(col_blocks, col_blocks); + dim3 threads(threadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES( + dets_sorted.scalar_type(), "nms_rotated_kernel", [&] { + nms_rotated_kernel_impl<<>>( + dets_num, + iou_threshold, + dets_sorted.data_ptr(), + (unsigned long long*)mask.data_ptr()); + }); + + at::Tensor keep = + at::zeros({dets_num}, dets.options().dtype(at::kBool).device(at::kCUDA)); + + gather_keep_from_mask<<< + 1, + min(col_blocks, threadsPerBlock), + col_blocks * sizeof(unsigned long long), + stream>>>( + keep.data_ptr(), + (unsigned long long*)mask.data_ptr(), + dets_num); + + AT_CUDA_CHECK(cudaGetLastError()); + return order_t.masked_select(keep); +} + } // namespace TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::nms_rotated"), + TORCH_FN(nms_rotated_kernel)); } } // namespace ops