From 8f785fba03e00c607550e4587fb8ba896cce99d5 Mon Sep 17 00:00:00 2001 From: Jash Shah Date: Wed, 29 Apr 2026 11:16:49 -0700 Subject: [PATCH] fix(transforms): handle float16/bfloat16 in convert_image_dtype Two bugs when using half-precision floating point dtypes: 1. float16/bfloat16 to integer: 255.999 rounds to 256.0 in float16 due to limited precision, then .to(uint8) wraps to 0. Fix: upcast to float32 before the multiply. 2. int32/int64 to float16: max values exceed float16 range (65504), producing NaN/inf silently. Fix: raise RuntimeError (same pattern as the existing float32->int32 safety check). Fixes #6799. --- test/test_transforms_tensor.py | 19 +++++++++++++++++++ torchvision/transforms/_functional_tensor.py | 6 ++++++ 2 files changed, 25 insertions(+) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index eac52dafc17..628cefb6c11 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -561,6 +561,25 @@ def test_convert_image_dtype_save_load(tmpdir): _test_fn_save_load(fn, tmpdir) +@pytest.mark.parametrize("device", cpu_and_cuda()) +@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("out_dtype", [torch.uint8, torch.int8, torch.int16, torch.int32]) +def test_convert_image_dtype_half_precision(device, in_dtype, out_dtype): + image = torch.tensor([0.0, 0.5, 1.0], dtype=in_dtype, device=device).reshape(1, 1, 3) + result = F.convert_image_dtype(image, out_dtype) + max_val = torch.iinfo(out_dtype).max + assert result[0, 0, 0] == 0, f"0.0 should map to 0, got {result[0, 0, 0]}" + assert result[0, 0, 2] == max_val, f"1.0 should map to {max_val}, got {result[0, 0, 2]}" + + +@pytest.mark.parametrize("device", cpu_and_cuda()) +def test_convert_image_dtype_int_to_float16_raises(device): + for in_dtype in (torch.int32, torch.int64): + image = torch.tensor([0, 1, 2], dtype=in_dtype, device=device).reshape(1, 1, 3) + with pytest.raises(RuntimeError, match=r"cannot be performed safely"): + F.convert_image_dtype(image, torch.float16) + + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("policy", [policy for policy in T.AutoAugmentPolicy]) @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 1a9830450d5..ab7240d59b5 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -85,6 +85,9 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - # when float is exactly 1.0. # `max + 1 - epsilon` provides more evenly distributed mapping of # ranges of floats to ints. + # float16/bfloat16 lack precision: 255.999 rounds to 256.0, overflowing to 0 on .to(uint8). + if image.dtype == torch.float16 or image.dtype == torch.bfloat16: + image = image.to(torch.float32) eps = 1e-3 max_val = float(_max_value(dtype)) result = image.mul(max_val + 1.0 - eps) @@ -95,6 +98,9 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - # int to float # TODO: replace with dtype.is_floating_point when torchscript supports it if torch.tensor(0, dtype=dtype).is_floating_point(): + if dtype == torch.float16 and image.dtype in (torch.int32, torch.int64): + msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely." + raise RuntimeError(msg) image = image.to(dtype) return image / input_max