diff --git a/test/test_ops.py b/test/test_ops.py index 9521f21a815..ce1403ff16c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -770,7 +770,7 @@ def test_is_leaf_node(self, device): class TestNMS: - def _reference_nms(self, boxes, scores, iou_threshold): + def _reference_aligned_nms(cls, boxes, scores, iou_threshold): """ Args: boxes: boxes in corner-form @@ -818,15 +818,15 @@ def test_nms_ref(self, iou, seed): torch.random.manual_seed(seed) err_msg = "NMS incompatible between CPU and reference implementation for IoU={}" boxes, scores = self._create_tensors_with_iou(1000, iou) - keep_ref = self._reference_nms(boxes, scores, iou) + keep_ref = self._reference_aligned_nms(boxes, scores, iou) keep = ops.nms(boxes, scores, iou) torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou)) def test_nms_input_errors(self): with pytest.raises(RuntimeError): ops.nms(torch.rand(4), torch.rand(3), 0.5) - with pytest.raises(RuntimeError): - ops.nms(torch.rand(3, 5), torch.rand(3), 0.5) + with pytest.raises((RuntimeError, ValueError)): + ops.nms(torch.rand(3, 6), torch.rand(3), 0.5) with pytest.raises(RuntimeError): ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5) with pytest.raises(RuntimeError): @@ -920,19 +920,23 @@ def test_nms_float16(self, device): assert_equal(keep32, keep16) @pytest.mark.parametrize("seed", range(10)) + @pytest.mark.parametrize("rotated", (False, True)) @pytest.mark.opcheck_only_one() - def test_batched_nms_implementations(self, seed): + def test_batched_nms_implementations(self, seed, rotated): """Make sure that both implementations of batched_nms yield identical results""" torch.random.manual_seed(seed) num_boxes = 1000 iou_threshold = 0.9 - boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1) - assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2 - assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2 + if rotated: + _, boxes, scores = self._create_rotated_boxes(num_boxes) + else: + boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1) + assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2 + assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2 + scores = torch.rand(num_boxes) - scores = torch.rand(num_boxes) idxs = torch.randint(0, 4, size=(num_boxes,)) keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold) keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold) @@ -945,6 +949,89 @@ def test_batched_nms_implementations(self, seed): empty = torch.empty((0,), dtype=torch.int64) torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None)) + def _create_rotated_boxes(self, N, angle=0, device="cpu"): + boxes = torch.rand(N, 4, device=device) * 200 + boxes[:, 2:] += boxes[:, :2] + scores = torch.rand(N, device=device) + cxcywh = ops.box_convert(boxes, in_fmt="xyxy", out_fmt="cxcywh") + rotated_boxes = torch.zeros(N, 5, device=device) + rotated_boxes[:, :4] = cxcywh + rotated_boxes[:, 4] = angle + return boxes, rotated_boxes, scores + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + @pytest.mark.parametrize("angle", (0, 90, 180)) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_nms_rotated(self, iou, angle, device): + torch.manual_seed(0) + N = 1000 + boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle, device=device) + if angle == 90: + # widths and heights are intentionally swapped here for 90 degrees case + # so that the reference horizontal nms could be used + rotated_boxes[:, 2], rotated_boxes[:, 3] = ( + rotated_boxes[:, 3].clone(), + rotated_boxes[:, 2].clone(), + ) + keep_ref = self._reference_aligned_nms(boxes.cpu(), scores.cpu(), iou) + keep = ops.nms(rotated_boxes, scores, iou) + torch.testing.assert_close(keep.cpu(), keep_ref, atol=0, rtol=0) + keep_non_rotated = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep.cpu(), keep_non_rotated.cpu(), atol=0, rtol=0) + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + @pytest.mark.parametrize("angle", (0, 90, 180)) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_batched_nms_rotated(self, iou, angle, device): + torch.manual_seed(0) + N = 2000 + num_classes = 50 + boxes, rotated_boxes, scores = self._create_rotated_boxes(N, angle=angle, device=device) + if angle == 90: + # widths and heights are intentionally swapped here for 90 degrees case + # so that the reference horizontal nms could be used + rotated_boxes[:, 2], rotated_boxes[:, 3] = ( + rotated_boxes[:, 3].clone(), + rotated_boxes[:, 2].clone(), + ) + idxs = torch.randint(0, num_classes, (N,), device=device) + backup = rotated_boxes.clone() + keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou) + keep = ops.batched_nms(rotated_boxes, scores, idxs, iou) + torch.testing.assert_close(rotated_boxes, backup) + torch.testing.assert_close(keep.cpu(), keep_non_rotated.cpu(), atol=0, rtol=0) + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_nms_rotated_different_angles(self, iou, device): + torch.manual_seed(0) + N = 1000 + _, rotated_boxes, scores = self._create_rotated_boxes(N, device=device) + rotated_boxes[:, 4] = torch.rand(N, device=device) * 360 + keep = ops.nms(rotated_boxes, scores, iou) + assert keep.dtype == torch.int64 + assert keep.dim() == 1 + assert keep.numel() <= N + assert (keep.cpu() >= 0).all() and (keep.cpu() < N).all() + assert (scores[keep][:-1] >= scores[keep][1:]).all() + + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_nms_rotated_specific_angles(self, device): + boxes = torch.tensor( + [ + [0, 0, 10, 10, 0], + [0, 0, 10, 10, 45], + [100, 100, 10, 10, 30], + ], + dtype=torch.float32, + device=device, + ) + scores = torch.tensor([0.9, 0.8, 0.7], device=device) + keep = ops.nms(boxes, scores, iou_threshold=0.5) + assert 0 in keep.tolist() + assert 1 not in keep.tolist() + assert 2 in keep.tolist() + optests.generate_opcheck_tests( testcase=TestNMS, diff --git a/torchvision/_autograd_registrations.py b/torchvision/_autograd_registrations.py index 18d9ced6c54..564657ee35a 100644 --- a/torchvision/_autograd_registrations.py +++ b/torchvision/_autograd_registrations.py @@ -235,6 +235,15 @@ def _autocast_nms(dets, scores, iou_threshold): ) +def _autocast_nms_rotated(dets, scores, iou_threshold): + with torch._C._ExcludeDispatchKeyGuard(_all_autocast_keys): + return torch.ops.torchvision.nms_rotated( + _autocast_cast(dets), + _autocast_cast(scores), + iou_threshold, + ) + + def _autocast_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): orig_dtype = input.dtype with torch._C._ExcludeDispatchKeyGuard(_all_autocast_keys): @@ -358,6 +367,7 @@ def _autocast_deform_conv2d( # nms and roi_align: registered for all autocast device types for _key in ("AutocastCUDA", "AutocastCPU", "AutocastXPU"): _autocast_lib.impl("nms", _autocast_nms, _key) + _autocast_lib.impl("nms_rotated", _autocast_nms_rotated, _key) _autocast_lib.impl("roi_align", _autocast_roi_align, _key) # Other ops: CUDA autocast only diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index f75bfb77a7f..e7a183250e4 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -174,6 +174,20 @@ def meta_nms(dets, scores, iou_threshold): return dets.new_empty(num_to_keep, dtype=torch.long) +@torch.library.register_fake("torchvision::nms_rotated") +def meta_nms_rotated(dets, scores, iou_threshold): + torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") + torch._check(dets.size(1) == 5, lambda: f"boxes should have 5 elements in dimension 1, got {dets.size(1)}") + torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}") + torch._check( + dets.size(0) == scores.size(0), + lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}", + ) + ctx = torch._custom_ops.get_ctx() + num_to_keep = ctx.create_unbacked_symint() + return dets.new_empty(num_to_keep, dtype=torch.long) + + @register_meta("deform_conv2d") def meta_deform_conv2d( input, diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 454ce118a6d..9a5eb4f242f 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -1,16 +1,25 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + #include #include +#include "../box_iou_rotated_utils.h" + namespace vision { namespace ops { namespace { -template +template at::Tensor nms_kernel_impl( const at::Tensor& dets, const at::Tensor& scores, - double iou_threshold) { + double iou_threshold, + IoUFunc iou_func) { TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); TORCH_CHECK( @@ -21,13 +30,6 @@ at::Tensor nms_kernel_impl( return at::empty({0}, dets.options().dtype(at::kLong)); } - auto x1_t = dets.select(1, 0).contiguous(); - auto y1_t = dets.select(1, 1).contiguous(); - auto x2_t = dets.select(1, 2).contiguous(); - auto y2_t = dets.select(1, 3).contiguous(); - - at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); - auto order_t = std::get<1>( scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); @@ -38,11 +40,6 @@ at::Tensor nms_kernel_impl( auto suppressed = suppressed_t.data_ptr(); auto keep = keep_t.data_ptr(); auto order = order_t.data_ptr(); - auto x1 = x1_t.data_ptr(); - auto y1 = y1_t.data_ptr(); - auto x2 = x2_t.data_ptr(); - auto y2 = y2_t.data_ptr(); - auto areas = areas_t.data_ptr(); int64_t num_to_keep = 0; @@ -52,26 +49,16 @@ at::Tensor nms_kernel_impl( continue; } keep[num_to_keep++] = i; - auto ix1 = x1[i]; - auto iy1 = y1[i]; - auto ix2 = x2[i]; - auto iy2 = y2[i]; - auto iarea = areas[i]; + + iou_func.set_box(i); for (int64_t _j = _i + 1; _j < ndets; _j++) { auto j = order[_j]; if (suppressed[j] == 1) { continue; } - auto xx1 = std::max(ix1, x1[j]); - auto yy1 = std::max(iy1, y1[j]); - auto xx2 = std::min(ix2, x2[j]); - auto yy2 = std::min(iy2, y2[j]); - - auto w = std::max(static_cast(0), xx2 - xx1); - auto h = std::max(static_cast(0), yy2 - yy1); - auto inter = w * h; - auto ovr = inter / (iarea + areas[j] - inter); + + auto ovr = iou_func.compute(j); if (ovr > iou_threshold) { suppressed[j] = 1; } @@ -80,6 +67,70 @@ at::Tensor nms_kernel_impl( return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); } +template +struct NonRotatedIoU { + const scalar_t* x1; + const scalar_t* y1; + const scalar_t* x2; + const scalar_t* y2; + const scalar_t* areas; + at::Tensor x1_t, y1_t, x2_t, y2_t, areas_t; + + scalar_t ix1, iy1, ix2, iy2, iarea; + + NonRotatedIoU(const at::Tensor& dets) { + x1_t = dets.select(1, 0).contiguous(); + y1_t = dets.select(1, 1).contiguous(); + x2_t = dets.select(1, 2).contiguous(); + y2_t = dets.select(1, 3).contiguous(); + areas_t = (x2_t - x1_t) * (y2_t - y1_t); + x1 = x1_t.data_ptr(); + y1 = y1_t.data_ptr(); + x2 = x2_t.data_ptr(); + y2 = y2_t.data_ptr(); + areas = areas_t.data_ptr(); + } + + void set_box(int64_t i) { + ix1 = x1[i]; + iy1 = y1[i]; + ix2 = x2[i]; + iy2 = y2[i]; + iarea = areas[i]; + } + + scalar_t compute(int64_t j) const { + auto xx1 = std::max(ix1, x1[j]); + auto yy1 = std::max(iy1, y1[j]); + auto xx2 = std::min(ix2, x2[j]); + auto yy2 = std::min(iy2, y2[j]); + + auto w = std::max(static_cast(0), xx2 - xx1); + auto h = std::max(static_cast(0), yy2 - yy1); + auto inter = w * h; + return inter / (iarea + areas[j] - inter); + } +}; + +template +struct RotatedIoU { + const at::Tensor* dets_ptr; + + RotatedIoU(const at::Tensor& dets) : dets_ptr(&dets) {} + + int64_t i; + + void set_box(int64_t i) { + this->i = i; + } + + scalar_t compute(int64_t j) const { + return single_box_iou_rotated( + (*dets_ptr)[i].template data_ptr(), + (*dets_ptr)[j].template data_ptr()); + } +}; + at::Tensor nms_kernel( const at::Tensor& dets, const at::Tensor& scores, @@ -106,7 +157,40 @@ at::Tensor nms_kernel( auto result = at::empty({0}, dets.options()); AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { - result = nms_kernel_impl(dets, scores, iou_threshold); + result = nms_kernel_impl( + dets, scores, iou_threshold, NonRotatedIoU(dets)); + }); + return result; +} + +at::Tensor nms_rotated_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 5, + "boxes should have 5 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); + + auto result = at::empty({0}, dets.options()); + + AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated_kernel", [&] { + result = nms_kernel_impl( + dets, scores, iou_threshold, RotatedIoU(dets)); }); return result; } @@ -115,6 +199,9 @@ at::Tensor nms_kernel( TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::nms_rotated"), + TORCH_FN(nms_rotated_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cuda/nms_kernel.cu b/torchvision/csrc/ops/cuda/nms_kernel.cu index 44ce8db6b8e..5221ce8fc65 100644 --- a/torchvision/csrc/ops/cuda/nms_kernel.cu +++ b/torchvision/csrc/ops/cuda/nms_kernel.cu @@ -4,6 +4,7 @@ #include #include +#include "../box_iou_rotated_utils.h" #include "cuda_helpers.h" namespace vision { @@ -199,10 +200,142 @@ at::Tensor nms_kernel( return order_t.masked_select(keep); } +template +__global__ void nms_rotated_kernel_impl( + int n_boxes, + double iou_threshold, + const T* dev_boxes, + unsigned long long* dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + // Compared to nms_kernel_impl, where each box is represented with 4 values + // (x1, y1, x2, y2), each rotated box is represented with 5 values + // (x_center, y_center, width, height, angle_degrees) here. + __shared__ T block_boxes[threadsPerBlock * 5]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const T* cur_box = dev_boxes + cur_box_idx * 5; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + // Instead of devIoU used by nms_kernel_impl, here + // we use the single_box_iou_rotated function from box_iou_rotated_utils.h + if (single_box_iou_rotated(cur_box, block_boxes + i * 5) > + iou_threshold) { + t |= 1ULL << i; + } + } + const int col_blocks = ceil_div(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +at::Tensor nms_rotated_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor"); + TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor"); + TORCH_CHECK( + dets.scalar_type() == scores.scalar_type(), + "dets and scores must have the same dtype"); + + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 5, + "boxes should have 5 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); + + at::cuda::CUDAGuard device_guard(dets.device()); + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong)); + } + + auto order_t = std::get<1>( + scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + auto dets_sorted = dets.index_select(0, order_t).contiguous(); + + int dets_num = dets.size(0); + + const int col_blocks = ceil_div(dets_num, threadsPerBlock); + + at::Tensor mask = + at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + + dim3 blocks(col_blocks, col_blocks); + dim3 threads(threadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES( + dets_sorted.scalar_type(), "nms_rotated_kernel", [&] { + nms_rotated_kernel_impl<<>>( + dets_num, + iou_threshold, + dets_sorted.data_ptr(), + (unsigned long long*)mask.data_ptr()); + }); + + at::Tensor keep = + at::zeros({dets_num}, dets.options().dtype(at::kBool).device(at::kCUDA)); + + gather_keep_from_mask<<< + 1, + min(col_blocks, threadsPerBlock), + col_blocks * sizeof(unsigned long long), + stream>>>( + keep.data_ptr(), + (unsigned long long*)mask.data_ptr(), + dets_num); + + AT_CUDA_CHECK(cudaGetLastError()); + return order_t.masked_select(keep); +} + } // namespace TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::nms_rotated"), + TORCH_FN(nms_rotated_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/nms_rotated.cpp b/torchvision/csrc/ops/nms_rotated.cpp new file mode 100644 index 00000000000..da619e6c32c --- /dev/null +++ b/torchvision/csrc/ops/nms_rotated.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "nms_rotated.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +at::Tensor nms_rotated( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.nms_rotated.nms_rotated"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::nms_rotated", "") + .typed(); + return op.call(dets, scores, iou_threshold); +} + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::nms_rotated(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/nms_rotated.h b/torchvision/csrc/ops/nms_rotated.h new file mode 100644 index 00000000000..98bc225f691 --- /dev/null +++ b/torchvision/csrc/ops/nms_rotated.h @@ -0,0 +1,21 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API at::Tensor nms_rotated( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold); + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/ops.h b/torchvision/csrc/ops/ops.h index 9902c3b1ecd..173f2ef77b1 100644 --- a/torchvision/csrc/ops/ops.h +++ b/torchvision/csrc/ops/ops.h @@ -3,6 +3,7 @@ #include "box_iou_rotated.h" #include "deform_conv2d.h" #include "nms.h" +#include "nms_rotated.h" #include "ps_roi_align.h" #include "ps_roi_pool.h" #include "roi_align.h" diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index a089af2c4ad..43692538232 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -32,9 +32,12 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: to the behavior of argsort in PyTorch when repeated values are present. Args: - boxes (Tensor[N, 4])): boxes to perform NMS on. They - are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and - ``0 <= y1 < y2``. + boxes (Tensor[N, K])): boxes to perform NMS on. + If K=4, boxes are expected to be in ``(x1, y1, x2, y2)`` format + with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + If K=5, boxes are expected to be in ``(cx, cy, w, h, angle)`` format + for rotated boxes, where ``(cx, cy)`` is the center, ``(w, h)`` is + width and height, and ``angle`` is the rotation angle in degrees. scores (Tensor[N]): scores for each one of the boxes iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold @@ -45,7 +48,15 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(nms) _assert_has_ops() - return torch.ops.torchvision.nms(boxes, scores, iou_threshold) + + if boxes.size(-1) == 4: + return torch.ops.torchvision.nms(boxes, scores, iou_threshold) + elif boxes.size(-1) == 5: + return torch.ops.torchvision.nms_rotated(boxes, scores, iou_threshold) + else: + raise ValueError( + f"boxes should have 4 (axis-aligned) or 5 (rotated) elements in the last dimension, got {boxes.size(-1)}" + ) def batched_nms( @@ -61,9 +72,9 @@ def batched_nms( will not be applied between elements of different categories. Args: - boxes (Tensor[N, 4]): boxes where NMS will be performed. They - are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and - ``0 <= y1 < y2``. + boxes (Tensor[N, K]): boxes where NMS will be performed. + If K=4, boxes are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + If K=5, boxes are expected to be in ``(cx, cy, w, h, angle)`` format. scores (Tensor[N]): scores for each one of the boxes idxs (Tensor[N]): indices of the categories for each one of the boxes. iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold @@ -98,7 +109,11 @@ def _batched_nms_coordinate_trick( return torch.empty((0,), dtype=torch.int64, device=boxes.device) max_coordinate = boxes.max() offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) - boxes_for_nms = boxes + offsets[:, None] + if boxes.size(-1) == 4: + boxes_for_nms = boxes + offsets[:, None] + else: + boxes_for_nms = boxes.clone() + boxes_for_nms[..., :2] = boxes[..., :2] + offsets[:, None] keep = nms(boxes_for_nms, scores, iou_threshold) return keep