From 91b61926327676309bb04c49dd8914bebfbad0cf Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Mon, 22 Dec 2025 10:02:23 +0100 Subject: [PATCH 01/14] Add SIMD optimization for int_to_float conversion Add SIMD fast paths for converting custom bit-depth floats to f32: - 32-bit float passthrough: Simple bitcast using SIMD - 16-bit float (f16/half-precision): SIMD conversion with scalar fallback for subnormal values The 16-bit float SIMD path handles normal, zero, and inf/nan cases directly, falling back to scalar for the rare subnormal case which requires variable-iteration normalization. Also adds BitDepth::f16() test helper and comprehensive unit tests for the conversion functions. --- jxl/src/headers/bit_depth.rs | 8 + jxl/src/render/stages/convert.rs | 336 ++++++++++++++++++++++++++++++- 2 files changed, 337 insertions(+), 7 deletions(-) diff --git a/jxl/src/headers/bit_depth.rs b/jxl/src/headers/bit_depth.rs index 898e721e5..e83217dab 100644 --- a/jxl/src/headers/bit_depth.rs +++ b/jxl/src/headers/bit_depth.rs @@ -65,6 +65,14 @@ impl BitDepth { exponent_bits_per_sample: 8, } } + #[cfg(test)] + pub fn f16() -> BitDepth { + BitDepth { + floating_point_sample: true, + bits_per_sample: 16, + exponent_bits_per_sample: 5, + } + } pub fn bits_per_sample(&self) -> u32 { self.bits_per_sample } diff --git a/jxl/src/render/stages/convert.rs b/jxl/src/render/stages/convert.rs index 3ded74f2b..279006758 100644 --- a/jxl/src/render/stages/convert.rs +++ b/jxl/src/render/stages/convert.rs @@ -8,7 +8,7 @@ use crate::{ headers::bit_depth::BitDepth, render::{Channels, ChannelsMut, RenderPipelineInOutStage}, }; -use jxl_simd::{F32SimdVec, simd_function}; +use jxl_simd::{F32SimdVec, I32SimdVec, SimdMask, shl, simd_function}; pub struct ConvertU8F32Stage { channel: usize, @@ -135,20 +135,184 @@ impl std::fmt::Display for ConvertModularToF32Stage { } } +// SIMD 32-bit float passthrough (bitcast i32 to f32) +simd_function!( + int_to_float_32bit_simd_dispatch, + d: D, + fn int_to_float_32bit_simd(input: &[i32], output: &mut [f32], xsize: usize) { + let simd_width = D::I32Vec::LEN; + let num_full_chunks = xsize / simd_width; + + // Process full SIMD chunks + for (in_chunk, out_chunk) in input + .chunks_exact(simd_width) + .zip(output.chunks_exact_mut(simd_width)) + .take(num_full_chunks) + { + let val = D::I32Vec::load(d, in_chunk); + val.bitcast_to_f32().store(out_chunk); + } + + // Handle remainder with scalar + let remainder_start = num_full_chunks * simd_width; + for i in remainder_start..xsize { + output[i] = f32::from_bits(input[i] as u32); + } + } +); + +// SIMD 16-bit float (half-precision) to 32-bit float conversion +// This handles IEEE 754 binary16 format: 1 sign bit, 5 exponent bits, 10 mantissa bits +simd_function!( + int_to_float_16bit_simd_dispatch, + d: D, + fn int_to_float_16bit_simd(input: &[i32], output: &mut [f32], xsize: usize) { + let simd_width = D::I32Vec::LEN; + let num_full_chunks = xsize / simd_width; + + // Constants for 16-bit float (exp_bits=5, mant_bits=10) + let abs_mask = D::I32Vec::splat(d, 0x7FFF); // Mask for absolute value + let exp_mask = D::I32Vec::splat(d, 0x7C00); // Exponent bits in f16 + let mant_mask = D::I32Vec::splat(d, 0x03FF); // Mantissa bits in f16 + let exp_max = D::I32Vec::splat(d, 0x7C00); // Max exponent (inf/nan) + let exp_bias_adjust = D::I32Vec::splat(d, (127 - 15) << 23); // Bias adjustment shifted + let f32_inf_exp = D::I32Vec::splat(d, 0x7F80_0000_u32 as i32); + + for (in_chunk, out_chunk) in input + .chunks_exact(simd_width) + .zip(output.chunks_exact_mut(simd_width)) + .take(num_full_chunks) + { + let val = D::I32Vec::load(d, in_chunk); + + // Extract components + let abs_val = val & abs_mask; // Absolute value (exp + mantissa) + let exp_bits = val & exp_mask; // Exponent bits + let mant_bits = val & mant_mask; // Mantissa bits + + // Check for zero + let is_zero = abs_val.eq_zero(); + + // Check for inf/nan (exponent all 1s) + let is_inf_nan = exp_bits.eq(exp_max); + + // Check for subnormal (exponent is 0 but mantissa non-zero) + // Use andnot: !mant_is_zero & exp_is_zero + let exp_is_zero = exp_bits.eq_zero(); + let mant_is_zero = mant_bits.eq_zero(); + // is_subnormal = exp_is_zero AND NOT mant_is_zero + let is_subnormal = mant_is_zero.andnot(exp_is_zero); + + // Normal case: shift exponent and mantissa, adjust bias + // Sign bit at position 15 goes to position 31: shift left by 16 + // f16 exponent at bits 10-14 goes to f32 exponent at bits 23-30 + // f16 mantissa at bits 0-9 goes to f32 mantissa at bits 13-22 (shift left by 13) + let sign_shifted = shl!(val, 16) & D::I32Vec::splat(d, 0x8000_0000_u32 as i32); + let normal_exp = shl!(exp_bits, 13); + let normal_mant = shl!(mant_bits, 13); + let normal_result = sign_shifted | (normal_exp + exp_bias_adjust) | normal_mant; + + // Inf/NaN case: preserve mantissa pattern, set f32 inf exponent + let inf_nan_result = sign_shifted | f32_inf_exp | normal_mant; + + // Zero case: just the sign bit + let zero_result = sign_shifted; + + // Select result based on conditions + // Start with normal result, then override special cases + let result = is_inf_nan.if_then_else_i32(inf_nan_result, normal_result); + let result = is_zero.if_then_else_i32(zero_result, result); + + // For subnormals, fall back to scalar (rare case) + // maskz_i32 returns 0 where mask is true, so if any subnormal exists, + // there will be a 0 in subnormal_check, meaning eq_zero().all() would be true + // only if ALL elements are subnormal. We want to check if ANY are subnormal. + // So we check the inverse: if NOT eq_zero for all (meaning no subnormals), use SIMD. + let subnormal_check = is_subnormal.maskz_i32(D::I32Vec::splat(d, 1)); + // subnormal_check is 0 where is_subnormal=true, 1 where is_subnormal=false + // If all elements are 1 (no subnormals), eq(splat(1)).all() is true + let no_subnormals = subnormal_check.eq(D::I32Vec::splat(d, 1)); + if no_subnormals.all() { + // No subnormals - use SIMD result + result.bitcast_to_f32().store(out_chunk); + } else { + // At least one subnormal - process this chunk scalar + for (&in_val, out_val) in in_chunk.iter().zip(out_chunk.iter_mut()) { + *out_val = int_to_float_16bit_scalar(in_val); + } + } + } + + // Handle remainder with scalar + let remainder_start = num_full_chunks * simd_width; + for i in remainder_start..xsize { + output[i] = int_to_float_16bit_scalar(input[i]); + } + } +); + +// Scalar fallback for 16-bit float conversion (handles subnormals) +#[inline] +fn int_to_float_16bit_scalar(in_val: i32) -> f32 { + let mut f = in_val as u32; + let signbit = (f >> 15) != 0; + f &= 0x7FFF; + if f == 0 { + return if signbit { -0.0 } else { 0.0 }; + } + let mut exp = (f >> 10) as i32; + let mut mantissa = f & 0x3FF; + if exp == 31 { + // NaN or infinity + f = if signbit { 0x80000000 } else { 0 }; + f |= 0xFF << 23; + f |= mantissa << 13; + return f32::from_bits(f); + } + mantissa <<= 13; + if exp == 0 { + // subnormal number - normalize + while (mantissa & 0x800000) == 0 { + mantissa <<= 1; + exp -= 1; + } + exp += 1; + mantissa &= 0x7fffff; + } + exp = exp - 15 + 127; + f = if signbit { 0x80000000 } else { 0 }; + f |= (exp as u32) << 23; + f |= mantissa; + f32::from_bits(f) +} + // Converts custom [bits]-bit float (with [exp_bits] exponent bits) stored as // int back to binary32 float. -// TODO(sboukortt): SIMD fn int_to_float(input: &[i32], output: &mut [f32], bit_depth: &BitDepth) { assert_eq!(input.len(), output.len()); let bits = bit_depth.bits_per_sample(); let exp_bits = bit_depth.exponent_bits_per_sample(); - if bits == 32 { - assert_eq!(exp_bits, 8); - for (&in_val, out_val) in input.iter().zip(output) { - *out_val = f32::from_bits(in_val as u32); - } + let xsize = input.len(); + + // Use SIMD fast paths for common formats + if bits == 32 && exp_bits == 8 { + // 32-bit float passthrough + int_to_float_32bit_simd_dispatch(input, output, xsize); return; } + + if bits == 16 && exp_bits == 5 { + // IEEE 754 half-precision (f16) - common HDR format + int_to_float_16bit_simd_dispatch(input, output, xsize); + return; + } + + // Generic scalar path for other custom float formats + int_to_float_generic(input, output, bits, exp_bits); +} + +// Generic scalar conversion for arbitrary bit-depth floats +fn int_to_float_generic(input: &[i32], output: &mut [f32], bits: u32, exp_bits: u32) { let exp_bias = (1 << (exp_bits - 1)) - 1; let sign_shift = bits - 1; let mant_bits = bits - exp_bits - 1; @@ -419,6 +583,7 @@ impl RenderPipelineInOutStage for ConvertF32ToF16Stage { mod test { use super::*; use crate::error::Result; + use crate::headers::bit_depth::BitDepth; use test_log::test; #[test] @@ -467,4 +632,161 @@ mod test { 1, ) } + + #[test] + fn test_int_to_float_32bit() { + // Test 32-bit float passthrough + let bit_depth = BitDepth::f32(); + let test_values: Vec = vec![ + 0.0, + 1.0, + -1.0, + 0.5, + -0.5, + f32::INFINITY, + f32::NEG_INFINITY, + 1e-30, + 1e30, + ]; + let input: Vec = test_values.iter().map(|&f| f.to_bits() as i32).collect(); + let mut output = vec![0.0f32; input.len()]; + + int_to_float(&input, &mut output, &bit_depth); + + for (i, (&expected, &actual)) in test_values.iter().zip(output.iter()).enumerate() { + if expected.is_nan() { + assert!(actual.is_nan(), "index {}: expected NaN, got {}", i, actual); + } else { + assert_eq!(expected, actual, "index {}: mismatch", i); + } + } + } + + #[test] + fn test_int_to_float_16bit_normal() { + // Test 16-bit float (f16) conversion for normal values + let bit_depth = BitDepth::f16(); + + // f16 format: 1 sign, 5 exp, 10 mantissa + // Test cases: (f16_bits, expected_f32) + let test_cases: Vec<(u16, f32)> = vec![ + (0x0000, 0.0), // +0 + (0x8000, -0.0), // -0 + (0x3C00, 1.0), // 1.0 + (0xBC00, -1.0), // -1.0 + (0x3800, 0.5), // 0.5 + (0x4000, 2.0), // 2.0 + (0x4400, 4.0), // 4.0 + (0x7BFF, 65504.0), // max normal f16 + ]; + + let input: Vec = test_cases.iter().map(|(bits, _)| *bits as i32).collect(); + let mut output = vec![0.0f32; input.len()]; + + int_to_float(&input, &mut output, &bit_depth); + + for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() { + assert!( + (expected - actual).abs() < 1e-6 + || (expected.is_sign_negative() == actual.is_sign_negative() + && *expected == 0.0 + && actual == 0.0), + "index {}: expected {}, got {}", + i, + expected, + actual + ); + } + } + + #[test] + fn test_int_to_float_16bit_special() { + // Test 16-bit float conversion for special values (inf, nan) + let bit_depth = BitDepth::f16(); + + let test_cases: Vec<(u16, f32)> = vec![ + (0x7C00, f32::INFINITY), // +inf + (0xFC00, f32::NEG_INFINITY), // -inf + ]; + + let input: Vec = test_cases.iter().map(|(bits, _)| *bits as i32).collect(); + let mut output = vec![0.0f32; input.len()]; + + int_to_float(&input, &mut output, &bit_depth); + + for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() { + assert_eq!( + *expected, actual, + "index {}: expected {}, got {}", + i, expected, actual + ); + } + } + + #[test] + fn test_int_to_float_16bit_subnormal() { + // Test 16-bit float conversion for subnormal values + let bit_depth = BitDepth::f16(); + + // Verify bit_depth is set correctly + assert_eq!(bit_depth.bits_per_sample(), 16); + assert_eq!(bit_depth.exponent_bits_per_sample(), 5); + assert!(bit_depth.floating_point_sample()); + + // Smallest subnormal: 2^-24 ≈ 5.96e-8 + // Largest subnormal: (2^10 - 1) * 2^-24 ≈ 6.10e-5 + let test_cases: Vec<(u16, f32)> = vec![ + (0x0001, 5.960464477539063e-8), // smallest positive subnormal + (0x03FF, 6.097555160522461e-5), // largest positive subnormal + (0x8001, -5.960464477539063e-8), // smallest negative subnormal + ]; + + // First test the scalar function directly + for (bits, expected) in &test_cases { + let scalar_result = int_to_float_16bit_scalar(*bits as i32); + let rel_err = ((expected - scalar_result) / expected).abs(); + assert!( + rel_err < 1e-6, + "scalar: bits=0x{:04X}, expected {}, got {}, rel_err {}", + bits, + expected, + scalar_result, + rel_err + ); + } + + // Test through int_to_float_generic (which should match scalar) + let input: Vec = test_cases.iter().map(|(bits, _)| *bits as i32).collect(); + let mut generic_output = vec![0.0f32; input.len()]; + int_to_float_generic(&input, &mut generic_output, 16, 5); + for (i, ((_, expected), &actual)) in + test_cases.iter().zip(generic_output.iter()).enumerate() + { + let rel_err = ((expected - actual) / expected).abs(); + assert!( + rel_err < 1e-6, + "generic index {}: expected {}, got {}, rel_err {}", + i, + expected, + actual, + rel_err + ); + } + + // Now test through the main function (uses SIMD dispatch) + let mut output = vec![0.0f32; input.len()]; + int_to_float(&input, &mut output, &bit_depth); + + for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() { + let rel_err = ((expected - actual) / expected).abs(); + assert!( + rel_err < 1e-6, + "simd index {}: expected {}, got {}, rel_err {}", + i, + expected, + actual, + rel_err + ); + } + } } From 081c15647ad24e9a8cdb1cc6f6c9bb9cba40d7f9 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Mon, 22 Dec 2025 10:14:55 +0100 Subject: [PATCH 02/14] Fix clippy excessive precision warnings in f16 tests --- jxl/src/render/stages/convert.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jxl/src/render/stages/convert.rs b/jxl/src/render/stages/convert.rs index 279006758..08a8f1cb2 100644 --- a/jxl/src/render/stages/convert.rs +++ b/jxl/src/render/stages/convert.rs @@ -736,9 +736,9 @@ mod test { // Smallest subnormal: 2^-24 ≈ 5.96e-8 // Largest subnormal: (2^10 - 1) * 2^-24 ≈ 6.10e-5 let test_cases: Vec<(u16, f32)> = vec![ - (0x0001, 5.960464477539063e-8), // smallest positive subnormal - (0x03FF, 6.097555160522461e-5), // largest positive subnormal - (0x8001, -5.960464477539063e-8), // smallest negative subnormal + (0x0001, 5.960_464_5e-8), // smallest positive subnormal + (0x03FF, 6.097_555e-5), // largest positive subnormal + (0x8001, -5.960_464_5e-8), // smallest negative subnormal ]; // First test the scalar function directly From ece9ccab37fccf8bc244570da728081643508035 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Fri, 16 Jan 2026 19:35:13 +0100 Subject: [PATCH 03/14] Refactor f16 conversion to use F32Vec::load_f16_bits/store_f16 Address veluca93 review: add load_f16_bits() and store_f16() methods to F32SimdVec trait instead of implementing conversion in convert.rs. - AVX2+F16C: Hardware _mm256_cvtph_ps/_mm256_cvtps_ph - AVX-512: Hardware _mm512_cvtph_ps/_mm512_cvtps_ph - SSE4.2/NEON/Scalar: Scalar fallback Simplifies convert.rs by ~100 lines. --- jxl/src/render/stages/convert.rs | 118 ++++--------------------------- jxl_simd/src/aarch64/neon.rs | 24 +++++++ jxl_simd/src/lib.rs | 12 +++- jxl_simd/src/scalar.rs | 90 +++++++++++++++++++++++ jxl_simd/src/x86_64/avx.rs | 51 +++++++++++++ jxl_simd/src/x86_64/avx512.rs | 33 +++++++++ jxl_simd/src/x86_64/sse42.rs | 22 ++++++ 7 files changed, 245 insertions(+), 105 deletions(-) diff --git a/jxl/src/render/stages/convert.rs b/jxl/src/render/stages/convert.rs index 08a8f1cb2..0184a1499 100644 --- a/jxl/src/render/stages/convert.rs +++ b/jxl/src/render/stages/convert.rs @@ -8,7 +8,7 @@ use crate::{ headers::bit_depth::BitDepth, render::{Channels, ChannelsMut, RenderPipelineInOutStage}, }; -use jxl_simd::{F32SimdVec, I32SimdVec, SimdMask, shl, simd_function}; +use jxl_simd::{F32SimdVec, I32SimdVec, simd_function}; pub struct ConvertU8F32Stage { channel: usize, @@ -162,130 +162,40 @@ simd_function!( ); // SIMD 16-bit float (half-precision) to 32-bit float conversion -// This handles IEEE 754 binary16 format: 1 sign bit, 5 exponent bits, 10 mantissa bits +// Uses hardware F16C/NEON instructions when available via F32Vec::load_f16_bits() simd_function!( int_to_float_16bit_simd_dispatch, d: D, fn int_to_float_16bit_simd(input: &[i32], output: &mut [f32], xsize: usize) { - let simd_width = D::I32Vec::LEN; + let simd_width = D::F32Vec::LEN; let num_full_chunks = xsize / simd_width; - // Constants for 16-bit float (exp_bits=5, mant_bits=10) - let abs_mask = D::I32Vec::splat(d, 0x7FFF); // Mask for absolute value - let exp_mask = D::I32Vec::splat(d, 0x7C00); // Exponent bits in f16 - let mant_mask = D::I32Vec::splat(d, 0x03FF); // Mantissa bits in f16 - let exp_max = D::I32Vec::splat(d, 0x7C00); // Max exponent (inf/nan) - let exp_bias_adjust = D::I32Vec::splat(d, (127 - 15) << 23); // Bias adjustment shifted - let f32_inf_exp = D::I32Vec::splat(d, 0x7F80_0000_u32 as i32); + // Temporary buffer for converting i32 -> u16 + // Stack-allocated for common SIMD widths (up to 16 elements for AVX-512) + let mut u16_buf = [0u16; 16]; for (in_chunk, out_chunk) in input .chunks_exact(simd_width) .zip(output.chunks_exact_mut(simd_width)) .take(num_full_chunks) { - let val = D::I32Vec::load(d, in_chunk); - - // Extract components - let abs_val = val & abs_mask; // Absolute value (exp + mantissa) - let exp_bits = val & exp_mask; // Exponent bits - let mant_bits = val & mant_mask; // Mantissa bits - - // Check for zero - let is_zero = abs_val.eq_zero(); - - // Check for inf/nan (exponent all 1s) - let is_inf_nan = exp_bits.eq(exp_max); - - // Check for subnormal (exponent is 0 but mantissa non-zero) - // Use andnot: !mant_is_zero & exp_is_zero - let exp_is_zero = exp_bits.eq_zero(); - let mant_is_zero = mant_bits.eq_zero(); - // is_subnormal = exp_is_zero AND NOT mant_is_zero - let is_subnormal = mant_is_zero.andnot(exp_is_zero); - - // Normal case: shift exponent and mantissa, adjust bias - // Sign bit at position 15 goes to position 31: shift left by 16 - // f16 exponent at bits 10-14 goes to f32 exponent at bits 23-30 - // f16 mantissa at bits 0-9 goes to f32 mantissa at bits 13-22 (shift left by 13) - let sign_shifted = shl!(val, 16) & D::I32Vec::splat(d, 0x8000_0000_u32 as i32); - let normal_exp = shl!(exp_bits, 13); - let normal_mant = shl!(mant_bits, 13); - let normal_result = sign_shifted | (normal_exp + exp_bias_adjust) | normal_mant; - - // Inf/NaN case: preserve mantissa pattern, set f32 inf exponent - let inf_nan_result = sign_shifted | f32_inf_exp | normal_mant; - - // Zero case: just the sign bit - let zero_result = sign_shifted; - - // Select result based on conditions - // Start with normal result, then override special cases - let result = is_inf_nan.if_then_else_i32(inf_nan_result, normal_result); - let result = is_zero.if_then_else_i32(zero_result, result); - - // For subnormals, fall back to scalar (rare case) - // maskz_i32 returns 0 where mask is true, so if any subnormal exists, - // there will be a 0 in subnormal_check, meaning eq_zero().all() would be true - // only if ALL elements are subnormal. We want to check if ANY are subnormal. - // So we check the inverse: if NOT eq_zero for all (meaning no subnormals), use SIMD. - let subnormal_check = is_subnormal.maskz_i32(D::I32Vec::splat(d, 1)); - // subnormal_check is 0 where is_subnormal=true, 1 where is_subnormal=false - // If all elements are 1 (no subnormals), eq(splat(1)).all() is true - let no_subnormals = subnormal_check.eq(D::I32Vec::splat(d, 1)); - if no_subnormals.all() { - // No subnormals - use SIMD result - result.bitcast_to_f32().store(out_chunk); - } else { - // At least one subnormal - process this chunk scalar - for (&in_val, out_val) in in_chunk.iter().zip(out_chunk.iter_mut()) { - *out_val = int_to_float_16bit_scalar(in_val); - } + // Convert i32 values to u16 (f16 bit patterns are in lower 16 bits) + for (i, &val) in in_chunk.iter().enumerate() { + u16_buf[i] = val as u16; } + // Use hardware f16->f32 conversion + let result = D::F32Vec::load_f16_bits(d, &u16_buf[..simd_width]); + result.store(out_chunk); } // Handle remainder with scalar let remainder_start = num_full_chunks * simd_width; for i in remainder_start..xsize { - output[i] = int_to_float_16bit_scalar(input[i]); + output[i] = jxl_simd::scalar::f16_to_f32(input[i] as u16); } } ); -// Scalar fallback for 16-bit float conversion (handles subnormals) -#[inline] -fn int_to_float_16bit_scalar(in_val: i32) -> f32 { - let mut f = in_val as u32; - let signbit = (f >> 15) != 0; - f &= 0x7FFF; - if f == 0 { - return if signbit { -0.0 } else { 0.0 }; - } - let mut exp = (f >> 10) as i32; - let mut mantissa = f & 0x3FF; - if exp == 31 { - // NaN or infinity - f = if signbit { 0x80000000 } else { 0 }; - f |= 0xFF << 23; - f |= mantissa << 13; - return f32::from_bits(f); - } - mantissa <<= 13; - if exp == 0 { - // subnormal number - normalize - while (mantissa & 0x800000) == 0 { - mantissa <<= 1; - exp -= 1; - } - exp += 1; - mantissa &= 0x7fffff; - } - exp = exp - 15 + 127; - f = if signbit { 0x80000000 } else { 0 }; - f |= (exp as u32) << 23; - f |= mantissa; - f32::from_bits(f) -} - // Converts custom [bits]-bit float (with [exp_bits] exponent bits) stored as // int back to binary32 float. fn int_to_float(input: &[i32], output: &mut [f32], bit_depth: &BitDepth) { @@ -743,7 +653,7 @@ mod test { // First test the scalar function directly for (bits, expected) in &test_cases { - let scalar_result = int_to_float_16bit_scalar(*bits as i32); + let scalar_result = jxl_simd::scalar::f16_to_f32(*bits); let rel_err = ((expected - scalar_result) / expected).abs(); assert!( rel_err < 1e-6, diff --git a/jxl_simd/src/aarch64/neon.rs b/jxl_simd/src/aarch64/neon.rs index 66f20d152..53b4a717f 100644 --- a/jxl_simd/src/aarch64/neon.rs +++ b/jxl_simd/src/aarch64/neon.rs @@ -441,6 +441,30 @@ unsafe impl F32SimdVec for F32VecNeon { vst1_u16(dest.as_mut_ptr(), u16s); } } + + fn store_f16(this: F32VecNeon, dest: &mut [u16]) { + assert!(dest.len() >= F32VecNeon::LEN); + // TODO: Use vcvt_f16_f32 once Rust stdarch fix lands + // For now, use scalar conversion + let mut tmp = [0.0f32; 4]; + this.store(&mut tmp); + for i in 0..4 { + dest[i] = crate::scalar::f32_to_f16(tmp[i]); + } + } + } + + #[inline(always)] + fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { + assert!(mem.len() >= Self::LEN); + // TODO: Use vcvt_f32_f16 once Rust stdarch fix lands + // (currently requires fp16 target feature incorrectly) + // For now, use scalar conversion + let mut result = [0.0f32; 4]; + for i in 0..4 { + result[i] = crate::scalar::f16_to_f32(mem[i]); + } + Self::load(d, &result) } #[inline(always)] diff --git a/jxl_simd/src/lib.rs b/jxl_simd/src/lib.rs index c9aa8fb91..bf4a7326e 100644 --- a/jxl_simd/src/lib.rs +++ b/jxl_simd/src/lib.rs @@ -20,7 +20,7 @@ mod x86_64; #[cfg(target_arch = "aarch64")] mod aarch64; -mod scalar; +pub mod scalar; #[cfg(all(target_arch = "x86_64", feature = "avx"))] pub use x86_64::avx::AvxDescriptor; @@ -270,6 +270,16 @@ pub unsafe trait F32SimdVec: /// Transposes the Self::LEN x Self::LEN matrix formed by array elements /// `data[stride * i]` for i = 0..Self::LEN. fn transpose_square(d: Self::Descriptor, data: &mut [Self::UnderlyingArray], stride: usize); + + /// Loads f16 values (stored as u16 bit patterns) and converts them to f32. + /// Uses hardware conversion instructions when available (F16C on x86, NEON fp16 on ARM). + /// Requires `mem.len() >= Self::LEN` or it will panic. + fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self; + + /// Converts f32 values to f16 and stores as u16 bit patterns. + /// Uses hardware conversion instructions when available (F16C on x86, NEON fp16 on ARM). + /// Requires `dest.len() >= Self::LEN` or it will panic. + fn store_f16(self, dest: &mut [u16]); } pub trait I32SimdVec: diff --git a/jxl_simd/src/scalar.rs b/jxl_simd/src/scalar.rs index 9667e7e39..684affe9a 100644 --- a/jxl_simd/src/scalar.rs +++ b/jxl_simd/src/scalar.rs @@ -10,6 +10,86 @@ use crate::{U32SimdVec, impl_f32_array_interface}; use super::{F32SimdVec, I32SimdVec, SimdDescriptor, SimdMask}; +/// Convert f16 bits (as u16) to f32. +#[inline(always)] +pub fn f16_to_f32(bits: u16) -> f32 { + let sign = ((bits >> 15) & 1) as u32; + let exp = ((bits >> 10) & 0x1F) as i32; + let mant = (bits & 0x3FF) as u32; + + if exp == 0 { + if mant == 0 { + // Zero (positive or negative) + f32::from_bits(sign << 31) + } else { + // Subnormal: normalize + let mut m = mant; + let mut e = -14i32; + while (m & 0x400) == 0 { + m <<= 1; + e -= 1; + } + m &= 0x3FF; + let f32_exp = ((e + 127) as u32) << 23; + let f32_mant = m << 13; + f32::from_bits((sign << 31) | f32_exp | f32_mant) + } + } else if exp == 31 { + // Inf or NaN + let f32_mant = mant << 13; + f32::from_bits((sign << 31) | 0x7F800000 | f32_mant) + } else { + // Normal number + let f32_exp = ((exp - 15 + 127) as u32) << 23; + let f32_mant = mant << 13; + f32::from_bits((sign << 31) | f32_exp | f32_mant) + } +} + +/// Convert f32 to f16 bits (as u16). +#[inline(always)] +pub fn f32_to_f16(val: f32) -> u16 { + let bits = val.to_bits(); + let sign = ((bits >> 31) & 1) as u16; + let exp = ((bits >> 23) & 0xFF) as i32; + let mant = bits & 0x7FFFFF; + + if exp == 0 { + // Zero or subnormal f32 -> zero in f16 + sign << 15 + } else if exp == 255 { + // Inf or NaN + if mant == 0 { + // Infinity + (sign << 15) | 0x7C00 + } else { + // NaN - preserve some mantissa bits + (sign << 15) | 0x7C00 | ((mant >> 13) as u16).max(1) + } + } else { + // Normal number + let new_exp = exp - 127 + 15; + if new_exp >= 31 { + // Overflow -> infinity + (sign << 15) | 0x7C00 + } else if new_exp <= 0 { + // Underflow -> subnormal or zero + if new_exp < -10 { + // Too small -> zero + sign << 15 + } else { + // Subnormal + let m = (mant | 0x800000) >> (1 - new_exp + 13); + (sign << 15) | (m as u16) + } + } else { + // Normal f16 + let f16_mant = (mant >> 13) as u16; + (sign << 15) | ((new_exp as u16) << 10) | f16_mant + } + } +} + #[derive(Clone, Copy, Debug)] pub struct ScalarDescriptor; @@ -213,6 +293,16 @@ unsafe impl F32SimdVec for f32 { dest[0] = self.round() as u16; } + #[inline(always)] + fn load_f16_bits(_d: Self::Descriptor, mem: &[u16]) -> Self { + f16_to_f32(mem[0]) + } + + #[inline(always)] + fn store_f16(self, dest: &mut [u16]) { + dest[0] = f32_to_f16(self); + } + impl_f32_array_interface!(); #[inline(always)] diff --git a/jxl_simd/src/x86_64/avx.rs b/jxl_simd/src/x86_64/avx.rs index 6de9f53a8..3822e47fb 100644 --- a/jxl_simd/src/x86_64/avx.rs +++ b/jxl_simd/src/x86_64/avx.rs @@ -668,6 +668,57 @@ unsafe impl F32SimdVec for F32VecAvx { impl_f32_array_interface!(); + #[inline(always)] + fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { + assert!(mem.len() >= Self::LEN); + // Check for F16C at runtime and use hardware conversion if available + if is_x86_feature_detected!("f16c") { + #[target_feature(enable = "avx2,f16c")] + #[inline] + unsafe fn load_f16_f16c(d: AvxDescriptor, mem: &[u16]) -> F32VecAvx { + unsafe { + let bits = _mm_loadu_si128(mem.as_ptr() as *const __m128i); + F32VecAvx(_mm256_cvtph_ps(bits), d) + } + } + // SAFETY: we just checked f16c is available + unsafe { load_f16_f16c(d, mem) } + } else { + // Fallback to scalar conversion + let mut result = [0.0f32; 8]; + for i in 0..8 { + result[i] = crate::scalar::f16_to_f32(mem[i]); + } + Self::load(d, &result) + } + } + + #[inline(always)] + fn store_f16(self, dest: &mut [u16]) { + assert!(dest.len() >= Self::LEN); + // Check for F16C at runtime and use hardware conversion if available + if is_x86_feature_detected!("f16c") { + #[target_feature(enable = "avx2,f16c")] + #[inline] + unsafe fn store_f16_f16c(v: __m256, dest: &mut [u16]) { + unsafe { + // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC = 0 + let bits = _mm256_cvtps_ph::<0>(v); + _mm_storeu_si128(dest.as_mut_ptr() as *mut __m128i, bits); + } + } + // SAFETY: we just checked f16c is available + unsafe { store_f16_f16c(self.0, dest) } + } else { + // Fallback to scalar conversion + let mut tmp = [0.0f32; 8]; + self.store(&mut tmp); + for i in 0..8 { + dest[i] = crate::scalar::f32_to_f16(tmp[i]); + } + } + } + #[inline(always)] fn transpose_square(d: Self::Descriptor, data: &mut [Self::UnderlyingArray], stride: usize) { #[target_feature(enable = "avx2")] diff --git a/jxl_simd/src/x86_64/avx512.rs b/jxl_simd/src/x86_64/avx512.rs index e64c5f086..f649cdbdc 100644 --- a/jxl_simd/src/x86_64/avx512.rs +++ b/jxl_simd/src/x86_64/avx512.rs @@ -730,6 +730,39 @@ unsafe impl F32SimdVec for F32VecAvx512 { impl_f32_array_interface!(); + #[inline(always)] + fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { + assert!(mem.len() >= Self::LEN); + // AVX512 implies F16C, so we can always use hardware conversion + #[target_feature(enable = "avx512f")] + #[inline] + unsafe fn load_f16_impl(d: Avx512Descriptor, mem: &[u16]) -> F32VecAvx512 { + unsafe { + let bits = _mm256_loadu_si256(mem.as_ptr() as *const __m256i); + F32VecAvx512(_mm512_cvtph_ps(bits), d) + } + } + // SAFETY: avx512f is available from the safety invariant on the descriptor + unsafe { load_f16_impl(d, mem) } + } + + #[inline(always)] + fn store_f16(self, dest: &mut [u16]) { + assert!(dest.len() >= Self::LEN); + // AVX512 implies F16C, so we can always use hardware conversion + #[target_feature(enable = "avx512f")] + #[inline] + unsafe fn store_f16_impl(v: __m512, dest: &mut [u16]) { + unsafe { + // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC = 0 + let bits = _mm512_cvtps_ph::<0>(v); + _mm256_storeu_si256(dest.as_mut_ptr() as *mut __m256i, bits); + } + } + // SAFETY: avx512f is available from the safety invariant on the descriptor + unsafe { store_f16_impl(self.0, dest) } + } + #[inline(always)] fn transpose_square(d: Self::Descriptor, data: &mut [Self::UnderlyingArray], stride: usize) { #[target_feature(enable = "avx512f")] diff --git a/jxl_simd/src/x86_64/sse42.rs b/jxl_simd/src/x86_64/sse42.rs index 7fef8c923..e654bafc0 100644 --- a/jxl_simd/src/x86_64/sse42.rs +++ b/jxl_simd/src/x86_64/sse42.rs @@ -609,6 +609,28 @@ unsafe impl F32SimdVec for F32VecSse42 { impl_f32_array_interface!(); + #[inline(always)] + fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { + assert!(mem.len() >= Self::LEN); + // SSE4.2 doesn't have F16C, use scalar conversion + let mut result = [0.0f32; 4]; + for i in 0..4 { + result[i] = crate::scalar::f16_to_f32(mem[i]); + } + Self::load(d, &result) + } + + #[inline(always)] + fn store_f16(self, dest: &mut [u16]) { + assert!(dest.len() >= Self::LEN); + // SSE4.2 doesn't have F16C, use scalar conversion + let mut tmp = [0.0f32; 4]; + self.store(&mut tmp); + for i in 0..4 { + dest[i] = crate::scalar::f32_to_f16(tmp[i]); + } + } + #[inline(always)] fn transpose_square(d: Self::Descriptor, data: &mut [Self::UnderlyingArray], stride: usize) { #[target_feature(enable = "sse4.2")] From fd05d5664c6a9b1c6865fe7b96aa5c181da1a923 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Fri, 16 Jan 2026 20:30:50 +0100 Subject: [PATCH 04/14] Add SAFETY comments to unsafe blocks --- jxl_simd/src/x86_64/avx.rs | 2 ++ jxl_simd/src/x86_64/avx512.rs | 2 ++ 2 files changed, 4 insertions(+) diff --git a/jxl_simd/src/x86_64/avx.rs b/jxl_simd/src/x86_64/avx.rs index 3822e47fb..7341fe4b8 100644 --- a/jxl_simd/src/x86_64/avx.rs +++ b/jxl_simd/src/x86_64/avx.rs @@ -676,6 +676,7 @@ unsafe impl F32SimdVec for F32VecAvx { #[target_feature(enable = "avx2,f16c")] #[inline] unsafe fn load_f16_f16c(d: AvxDescriptor, mem: &[u16]) -> F32VecAvx { + // SAFETY: mem.len() >= 8 is checked by caller, and f16c is available unsafe { let bits = _mm_loadu_si128(mem.as_ptr() as *const __m128i); F32VecAvx(_mm256_cvtph_ps(bits), d) @@ -701,6 +702,7 @@ unsafe impl F32SimdVec for F32VecAvx { #[target_feature(enable = "avx2,f16c")] #[inline] unsafe fn store_f16_f16c(v: __m256, dest: &mut [u16]) { + // SAFETY: dest.len() >= 8 is checked by caller, and f16c is available unsafe { // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC = 0 let bits = _mm256_cvtps_ph::<0>(v); diff --git a/jxl_simd/src/x86_64/avx512.rs b/jxl_simd/src/x86_64/avx512.rs index f649cdbdc..3ca4381c1 100644 --- a/jxl_simd/src/x86_64/avx512.rs +++ b/jxl_simd/src/x86_64/avx512.rs @@ -737,6 +737,7 @@ unsafe impl F32SimdVec for F32VecAvx512 { #[target_feature(enable = "avx512f")] #[inline] unsafe fn load_f16_impl(d: Avx512Descriptor, mem: &[u16]) -> F32VecAvx512 { + // SAFETY: mem.len() >= 16 is checked by caller, and avx512f is available unsafe { let bits = _mm256_loadu_si256(mem.as_ptr() as *const __m256i); F32VecAvx512(_mm512_cvtph_ps(bits), d) @@ -753,6 +754,7 @@ unsafe impl F32SimdVec for F32VecAvx512 { #[target_feature(enable = "avx512f")] #[inline] unsafe fn store_f16_impl(v: __m512, dest: &mut [u16]) { + // SAFETY: dest.len() >= 16 is checked by caller, and avx512f is available unsafe { // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC = 0 let bits = _mm512_cvtps_ph::<0>(v); From d9832f290a252a9b9fa8bd3d32c2d323731788a2 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Tue, 20 Jan 2026 22:00:01 +0100 Subject: [PATCH 05/14] Address PR review comments for SIMD f16 conversion - AVX: Always require f16c for AVX2 path (removes runtime check) - AVX512: Restructure inner functions to not be unsafe, only wrap memory operations in unsafe blocks with SAFETY comments - NEON: Use inline ASM for f16 conversion (fcvtl/fcvtn) since stdarch incorrectly requires fp16 feature for basic conversion - Add f16 type module to jxl_simd and use it instead of u16/standalone functions throughout the crate --- jxl/src/render/stages/convert.rs | 4 +- jxl_simd/src/aarch64/neon.rs | 40 ++-- jxl_simd/src/float16.rs | 312 +++++++++++++++++++++++++++++++ jxl_simd/src/lib.rs | 3 + jxl_simd/src/scalar.rs | 86 +-------- jxl_simd/src/x86_64/avx.rs | 78 +++----- jxl_simd/src/x86_64/avx512.rs | 25 +-- jxl_simd/src/x86_64/sse42.rs | 4 +- 8 files changed, 387 insertions(+), 165 deletions(-) create mode 100644 jxl_simd/src/float16.rs diff --git a/jxl/src/render/stages/convert.rs b/jxl/src/render/stages/convert.rs index 0184a1499..7c1ecf851 100644 --- a/jxl/src/render/stages/convert.rs +++ b/jxl/src/render/stages/convert.rs @@ -191,7 +191,7 @@ simd_function!( // Handle remainder with scalar let remainder_start = num_full_chunks * simd_width; for i in remainder_start..xsize { - output[i] = jxl_simd::scalar::f16_to_f32(input[i] as u16); + output[i] = jxl_simd::f16::from_bits(input[i] as u16).to_f32(); } } ); @@ -653,7 +653,7 @@ mod test { // First test the scalar function directly for (bits, expected) in &test_cases { - let scalar_result = jxl_simd::scalar::f16_to_f32(*bits); + let scalar_result = jxl_simd::f16::from_bits(*bits).to_f32(); let rel_err = ((expected - scalar_result) / expected).abs(); assert!( rel_err < 1e-6, diff --git a/jxl_simd/src/aarch64/neon.rs b/jxl_simd/src/aarch64/neon.rs index 53b4a717f..b5c73f354 100644 --- a/jxl_simd/src/aarch64/neon.rs +++ b/jxl_simd/src/aarch64/neon.rs @@ -444,12 +444,18 @@ unsafe impl F32SimdVec for F32VecNeon { fn store_f16(this: F32VecNeon, dest: &mut [u16]) { assert!(dest.len() >= F32VecNeon::LEN); - // TODO: Use vcvt_f16_f32 once Rust stdarch fix lands - // For now, use scalar conversion - let mut tmp = [0.0f32; 4]; - this.store(&mut tmp); - for i in 0..4 { - dest[i] = crate::scalar::f32_to_f16(tmp[i]); + // Use inline asm because Rust stdarch incorrectly requires fp16 target feature + // for vcvt_f16_f32 (fixed in https://github.com/rust-lang/stdarch/pull/1978) + let f16_bits: uint16x4_t; + // SAFETY: NEON is available (guaranteed by descriptor), dest has enough space + unsafe { + std::arch::asm!( + "fcvtn {out:v}.4h, {inp:v}.4s", + inp = in(vreg) this.0, + out = out(vreg) f16_bits, + options(pure, nomem, nostack), + ); + vst1_u16(dest.as_mut_ptr(), f16_bits); } } } @@ -457,14 +463,20 @@ unsafe impl F32SimdVec for F32VecNeon { #[inline(always)] fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { assert!(mem.len() >= Self::LEN); - // TODO: Use vcvt_f32_f16 once Rust stdarch fix lands - // (currently requires fp16 target feature incorrectly) - // For now, use scalar conversion - let mut result = [0.0f32; 4]; - for i in 0..4 { - result[i] = crate::scalar::f16_to_f32(mem[i]); - } - Self::load(d, &result) + // Use inline asm because Rust stdarch incorrectly requires fp16 target feature + // for vcvt_f32_f16 (fixed in https://github.com/rust-lang/stdarch/pull/1978) + let result: float32x4_t; + // SAFETY: NEON is available (guaranteed by descriptor), mem has enough space + unsafe { + let f16_bits = vld1_u16(mem.as_ptr()); + std::arch::asm!( + "fcvtl {out:v}.4s, {inp:v}.4h", + inp = in(vreg) f16_bits, + out = out(vreg) result, + options(pure, nomem, nostack), + ); + } + F32VecNeon(result, d) } #[inline(always)] diff --git a/jxl_simd/src/float16.rs b/jxl_simd/src/float16.rs new file mode 100644 index 000000000..8fb07e0f1 --- /dev/null +++ b/jxl_simd/src/float16.rs @@ -0,0 +1,312 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//! IEEE 754 half-precision (binary16) floating-point type. +//! +//! This is a minimal implementation providing only the operations needed for JPEG XL decoding, +//! avoiding external dependencies like `half` which pulls in `zerocopy`. + +/// IEEE 754 binary16 half-precision floating-point type. +/// +/// Format: 1 sign bit, 5 exponent bits (bias 15), 10 mantissa bits. +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, Default, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct f16(u16); + +impl f16 { + /// Positive zero. + pub const ZERO: Self = Self(0); + + /// Creates an f16 from its raw bit representation. + #[inline] + pub const fn from_bits(bits: u16) -> Self { + Self(bits) + } + + /// Returns the raw bit representation. + #[inline] + pub const fn to_bits(self) -> u16 { + self.0 + } + + /// Converts to f32. + #[inline] + pub fn to_f32(self) -> f32 { + let bits = self.0; + let sign = ((bits >> 15) & 1) as u32; + let exp = ((bits >> 10) & 0x1F) as u32; + let mant = (bits & 0x3FF) as u32; + + let f32_bits = if exp == 0 { + if mant == 0 { + // Zero (signed) + sign << 31 + } else { + // Denormal f16 -> normalized f32 + // Find the leading 1 bit in mantissa + let mut m = mant; + let mut e = 0u32; + while (m & 0x400) == 0 { + m <<= 1; + e += 1; + } + m &= 0x3FF; // Remove the implicit leading 1 + // f16 denormal exponent is -14 (not -15), adjust by shift count + let new_exp = 127 - 14 - e; + (sign << 31) | (new_exp << 23) | (m << 13) + } + } else if exp == 31 { + // Infinity or NaN + if mant == 0 { + // Infinity + (sign << 31) | (0xFF << 23) + } else { + // NaN - preserve some payload bits, ensure quiet NaN + (sign << 31) | (0xFF << 23) | (mant << 13) | 0x0040_0000 + } + } else { + // Normal number + // Rebias: f16 uses bias 15, f32 uses bias 127 + // new_exp = exp - 15 + 127 = exp + 112 + let new_exp = exp + 112; + (sign << 31) | (new_exp << 23) | (mant << 13) + }; + + f32::from_bits(f32_bits) + } + + /// Creates an f16 from an f32. + #[inline] + pub fn from_f32(f: f32) -> Self { + let bits = f.to_bits(); + let sign = ((bits >> 31) & 1) as u16; + let exp = ((bits >> 23) & 0xFF) as i32; + let mant = bits & 0x007F_FFFF; + + let h_bits = if exp == 0 { + // Zero or f32 denormal -> f16 zero (too small) + sign << 15 + } else if exp == 255 { + // Infinity or NaN + if mant == 0 { + (sign << 15) | (0x1F << 10) // Infinity + } else { + (sign << 15) | (0x1F << 10) | 0x0200 // Quiet NaN + } + } else { + let unbiased = exp - 127; + + if unbiased < -24 { + // Too small, underflow to zero + sign << 15 + } else if unbiased < -14 { + // Denormal f16 + let shift = (-14 - unbiased) as u32; + let m = ((mant | 0x0080_0000) >> (shift + 14)) as u16; + (sign << 15) | m + } else if unbiased > 15 { + // Overflow to infinity + (sign << 15) | (0x1F << 10) + } else { + // Normal f16 + let h_exp = (unbiased + 15) as u16; + let h_mant = (mant >> 13) as u16; + + // Round to nearest, ties to even + let round_bit = (mant >> 12) & 1; + let sticky = mant & 0x0FFF; + let h_mant = if round_bit == 1 && (sticky != 0 || (h_mant & 1) == 1) { + h_mant + 1 + } else { + h_mant + }; + + // Handle mantissa overflow from rounding + if h_mant > 0x3FF { + if h_exp >= 30 { + // Overflow to infinity + (sign << 15) | (0x1F << 10) + } else { + (sign << 15) | ((h_exp + 1) << 10) + } + } else { + (sign << 15) | (h_exp << 10) | h_mant + } + } + }; + + Self(h_bits) + } + + /// Creates an f16 from an f64. + #[inline] + pub fn from_f64(f: f64) -> Self { + // Convert via f32 - sufficient precision for f16 + Self::from_f32(f as f32) + } + + /// Converts to f64. + #[inline] + pub fn to_f64(self) -> f64 { + self.to_f32() as f64 + } + + /// Returns true if this is neither infinite nor NaN. + #[inline] + pub fn is_finite(self) -> bool { + // Exponent of 31 means infinity or NaN + ((self.0 >> 10) & 0x1F) != 31 + } + + /// Returns the bytes in little-endian order. + #[inline] + pub const fn to_le_bytes(self) -> [u8; 2] { + self.0.to_le_bytes() + } + + /// Returns the bytes in big-endian order. + #[inline] + pub const fn to_be_bytes(self) -> [u8; 2] { + self.0.to_be_bytes() + } +} + +impl From for f32 { + #[inline] + fn from(f: f16) -> f32 { + f.to_f32() + } +} + +impl From for f64 { + #[inline] + fn from(f: f16) -> f64 { + f.to_f64() + } +} + +impl core::fmt::Debug for f16 { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.to_f32()) + } +} + +impl core::fmt::Display for f16 { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.to_f32()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_zero() { + let z = f16::ZERO; + assert_eq!(z.to_bits(), 0); + assert_eq!(z.to_f32(), 0.0); + assert!(z.is_finite()); + } + + #[test] + fn test_one() { + // 1.0 in f16: sign=0, exp=15 (biased), mant=0 -> 0x3C00 + let one = f16::from_bits(0x3C00); + assert!((one.to_f32() - 1.0).abs() < 1e-6); + assert!(one.is_finite()); + } + + #[test] + fn test_negative_one() { + // -1.0 in f16: sign=1, exp=15, mant=0 -> 0xBC00 + let neg_one = f16::from_bits(0xBC00); + assert!((neg_one.to_f32() - (-1.0)).abs() < 1e-6); + } + + #[test] + fn test_infinity() { + // +Inf: sign=0, exp=31, mant=0 -> 0x7C00 + let inf = f16::from_bits(0x7C00); + assert!(inf.to_f32().is_infinite()); + assert!(!inf.is_finite()); + + // -Inf: 0xFC00 + let neg_inf = f16::from_bits(0xFC00); + assert!(neg_inf.to_f32().is_infinite()); + assert!(!neg_inf.is_finite()); + } + + #[test] + fn test_nan() { + // NaN: exp=31, mant!=0 -> 0x7C01 (or any mant != 0) + let nan = f16::from_bits(0x7C01); + assert!(nan.to_f32().is_nan()); + assert!(!nan.is_finite()); + } + + #[test] + fn test_denormal() { + // Smallest positive denormal: 0x0001 + let tiny = f16::from_bits(0x0001); + let val = tiny.to_f32(); + assert!(val > 0.0); + assert!(val < 1e-6); + assert!(tiny.is_finite()); + } + + #[test] + fn test_roundtrip_normal() { + let test_values: [f32; 8] = [0.5, 1.0, 2.0, 100.0, 0.001, -0.5, -1.0, -100.0]; + for &v in &test_values { + let h = f16::from_f32(v); + let back = h.to_f32(); + // f16 has limited precision, allow ~0.1% error for normal values + let rel_err = ((v - back) / v).abs(); + assert!( + rel_err < 0.002, + "Roundtrip failed for {}: got {}, rel_err {}", + v, + back, + rel_err + ); + } + } + + #[test] + fn test_roundtrip_special() { + // Zero + assert_eq!(f16::from_f32(0.0).to_f32(), 0.0); + + // Infinity + assert!(f16::from_f32(f32::INFINITY).to_f32().is_infinite()); + assert!(f16::from_f32(f32::NEG_INFINITY).to_f32().is_infinite()); + + // NaN + assert!(f16::from_f32(f32::NAN).to_f32().is_nan()); + } + + #[test] + fn test_overflow_to_infinity() { + // f16 max is ~65504, values above should overflow to infinity + let big = f16::from_f32(100000.0); + assert!(big.to_f32().is_infinite()); + } + + #[test] + fn test_underflow_to_zero() { + // Very small values should underflow to zero + let tiny = f16::from_f32(1e-10); + assert_eq!(tiny.to_f32(), 0.0); + } + + #[test] + fn test_bytes() { + let h = f16::from_bits(0x1234); + assert_eq!(h.to_le_bytes(), [0x34, 0x12]); + assert_eq!(h.to_be_bytes(), [0x12, 0x34]); + } +} diff --git a/jxl_simd/src/lib.rs b/jxl_simd/src/lib.rs index bf4a7326e..27df0944a 100644 --- a/jxl_simd/src/lib.rs +++ b/jxl_simd/src/lib.rs @@ -20,8 +20,11 @@ mod x86_64; #[cfg(target_arch = "aarch64")] mod aarch64; +pub mod float16; pub mod scalar; +pub use float16::f16; + #[cfg(all(target_arch = "x86_64", feature = "avx"))] pub use x86_64::avx::AvxDescriptor; #[cfg(all(target_arch = "x86_64", feature = "avx512"))] diff --git a/jxl_simd/src/scalar.rs b/jxl_simd/src/scalar.rs index 684affe9a..f971ff84d 100644 --- a/jxl_simd/src/scalar.rs +++ b/jxl_simd/src/scalar.rs @@ -6,90 +6,10 @@ use std::mem::MaybeUninit; use std::num::Wrapping; -use crate::{U32SimdVec, impl_f32_array_interface}; +use crate::{U32SimdVec, f16, impl_f32_array_interface}; use super::{F32SimdVec, I32SimdVec, SimdDescriptor, SimdMask}; -/// Convert f16 bits (as u16) to f32. -#[inline(always)] -pub fn f16_to_f32(bits: u16) -> f32 { - let sign = ((bits >> 15) & 1) as u32; - let exp = ((bits >> 10) & 0x1F) as i32; - let mant = (bits & 0x3FF) as u32; - - if exp == 0 { - if mant == 0 { - // Zero (positive or negative) - f32::from_bits(sign << 31) - } else { - // Subnormal: normalize - let mut m = mant; - let mut e = -14i32; - while (m & 0x400) == 0 { - m <<= 1; - e -= 1; - } - m &= 0x3FF; - let f32_exp = ((e + 127) as u32) << 23; - let f32_mant = m << 13; - f32::from_bits((sign << 31) | f32_exp | f32_mant) - } - } else if exp == 31 { - // Inf or NaN - let f32_mant = mant << 13; - f32::from_bits((sign << 31) | 0x7F800000 | f32_mant) - } else { - // Normal number - let f32_exp = ((exp - 15 + 127) as u32) << 23; - let f32_mant = mant << 13; - f32::from_bits((sign << 31) | f32_exp | f32_mant) - } -} - -/// Convert f32 to f16 bits (as u16). -#[inline(always)] -pub fn f32_to_f16(val: f32) -> u16 { - let bits = val.to_bits(); - let sign = ((bits >> 31) & 1) as u16; - let exp = ((bits >> 23) & 0xFF) as i32; - let mant = bits & 0x7FFFFF; - - if exp == 0 { - // Zero or subnormal f32 -> zero in f16 - sign << 15 - } else if exp == 255 { - // Inf or NaN - if mant == 0 { - // Infinity - (sign << 15) | 0x7C00 - } else { - // NaN - preserve some mantissa bits - (sign << 15) | 0x7C00 | ((mant >> 13) as u16).max(1) - } - } else { - // Normal number - let new_exp = exp - 127 + 15; - if new_exp >= 31 { - // Overflow -> infinity - (sign << 15) | 0x7C00 - } else if new_exp <= 0 { - // Underflow -> subnormal or zero - if new_exp < -10 { - // Too small -> zero - sign << 15 - } else { - // Subnormal - let m = (mant | 0x800000) >> (1 - new_exp + 13); - (sign << 15) | (m as u16) - } - } else { - // Normal f16 - let f16_mant = (mant >> 13) as u16; - (sign << 15) | ((new_exp as u16) << 10) | f16_mant - } - } -} - #[derive(Clone, Copy, Debug)] pub struct ScalarDescriptor; @@ -295,12 +215,12 @@ unsafe impl F32SimdVec for f32 { #[inline(always)] fn load_f16_bits(_d: Self::Descriptor, mem: &[u16]) -> Self { - f16_to_f32(mem[0]) + f16::from_bits(mem[0]).to_f32() } #[inline(always)] fn store_f16(self, dest: &mut [u16]) { - dest[0] = f32_to_f16(self); + dest[0] = f16::from_f32(self).to_bits(); } impl_f32_array_interface!(); diff --git a/jxl_simd/src/x86_64/avx.rs b/jxl_simd/src/x86_64/avx.rs index 7341fe4b8..581e48057 100644 --- a/jxl_simd/src/x86_64/avx.rs +++ b/jxl_simd/src/x86_64/avx.rs @@ -96,13 +96,13 @@ fn transpose_8x8_core( (c0, c1, c2, c3, c4, c5, c6, c7) } -// Safety invariant: this type is only ever constructed if avx2 and fma are available. +// Safety invariant: this type is only ever constructed if avx2, fma, and f16c are available. #[derive(Clone, Copy, Debug)] pub struct AvxDescriptor(()); impl AvxDescriptor { /// # Safety - /// The caller must guarantee that the "avx2" and "fma" target features are available. + /// The caller must guarantee that the "avx2", "fma", and "f16c" target features are available. pub unsafe fn new_unchecked() -> Self { Self(()) } @@ -139,8 +139,11 @@ impl SimdDescriptor for AvxDescriptor { } fn new() -> Option { - if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { - // SAFETY: we just checked avx2 and fma. + if is_x86_feature_detected!("avx2") + && is_x86_feature_detected!("fma") + && is_x86_feature_detected!("f16c") + { + // SAFETY: we just checked avx2, fma, and f16c. Some(unsafe { Self::new_unchecked() }) } else { None @@ -148,12 +151,12 @@ impl SimdDescriptor for AvxDescriptor { } fn call(self, f: impl FnOnce(Self) -> R) -> R { - #[target_feature(enable = "avx2,fma")] + #[target_feature(enable = "avx2,fma,f16c")] #[inline(never)] unsafe fn inner(d: AvxDescriptor, f: impl FnOnce(AvxDescriptor) -> R) -> R { f(d) } - // SAFETY: the safety invariant on `self` guarantees avx2 and fma. + // SAFETY: the safety invariant on `self` guarantees avx2, fma, and f16c. unsafe { inner(self, f) } } } @@ -165,12 +168,12 @@ macro_rules! fn_avx { fn $name:ident($($arg:ident: $ty:ty),* $(,)?) $(-> $ret:ty )? $body: block) => { #[inline(always)] fn $name(self: $self_ty, $($arg: $ty),*) $(-> $ret)? { - #[target_feature(enable = "fma,avx2")] + #[target_feature(enable = "fma,avx2,f16c")] #[inline] fn inner($this: $self_ty, $($arg: $ty),*) $(-> $ret)? { $body } - // SAFETY: `self.1` is constructed iff avx2 and fma are available. + // SAFETY: `self.1` is constructed iff avx2, fma, and f16c are available. unsafe { inner(self, $($arg),*) } } }; @@ -671,54 +674,31 @@ unsafe impl F32SimdVec for F32VecAvx { #[inline(always)] fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { assert!(mem.len() >= Self::LEN); - // Check for F16C at runtime and use hardware conversion if available - if is_x86_feature_detected!("f16c") { - #[target_feature(enable = "avx2,f16c")] - #[inline] - unsafe fn load_f16_f16c(d: AvxDescriptor, mem: &[u16]) -> F32VecAvx { - // SAFETY: mem.len() >= 8 is checked by caller, and f16c is available - unsafe { - let bits = _mm_loadu_si128(mem.as_ptr() as *const __m128i); - F32VecAvx(_mm256_cvtph_ps(bits), d) - } - } - // SAFETY: we just checked f16c is available - unsafe { load_f16_f16c(d, mem) } - } else { - // Fallback to scalar conversion - let mut result = [0.0f32; 8]; - for i in 0..8 { - result[i] = crate::scalar::f16_to_f32(mem[i]); - } - Self::load(d, &result) + // f16c is guaranteed by the safety invariant on AvxDescriptor + #[target_feature(enable = "avx2,f16c")] + #[inline] + fn load_f16_impl(d: AvxDescriptor, mem: &[u16]) -> F32VecAvx { + // SAFETY: mem.len() >= 8 is checked by caller + let bits = unsafe { _mm_loadu_si128(mem.as_ptr() as *const __m128i) }; + F32VecAvx(_mm256_cvtph_ps(bits), d) } + // SAFETY: f16c is available from the safety invariant on the descriptor + unsafe { load_f16_impl(d, mem) } } #[inline(always)] fn store_f16(self, dest: &mut [u16]) { assert!(dest.len() >= Self::LEN); - // Check for F16C at runtime and use hardware conversion if available - if is_x86_feature_detected!("f16c") { - #[target_feature(enable = "avx2,f16c")] - #[inline] - unsafe fn store_f16_f16c(v: __m256, dest: &mut [u16]) { - // SAFETY: dest.len() >= 8 is checked by caller, and f16c is available - unsafe { - // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC = 0 - let bits = _mm256_cvtps_ph::<0>(v); - _mm_storeu_si128(dest.as_mut_ptr() as *mut __m128i, bits); - } - } - // SAFETY: we just checked f16c is available - unsafe { store_f16_f16c(self.0, dest) } - } else { - // Fallback to scalar conversion - let mut tmp = [0.0f32; 8]; - self.store(&mut tmp); - for i in 0..8 { - dest[i] = crate::scalar::f32_to_f16(tmp[i]); - } + // f16c is guaranteed by the safety invariant on AvxDescriptor + #[target_feature(enable = "avx2,f16c")] + #[inline] + fn store_f16_impl(v: __m256, dest: &mut [u16]) { + let bits = _mm256_cvtps_ph::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(v); + // SAFETY: dest.len() >= 8 is checked by caller + unsafe { _mm_storeu_si128(dest.as_mut_ptr() as *mut __m128i, bits) }; } + // SAFETY: f16c is available from the safety invariant on the descriptor + unsafe { store_f16_impl(self.0, dest) } } #[inline(always)] diff --git a/jxl_simd/src/x86_64/avx512.rs b/jxl_simd/src/x86_64/avx512.rs index 3ca4381c1..778244c07 100644 --- a/jxl_simd/src/x86_64/avx512.rs +++ b/jxl_simd/src/x86_64/avx512.rs @@ -732,16 +732,14 @@ unsafe impl F32SimdVec for F32VecAvx512 { #[inline(always)] fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { - assert!(mem.len() >= Self::LEN); // AVX512 implies F16C, so we can always use hardware conversion #[target_feature(enable = "avx512f")] #[inline] - unsafe fn load_f16_impl(d: Avx512Descriptor, mem: &[u16]) -> F32VecAvx512 { - // SAFETY: mem.len() >= 16 is checked by caller, and avx512f is available - unsafe { - let bits = _mm256_loadu_si256(mem.as_ptr() as *const __m256i); - F32VecAvx512(_mm512_cvtph_ps(bits), d) - } + fn load_f16_impl(d: Avx512Descriptor, mem: &[u16]) -> F32VecAvx512 { + assert!(mem.len() >= F32VecAvx512::LEN); + // SAFETY: mem.len() >= 16 is checked above + let bits = unsafe { _mm256_loadu_si256(mem.as_ptr() as *const __m256i) }; + F32VecAvx512(_mm512_cvtph_ps(bits), d) } // SAFETY: avx512f is available from the safety invariant on the descriptor unsafe { load_f16_impl(d, mem) } @@ -749,17 +747,14 @@ unsafe impl F32SimdVec for F32VecAvx512 { #[inline(always)] fn store_f16(self, dest: &mut [u16]) { - assert!(dest.len() >= Self::LEN); // AVX512 implies F16C, so we can always use hardware conversion #[target_feature(enable = "avx512f")] #[inline] - unsafe fn store_f16_impl(v: __m512, dest: &mut [u16]) { - // SAFETY: dest.len() >= 16 is checked by caller, and avx512f is available - unsafe { - // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC = 0 - let bits = _mm512_cvtps_ph::<0>(v); - _mm256_storeu_si256(dest.as_mut_ptr() as *mut __m256i, bits); - } + fn store_f16_impl(v: __m512, dest: &mut [u16]) { + assert!(dest.len() >= F32VecAvx512::LEN); + let bits = _mm512_cvtps_ph::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(v); + // SAFETY: dest.len() >= 16 is checked above + unsafe { _mm256_storeu_si256(dest.as_mut_ptr() as *mut __m256i, bits) }; } // SAFETY: avx512f is available from the safety invariant on the descriptor unsafe { store_f16_impl(self.0, dest) } diff --git a/jxl_simd/src/x86_64/sse42.rs b/jxl_simd/src/x86_64/sse42.rs index e654bafc0..bbca88c22 100644 --- a/jxl_simd/src/x86_64/sse42.rs +++ b/jxl_simd/src/x86_64/sse42.rs @@ -615,7 +615,7 @@ unsafe impl F32SimdVec for F32VecSse42 { // SSE4.2 doesn't have F16C, use scalar conversion let mut result = [0.0f32; 4]; for i in 0..4 { - result[i] = crate::scalar::f16_to_f32(mem[i]); + result[i] = crate::f16::from_bits(mem[i]).to_f32(); } Self::load(d, &result) } @@ -627,7 +627,7 @@ unsafe impl F32SimdVec for F32VecSse42 { let mut tmp = [0.0f32; 4]; self.store(&mut tmp); for i in 0..4 { - dest[i] = crate::scalar::f32_to_f16(tmp[i]); + dest[i] = crate::f16::from_f32(tmp[i]).to_bits(); } } From e1f973b8df5f07eec76a0e8b143887b2718cac55 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Tue, 20 Jan 2026 22:40:14 +0100 Subject: [PATCH 06/14] Address reviewer feedback: add I32Vec::store_u16 and remove scalar remainder - Add I32Vec::store_u16() method to extract lower 16 bits from each i32 lane and store as u16 values, implemented for all SIMD backends - Remove scalar remainder handling in int_to_float functions since render pipeline buffers are always padded to SIMD width - Use div_ceil pattern consistent with other SIMD functions in convert.rs --- jxl/src/render/stages/convert.rs | 30 ++++++++---------------------- jxl_simd/src/aarch64/neon.rs | 11 +++++++++++ jxl_simd/src/lib.rs | 4 ++++ jxl_simd/src/scalar.rs | 5 +++++ jxl_simd/src/x86_64/avx.rs | 17 +++++++++++++++++ jxl_simd/src/x86_64/avx512.rs | 17 +++++++++++++++++ jxl_simd/src/x86_64/sse42.rs | 20 ++++++++++++++++++++ 7 files changed, 82 insertions(+), 22 deletions(-) diff --git a/jxl/src/render/stages/convert.rs b/jxl/src/render/stages/convert.rs index 7c1ecf851..891d3695e 100644 --- a/jxl/src/render/stages/convert.rs +++ b/jxl/src/render/stages/convert.rs @@ -141,23 +141,16 @@ simd_function!( d: D, fn int_to_float_32bit_simd(input: &[i32], output: &mut [f32], xsize: usize) { let simd_width = D::I32Vec::LEN; - let num_full_chunks = xsize / simd_width; - // Process full SIMD chunks + // Process SIMD vectors using div_ceil (buffers are padded) for (in_chunk, out_chunk) in input .chunks_exact(simd_width) .zip(output.chunks_exact_mut(simd_width)) - .take(num_full_chunks) + .take(xsize.div_ceil(simd_width)) { let val = D::I32Vec::load(d, in_chunk); val.bitcast_to_f32().store(out_chunk); } - - // Handle remainder with scalar - let remainder_start = num_full_chunks * simd_width; - for i in remainder_start..xsize { - output[i] = f32::from_bits(input[i] as u32); - } } ); @@ -168,31 +161,24 @@ simd_function!( d: D, fn int_to_float_16bit_simd(input: &[i32], output: &mut [f32], xsize: usize) { let simd_width = D::F32Vec::LEN; - let num_full_chunks = xsize / simd_width; - // Temporary buffer for converting i32 -> u16 + // Temporary buffer for i32->u16 conversion via SIMD // Stack-allocated for common SIMD widths (up to 16 elements for AVX-512) let mut u16_buf = [0u16; 16]; + // Process SIMD vectors using div_ceil (buffers are padded) for (in_chunk, out_chunk) in input .chunks_exact(simd_width) .zip(output.chunks_exact_mut(simd_width)) - .take(num_full_chunks) + .take(xsize.div_ceil(simd_width)) { - // Convert i32 values to u16 (f16 bit patterns are in lower 16 bits) - for (i, &val) in in_chunk.iter().enumerate() { - u16_buf[i] = val as u16; - } + // Use SIMD to extract lower 16 bits from each i32 lane + let i32_vec = D::I32Vec::load(d, in_chunk); + i32_vec.store_u16(&mut u16_buf[..simd_width]); // Use hardware f16->f32 conversion let result = D::F32Vec::load_f16_bits(d, &u16_buf[..simd_width]); result.store(out_chunk); } - - // Handle remainder with scalar - let remainder_start = num_full_chunks * simd_width; - for i in remainder_start..xsize { - output[i] = jxl_simd::f16::from_bits(input[i] as u16).to_f32(); - } } ); diff --git a/jxl_simd/src/aarch64/neon.rs b/jxl_simd/src/aarch64/neon.rs index b5c73f354..7ea7eae57 100644 --- a/jxl_simd/src/aarch64/neon.rs +++ b/jxl_simd/src/aarch64/neon.rs @@ -689,6 +689,17 @@ impl I32SimdVec for I32VecNeon { // SAFETY: We know neon is available from the safety invariant on `self.1`. unsafe { Self(vshrq_n_s32::(self.0), self.1) } } + + #[inline(always)] + fn store_u16(self, dest: &mut [u16]) { + assert!(dest.len() >= Self::LEN); + // SAFETY: We know neon is available from the safety invariant on `self.1`. + unsafe { + // vmovn narrows i32 to i16 by taking the lower 16 bits + let narrowed = vmovn_s32(self.0); + vst1_u16(dest.as_mut_ptr(), vreinterpret_u16_s16(narrowed)); + } + } } impl Add for I32VecNeon { diff --git a/jxl_simd/src/lib.rs b/jxl_simd/src/lib.rs index 27df0944a..86d532ea7 100644 --- a/jxl_simd/src/lib.rs +++ b/jxl_simd/src/lib.rs @@ -340,6 +340,10 @@ pub trait I32SimdVec: fn shr(self) -> Self; fn mul_wide_take_high(self, rhs: Self) -> Self; + + /// Stores the lower 16 bits of each i32 lane as u16 values. + /// Requires `dest.len() >= Self::LEN` or it will panic. + fn store_u16(self, dest: &mut [u16]); } pub trait U32SimdVec: Sized + Copy + Debug + Send + Sync { diff --git a/jxl_simd/src/scalar.rs b/jxl_simd/src/scalar.rs index f971ff84d..a7fae8fd7 100644 --- a/jxl_simd/src/scalar.rs +++ b/jxl_simd/src/scalar.rs @@ -305,6 +305,11 @@ impl I32SimdVec for Wrapping { fn mul_wide_take_high(self, rhs: Self) -> Self { Wrapping(((self.0 as i64 * rhs.0 as i64) >> 32) as i32) } + + #[inline(always)] + fn store_u16(self, dest: &mut [u16]) { + dest[0] = self.0 as u16; + } } impl U32SimdVec for Wrapping { diff --git a/jxl_simd/src/x86_64/avx.rs b/jxl_simd/src/x86_64/avx.rs index 581e48057..e87675272 100644 --- a/jxl_simd/src/x86_64/avx.rs +++ b/jxl_simd/src/x86_64/avx.rs @@ -879,6 +879,23 @@ impl I32SimdVec for I32VecAvx { let p1 = _mm256_unpackhi_epi32(l, h); I32VecAvx(_mm256_unpackhi_epi64(p0, p1), this.1) }); + + #[inline(always)] + fn store_u16(self, dest: &mut [u16]) { + assert!(dest.len() >= Self::LEN); + #[target_feature(enable = "avx2")] + #[inline] + fn store_u16_impl(v: __m256i, dest: &mut [u16]) { + let mut tmp = [0i32; 8]; + // SAFETY: tmp has 8 elements, matching LEN + unsafe { _mm256_storeu_si256(tmp.as_mut_ptr() as *mut __m256i, v) }; + for i in 0..8 { + dest[i] = tmp[i] as u16; + } + } + // SAFETY: avx2 is available from the safety invariant on the descriptor. + unsafe { store_u16_impl(self.0, dest) } + } } impl Add for I32VecAvx { diff --git a/jxl_simd/src/x86_64/avx512.rs b/jxl_simd/src/x86_64/avx512.rs index 778244c07..5e9767d31 100644 --- a/jxl_simd/src/x86_64/avx512.rs +++ b/jxl_simd/src/x86_64/avx512.rs @@ -1055,6 +1055,23 @@ impl I32SimdVec for I32VecAvx512 { let idx = _mm512_setr_epi32(1, 17, 3, 19, 5, 21, 7, 23, 9, 25, 11, 27, 13, 29, 15, 31); I32VecAvx512(_mm512_permutex2var_epi32(l, idx, h), this.1) }); + + #[inline(always)] + fn store_u16(self, dest: &mut [u16]) { + assert!(dest.len() >= Self::LEN); + #[target_feature(enable = "avx512f")] + #[inline] + fn store_u16_impl(v: __m512i, dest: &mut [u16]) { + let mut tmp = [0i32; 16]; + // SAFETY: tmp has 16 elements, matching LEN + unsafe { _mm512_storeu_si512(tmp.as_mut_ptr() as *mut __m512i, v) }; + for i in 0..16 { + dest[i] = tmp[i] as u16; + } + } + // SAFETY: avx512f is available from the safety invariant on the descriptor. + unsafe { store_u16_impl(self.0, dest) } + } } impl Add for I32VecAvx512 { diff --git a/jxl_simd/src/x86_64/sse42.rs b/jxl_simd/src/x86_64/sse42.rs index bbca88c22..5532e0840 100644 --- a/jxl_simd/src/x86_64/sse42.rs +++ b/jxl_simd/src/x86_64/sse42.rs @@ -811,6 +811,26 @@ impl I32SimdVec for I32VecSse42 { let p1 = _mm_unpackhi_epi32(l, h); I32VecSse42(_mm_unpackhi_epi64(p0, p1), this.1) }); + + #[inline(always)] + fn store_u16(self, dest: &mut [u16]) { + assert!(dest.len() >= Self::LEN); + // Pack i32 to i16 with signed saturation, then store lower 64 bits + // _mm_packs_epi32 saturates i32 to i16, which preserves low 16 bits for values in range + #[target_feature(enable = "sse4.2")] + #[inline] + fn store_u16_impl(v: __m128i, dest: &mut [u16]) { + // Use scalar loop since _mm_packs_epi32 would saturate incorrectly for unsigned values + let mut tmp = [0i32; 4]; + // SAFETY: tmp has 4 elements, matching LEN + unsafe { _mm_storeu_si128(tmp.as_mut_ptr() as *mut __m128i, v) }; + for i in 0..4 { + dest[i] = tmp[i] as u16; + } + } + // SAFETY: sse4.2 is available from the safety invariant on the descriptor. + unsafe { store_u16_impl(self.0, dest) } + } } impl Add for I32VecSse42 { From 01b06007594ee03ae2a7fd12c5adc78b43e33337 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Tue, 20 Jan 2026 22:42:24 +0100 Subject: [PATCH 07/14] Add TODO for SIMD optimization of generic float format conversion --- jxl/src/render/stages/convert.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/jxl/src/render/stages/convert.rs b/jxl/src/render/stages/convert.rs index 891d3695e..f4c26f528 100644 --- a/jxl/src/render/stages/convert.rs +++ b/jxl/src/render/stages/convert.rs @@ -208,6 +208,7 @@ fn int_to_float(input: &[i32], output: &mut [f32], bit_depth: &BitDepth) { } // Generic scalar conversion for arbitrary bit-depth floats +// TODO: SIMD optimization for custom float formats fn int_to_float_generic(input: &[i32], output: &mut [f32], bits: u32, exp_bits: u32) { let exp_bias = (1 << (exp_bits - 1)) - 1; let sign_shift = bits - 1; From 718a85d869e688897a5db35ca0bc48c6730ca0c6 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Tue, 20 Jan 2026 22:44:04 +0100 Subject: [PATCH 08/14] Move assert inside inner function for load_f16_bits/store_f16 in AVX --- jxl_simd/src/x86_64/avx.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jxl_simd/src/x86_64/avx.rs b/jxl_simd/src/x86_64/avx.rs index e87675272..5a11cdd9f 100644 --- a/jxl_simd/src/x86_64/avx.rs +++ b/jxl_simd/src/x86_64/avx.rs @@ -673,12 +673,12 @@ unsafe impl F32SimdVec for F32VecAvx { #[inline(always)] fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { - assert!(mem.len() >= Self::LEN); // f16c is guaranteed by the safety invariant on AvxDescriptor #[target_feature(enable = "avx2,f16c")] #[inline] fn load_f16_impl(d: AvxDescriptor, mem: &[u16]) -> F32VecAvx { - // SAFETY: mem.len() >= 8 is checked by caller + assert!(mem.len() >= F32VecAvx::LEN); + // SAFETY: mem.len() >= 8 is checked above let bits = unsafe { _mm_loadu_si128(mem.as_ptr() as *const __m128i) }; F32VecAvx(_mm256_cvtph_ps(bits), d) } @@ -688,13 +688,13 @@ unsafe impl F32SimdVec for F32VecAvx { #[inline(always)] fn store_f16(self, dest: &mut [u16]) { - assert!(dest.len() >= Self::LEN); // f16c is guaranteed by the safety invariant on AvxDescriptor #[target_feature(enable = "avx2,f16c")] #[inline] fn store_f16_impl(v: __m256, dest: &mut [u16]) { + assert!(dest.len() >= F32VecAvx::LEN); let bits = _mm256_cvtps_ph::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(v); - // SAFETY: dest.len() >= 8 is checked by caller + // SAFETY: dest.len() >= 8 is checked above unsafe { _mm_storeu_si128(dest.as_mut_ptr() as *mut __m128i, bits) }; } // SAFETY: f16c is available from the safety invariant on the descriptor From db05b2f20190fdda323a68456f8388f76681f8c9 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Tue, 20 Jan 2026 22:46:47 +0100 Subject: [PATCH 09/14] Remove unnecessary f16c comments in AVX load/store --- jxl_simd/src/x86_64/avx.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/jxl_simd/src/x86_64/avx.rs b/jxl_simd/src/x86_64/avx.rs index 5a11cdd9f..aef688784 100644 --- a/jxl_simd/src/x86_64/avx.rs +++ b/jxl_simd/src/x86_64/avx.rs @@ -673,7 +673,6 @@ unsafe impl F32SimdVec for F32VecAvx { #[inline(always)] fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { - // f16c is guaranteed by the safety invariant on AvxDescriptor #[target_feature(enable = "avx2,f16c")] #[inline] fn load_f16_impl(d: AvxDescriptor, mem: &[u16]) -> F32VecAvx { @@ -688,7 +687,6 @@ unsafe impl F32SimdVec for F32VecAvx { #[inline(always)] fn store_f16(self, dest: &mut [u16]) { - // f16c is guaranteed by the safety invariant on AvxDescriptor #[target_feature(enable = "avx2,f16c")] #[inline] fn store_f16_impl(v: __m256, dest: &mut [u16]) { From 469ee6461b04e6685cbfb0509d99e6a8c669b9b5 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Tue, 20 Jan 2026 23:02:26 +0100 Subject: [PATCH 10/14] Rename store_f16 to store_f16_bits for consistency The method takes &mut [u16] (raw bits), so the name should match load_f16_bits for consistency. --- jxl_simd/src/aarch64/neon.rs | 2 +- jxl_simd/src/lib.rs | 2 +- jxl_simd/src/scalar.rs | 2 +- jxl_simd/src/x86_64/avx.rs | 6 +++--- jxl_simd/src/x86_64/avx512.rs | 6 +++--- jxl_simd/src/x86_64/sse42.rs | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/jxl_simd/src/aarch64/neon.rs b/jxl_simd/src/aarch64/neon.rs index 7ea7eae57..c6e5f7ecf 100644 --- a/jxl_simd/src/aarch64/neon.rs +++ b/jxl_simd/src/aarch64/neon.rs @@ -442,7 +442,7 @@ unsafe impl F32SimdVec for F32VecNeon { } } - fn store_f16(this: F32VecNeon, dest: &mut [u16]) { + fn store_f16_bits(this: F32VecNeon, dest: &mut [u16]) { assert!(dest.len() >= F32VecNeon::LEN); // Use inline asm because Rust stdarch incorrectly requires fp16 target feature // for vcvt_f16_f32 (fixed in https://github.com/rust-lang/stdarch/pull/1978) diff --git a/jxl_simd/src/lib.rs b/jxl_simd/src/lib.rs index 86d532ea7..2992bc954 100644 --- a/jxl_simd/src/lib.rs +++ b/jxl_simd/src/lib.rs @@ -282,7 +282,7 @@ pub unsafe trait F32SimdVec: /// Converts f32 values to f16 and stores as u16 bit patterns. /// Uses hardware conversion instructions when available (F16C on x86, NEON fp16 on ARM). /// Requires `dest.len() >= Self::LEN` or it will panic. - fn store_f16(self, dest: &mut [u16]); + fn store_f16_bits(self, dest: &mut [u16]); } pub trait I32SimdVec: diff --git a/jxl_simd/src/scalar.rs b/jxl_simd/src/scalar.rs index a7fae8fd7..f0444c34b 100644 --- a/jxl_simd/src/scalar.rs +++ b/jxl_simd/src/scalar.rs @@ -219,7 +219,7 @@ unsafe impl F32SimdVec for f32 { } #[inline(always)] - fn store_f16(self, dest: &mut [u16]) { + fn store_f16_bits(self, dest: &mut [u16]) { dest[0] = f16::from_f32(self).to_bits(); } diff --git a/jxl_simd/src/x86_64/avx.rs b/jxl_simd/src/x86_64/avx.rs index aef688784..872b1f27d 100644 --- a/jxl_simd/src/x86_64/avx.rs +++ b/jxl_simd/src/x86_64/avx.rs @@ -686,17 +686,17 @@ unsafe impl F32SimdVec for F32VecAvx { } #[inline(always)] - fn store_f16(self, dest: &mut [u16]) { + fn store_f16_bits(self, dest: &mut [u16]) { #[target_feature(enable = "avx2,f16c")] #[inline] - fn store_f16_impl(v: __m256, dest: &mut [u16]) { + fn store_f16_bits_impl(v: __m256, dest: &mut [u16]) { assert!(dest.len() >= F32VecAvx::LEN); let bits = _mm256_cvtps_ph::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(v); // SAFETY: dest.len() >= 8 is checked above unsafe { _mm_storeu_si128(dest.as_mut_ptr() as *mut __m128i, bits) }; } // SAFETY: f16c is available from the safety invariant on the descriptor - unsafe { store_f16_impl(self.0, dest) } + unsafe { store_f16_bits_impl(self.0, dest) } } #[inline(always)] diff --git a/jxl_simd/src/x86_64/avx512.rs b/jxl_simd/src/x86_64/avx512.rs index 5e9767d31..02abfabbb 100644 --- a/jxl_simd/src/x86_64/avx512.rs +++ b/jxl_simd/src/x86_64/avx512.rs @@ -746,18 +746,18 @@ unsafe impl F32SimdVec for F32VecAvx512 { } #[inline(always)] - fn store_f16(self, dest: &mut [u16]) { + fn store_f16_bits(self, dest: &mut [u16]) { // AVX512 implies F16C, so we can always use hardware conversion #[target_feature(enable = "avx512f")] #[inline] - fn store_f16_impl(v: __m512, dest: &mut [u16]) { + fn store_f16_bits_impl(v: __m512, dest: &mut [u16]) { assert!(dest.len() >= F32VecAvx512::LEN); let bits = _mm512_cvtps_ph::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(v); // SAFETY: dest.len() >= 16 is checked above unsafe { _mm256_storeu_si256(dest.as_mut_ptr() as *mut __m256i, bits) }; } // SAFETY: avx512f is available from the safety invariant on the descriptor - unsafe { store_f16_impl(self.0, dest) } + unsafe { store_f16_bits_impl(self.0, dest) } } #[inline(always)] diff --git a/jxl_simd/src/x86_64/sse42.rs b/jxl_simd/src/x86_64/sse42.rs index 5532e0840..bea4e12f4 100644 --- a/jxl_simd/src/x86_64/sse42.rs +++ b/jxl_simd/src/x86_64/sse42.rs @@ -621,7 +621,7 @@ unsafe impl F32SimdVec for F32VecSse42 { } #[inline(always)] - fn store_f16(self, dest: &mut [u16]) { + fn store_f16_bits(self, dest: &mut [u16]) { assert!(dest.len() >= Self::LEN); // SSE4.2 doesn't have F16C, use scalar conversion let mut tmp = [0.0f32; 4]; From 74640c3aa383ae4b778bb252144396771b57aae0 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Tue, 20 Jan 2026 23:15:19 +0100 Subject: [PATCH 11/14] Fix SIMD int_to_float to handle non-SIMD-aligned sizes The SIMD conversion functions were using chunks_exact() which only processes complete SIMD vectors, leaving remainder elements unprocessed. This caused test failures when the row size wasn't divisible by the SIMD width (e.g., 244 pixels with AVX2 width of 8). Fix by adding scalar fallback loops to handle remainder elements for both 32-bit float passthrough and 16-bit float conversion paths. Also use const assert to verify the buffer size assumption at compile time rather than runtime. --- jxl/src/render/stages/convert.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/jxl/src/render/stages/convert.rs b/jxl/src/render/stages/convert.rs index f4c26f528..cc483a4be 100644 --- a/jxl/src/render/stages/convert.rs +++ b/jxl/src/render/stages/convert.rs @@ -142,15 +142,20 @@ simd_function!( fn int_to_float_32bit_simd(input: &[i32], output: &mut [f32], xsize: usize) { let simd_width = D::I32Vec::LEN; - // Process SIMD vectors using div_ceil (buffers are padded) + // Process complete SIMD vectors for (in_chunk, out_chunk) in input .chunks_exact(simd_width) .zip(output.chunks_exact_mut(simd_width)) - .take(xsize.div_ceil(simd_width)) { let val = D::I32Vec::load(d, in_chunk); val.bitcast_to_f32().store(out_chunk); } + + // Handle remainder with scalar fallback + let processed = (xsize / simd_width) * simd_width; + for i in processed..xsize { + output[i] = f32::from_bits(input[i] as u32); + } } ); @@ -163,14 +168,15 @@ simd_function!( let simd_width = D::F32Vec::LEN; // Temporary buffer for i32->u16 conversion via SIMD - // Stack-allocated for common SIMD widths (up to 16 elements for AVX-512) + // Note: Using constant 16 (max AVX-512 width) because D::F32Vec::LEN + // cannot be used as array size in Rust (const generics limitation) + const { assert!(D::F32Vec::LEN <= 16) } let mut u16_buf = [0u16; 16]; - // Process SIMD vectors using div_ceil (buffers are padded) + // Process complete SIMD vectors for (in_chunk, out_chunk) in input .chunks_exact(simd_width) .zip(output.chunks_exact_mut(simd_width)) - .take(xsize.div_ceil(simd_width)) { // Use SIMD to extract lower 16 bits from each i32 lane let i32_vec = D::I32Vec::load(d, in_chunk); @@ -179,6 +185,12 @@ simd_function!( let result = D::F32Vec::load_f16_bits(d, &u16_buf[..simd_width]); result.store(out_chunk); } + + // Handle remainder with scalar f16 conversion + let processed = (xsize / simd_width) * simd_width; + for i in processed..xsize { + output[i] = jxl_simd::f16::from_bits(input[i] as u16).to_f32(); + } } ); From ec6ae4b070fd55e447356f3eb4f4b638c1e5a44e Mon Sep 17 00:00:00 2001 From: Luca Versari Date: Wed, 21 Jan 2026 10:59:47 +0100 Subject: [PATCH 12/14] Add f16c to target features defined on dispatch functions. --- jxl_simd/src/x86_64/mod.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jxl_simd/src/x86_64/mod.rs b/jxl_simd/src/x86_64/mod.rs index 939dd24a9..1a5463b17 100644 --- a/jxl_simd/src/x86_64/mod.rs +++ b/jxl_simd/src/x86_64/mod.rs @@ -62,16 +62,16 @@ macro_rules! simd_function_body_sse42 { #[macro_export] macro_rules! simd_function_body_avx { ($name:ident($($arg:ident: $ty:ty),* $(,)?) $(-> $ret:ty )?; ($($val:expr),* $(,)?)) => { - if cfg!(all(target_feature = "avx2", target_feature = "fma")) { - // SAFETY: we just checked for avx2 and fma. + if cfg!(all(target_feature = "avx2", target_feature = "fma", target_feature = "f16c")) { + // SAFETY: we just checked for avx2, fma and f16c. let d = unsafe { $crate::AvxDescriptor::new_unchecked() }; return $name(d, $($val),*); } else if let Some(d) = $crate::AvxDescriptor::new() { - #[target_feature(enable = "avx2,fma")] + #[target_feature(enable = "avx2,fma,f16c")] fn avx(d: $crate::AvxDescriptor, $($arg: $ty),*) $(-> $ret)? { $name(d, $($val),*) } - // SAFETY: we just checked for avx2 and fma. + // SAFETY: we just checked for avx2, fma and f16c. return unsafe { avx(d, $($arg),*) }; } }; @@ -170,11 +170,11 @@ macro_rules! test_avx { fn [<$name _avx>]() { use $crate::SimdDescriptor; let Some(d) = $crate::AvxDescriptor::new() else { return; }; - #[target_feature(enable = "avx2,fma")] + #[target_feature(enable = "avx2,fma,f16c")] fn inner(d: $crate::AvxDescriptor) { $name(d) } - // SAFETY: we just checked for avx2 and fma. + // SAFETY: we just checked for avx2, fma and f16c. return unsafe { inner(d) }; } } From 2b41c519fd2a64dd2eb74fbb3ad038574ec2efa6 Mon Sep 17 00:00:00 2001 From: Luca Versari Date: Wed, 21 Jan 2026 17:49:47 +0100 Subject: [PATCH 13/14] Remove scalar fallback and fix tests & pipeline code --- jxl/src/render/stages/convert.rs | 165 +++++++------------------------ 1 file changed, 38 insertions(+), 127 deletions(-) diff --git a/jxl/src/render/stages/convert.rs b/jxl/src/render/stages/convert.rs index cc483a4be..1ddd46f59 100644 --- a/jxl/src/render/stages/convert.rs +++ b/jxl/src/render/stages/convert.rs @@ -146,16 +146,11 @@ simd_function!( for (in_chunk, out_chunk) in input .chunks_exact(simd_width) .zip(output.chunks_exact_mut(simd_width)) + .take(xsize.div_ceil(simd_width)) { let val = D::I32Vec::load(d, in_chunk); val.bitcast_to_f32().store(out_chunk); } - - // Handle remainder with scalar fallback - let processed = (xsize / simd_width) * simd_width; - for i in processed..xsize { - output[i] = f32::from_bits(input[i] as u32); - } } ); @@ -177,6 +172,7 @@ simd_function!( for (in_chunk, out_chunk) in input .chunks_exact(simd_width) .zip(output.chunks_exact_mut(simd_width)) + .take(xsize.div_ceil(simd_width)) { // Use SIMD to extract lower 16 bits from each i32 lane let i32_vec = D::I32Vec::load(d, in_chunk); @@ -185,22 +181,15 @@ simd_function!( let result = D::F32Vec::load_f16_bits(d, &u16_buf[..simd_width]); result.store(out_chunk); } - - // Handle remainder with scalar f16 conversion - let processed = (xsize / simd_width) * simd_width; - for i in processed..xsize { - output[i] = jxl_simd::f16::from_bits(input[i] as u16).to_f32(); - } } ); // Converts custom [bits]-bit float (with [exp_bits] exponent bits) stored as // int back to binary32 float. -fn int_to_float(input: &[i32], output: &mut [f32], bit_depth: &BitDepth) { +fn int_to_float(input: &[i32], output: &mut [f32], bit_depth: &BitDepth, xsize: usize) { assert_eq!(input.len(), output.len()); let bits = bit_depth.bits_per_sample(); let exp_bits = bit_depth.exponent_bits_per_sample(); - let xsize = input.len(); // Use SIMD fast paths for common formats if bits == 32 && exp_bits == 8 { @@ -288,12 +277,9 @@ impl RenderPipelineInOutStage for ConvertModularToF32Stage { ) { let input = &input_rows[0]; if self.bit_depth.floating_point_sample() { - int_to_float( - &input[0][..xsize], - &mut output_rows[0][0][..xsize], - &self.bit_depth, - ); + int_to_float(input[0], output_rows[0][0], &self.bit_depth, xsize); } else { + // TODO(veluca): SIMDfy this code. let scale = 1.0 / ((1u64 << self.bit_depth.bits_per_sample()) - 1) as f32; for i in 0..xsize { output_rows[0][0][i] = input[0][i] as f32 * scale; @@ -557,10 +543,15 @@ mod test { 1e-30, 1e30, ]; - let input: Vec = test_values.iter().map(|&f| f.to_bits() as i32).collect(); - let mut output = vec![0.0f32; input.len()]; + let input: Vec = test_values + .iter() + .map(|&f| f.to_bits() as i32) + .chain(std::iter::repeat(0)) + .take(16) + .collect(); + let mut output = vec![0.0f32; 16]; - int_to_float(&input, &mut output, &bit_depth); + int_to_float(&input, &mut output, &bit_depth, test_values.len()); for (i, (&expected, &actual)) in test_values.iter().zip(output.iter()).enumerate() { if expected.is_nan() { @@ -572,33 +563,44 @@ mod test { } #[test] - fn test_int_to_float_16bit_normal() { + fn test_int_to_float_16bit() { // Test 16-bit float (f16) conversion for normal values let bit_depth = BitDepth::f16(); // f16 format: 1 sign, 5 exp, 10 mantissa // Test cases: (f16_bits, expected_f32) let test_cases: Vec<(u16, f32)> = vec![ - (0x0000, 0.0), // +0 - (0x8000, -0.0), // -0 - (0x3C00, 1.0), // 1.0 - (0xBC00, -1.0), // -1.0 - (0x3800, 0.5), // 0.5 - (0x4000, 2.0), // 2.0 - (0x4400, 4.0), // 4.0 - (0x7BFF, 65504.0), // max normal f16 + (0x0000, 0.0), // +0 + (0x8000, -0.0), // -0 + (0x3C00, 1.0), // 1.0 + (0xBC00, -1.0), // -1.0 + (0x3800, 0.5), // 0.5 + (0x4000, 2.0), // 2.0 + (0x4400, 4.0), // 4.0 + (0x7BFF, 65504.0), // max normal f16 + (0x7C00, f32::INFINITY), // +inf + (0xFC00, f32::NEG_INFINITY), // -inf + (0x0001, 5.960_464_5e-8), // smallest positive subnormal + (0x03FF, 6.097_555e-5), // largest positive subnormal + (0x8001, -5.960_464_5e-8), // smallest negative subnormal ]; - let input: Vec = test_cases.iter().map(|(bits, _)| *bits as i32).collect(); - let mut output = vec![0.0f32; input.len()]; + let input: Vec = test_cases + .iter() + .map(|(bits, _)| *bits as i32) + .chain(std::iter::repeat(0)) + .take(16) + .collect(); + let mut output = vec![0.0f32; 16]; - int_to_float(&input, &mut output, &bit_depth); + int_to_float(&input, &mut output, &bit_depth, test_cases.len()); - for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() { + for (i, (&(_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() { assert!( (expected - actual).abs() < 1e-6 + || expected == actual || (expected.is_sign_negative() == actual.is_sign_negative() - && *expected == 0.0 + && expected == 0.0 && actual == 0.0), "index {}: expected {}, got {}", i, @@ -607,95 +609,4 @@ mod test { ); } } - - #[test] - fn test_int_to_float_16bit_special() { - // Test 16-bit float conversion for special values (inf, nan) - let bit_depth = BitDepth::f16(); - - let test_cases: Vec<(u16, f32)> = vec![ - (0x7C00, f32::INFINITY), // +inf - (0xFC00, f32::NEG_INFINITY), // -inf - ]; - - let input: Vec = test_cases.iter().map(|(bits, _)| *bits as i32).collect(); - let mut output = vec![0.0f32; input.len()]; - - int_to_float(&input, &mut output, &bit_depth); - - for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() { - assert_eq!( - *expected, actual, - "index {}: expected {}, got {}", - i, expected, actual - ); - } - } - - #[test] - fn test_int_to_float_16bit_subnormal() { - // Test 16-bit float conversion for subnormal values - let bit_depth = BitDepth::f16(); - - // Verify bit_depth is set correctly - assert_eq!(bit_depth.bits_per_sample(), 16); - assert_eq!(bit_depth.exponent_bits_per_sample(), 5); - assert!(bit_depth.floating_point_sample()); - - // Smallest subnormal: 2^-24 ≈ 5.96e-8 - // Largest subnormal: (2^10 - 1) * 2^-24 ≈ 6.10e-5 - let test_cases: Vec<(u16, f32)> = vec![ - (0x0001, 5.960_464_5e-8), // smallest positive subnormal - (0x03FF, 6.097_555e-5), // largest positive subnormal - (0x8001, -5.960_464_5e-8), // smallest negative subnormal - ]; - - // First test the scalar function directly - for (bits, expected) in &test_cases { - let scalar_result = jxl_simd::f16::from_bits(*bits).to_f32(); - let rel_err = ((expected - scalar_result) / expected).abs(); - assert!( - rel_err < 1e-6, - "scalar: bits=0x{:04X}, expected {}, got {}, rel_err {}", - bits, - expected, - scalar_result, - rel_err - ); - } - - // Test through int_to_float_generic (which should match scalar) - let input: Vec = test_cases.iter().map(|(bits, _)| *bits as i32).collect(); - let mut generic_output = vec![0.0f32; input.len()]; - int_to_float_generic(&input, &mut generic_output, 16, 5); - for (i, ((_, expected), &actual)) in - test_cases.iter().zip(generic_output.iter()).enumerate() - { - let rel_err = ((expected - actual) / expected).abs(); - assert!( - rel_err < 1e-6, - "generic index {}: expected {}, got {}, rel_err {}", - i, - expected, - actual, - rel_err - ); - } - - // Now test through the main function (uses SIMD dispatch) - let mut output = vec![0.0f32; input.len()]; - int_to_float(&input, &mut output, &bit_depth); - - for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() { - let rel_err = ((expected - actual) / expected).abs(); - assert!( - rel_err < 1e-6, - "simd index {}: expected {}, got {}, rel_err {}", - i, - expected, - actual, - rel_err - ); - } - } } From c1f53216e6a7f3fcfae8e1cf712232d4b675b585 Mon Sep 17 00:00:00 2001 From: Luca Versari Date: Wed, 21 Jan 2026 18:16:53 +0100 Subject: [PATCH 14/14] Improve store_u16 impl & some safety comments. --- jxl/src/image/internal.rs | 3 +-- jxl_simd/src/aarch64/neon.rs | 6 ++++-- jxl_simd/src/lib.rs | 36 +++++++++++++++++++++++++++++++++++ jxl_simd/src/x86_64/avx.rs | 27 ++++++++++++++++---------- jxl_simd/src/x86_64/avx512.rs | 14 ++++++-------- jxl_simd/src/x86_64/sse42.rs | 2 +- 6 files changed, 65 insertions(+), 23 deletions(-) diff --git a/jxl/src/image/internal.rs b/jxl/src/image/internal.rs index 822842c90..0c77b4b30 100644 --- a/jxl/src/image/internal.rs +++ b/jxl/src/image/internal.rs @@ -165,8 +165,7 @@ impl RawImageBuffer { // invariant. let start = unsafe { self.buf.add(start) }; // SAFETY: due to the struct safety invariant, we know the entire slice is in a range of - // memory valid for writes. Moreover, the caller promises not to write uninitialized data - // in the returned slice. Finally, the caller guarantees aliasing rules will not be violated. + // memory valid for reads. The caller guarantees aliasing rules will not be violated. unsafe { std::slice::from_raw_parts(start, self.bytes_per_row) } } diff --git a/jxl_simd/src/aarch64/neon.rs b/jxl_simd/src/aarch64/neon.rs index c6e5f7ecf..c0d649939 100644 --- a/jxl_simd/src/aarch64/neon.rs +++ b/jxl_simd/src/aarch64/neon.rs @@ -486,7 +486,8 @@ unsafe impl F32SimdVec for F32VecNeon { fn prepare_impl(table: &[f32; 8]) -> uint8x16_t { // Convert f32 table to BF16 packed in 128 bits (16 bytes for 8 entries) // BF16 is the high 16 bits of f32 - // SAFETY: neon is available from target_feature + // SAFETY: neon is available from target_feature, and `table` is large + // enough for the loads. let (table_lo, table_hi) = unsafe { (vld1q_f32(table.as_ptr()), vld1q_f32(table.as_ptr().add(4))) }; @@ -693,7 +694,8 @@ impl I32SimdVec for I32VecNeon { #[inline(always)] fn store_u16(self, dest: &mut [u16]) { assert!(dest.len() >= Self::LEN); - // SAFETY: We know neon is available from the safety invariant on `self.1`. + // SAFETY: We know neon is available from the safety invariant on `self.1`, + // and we just checked that `dest` has enough space. unsafe { // vmovn narrows i32 to i16 by taking the lower 16 bits let narrowed = vmovn_s32(self.0); diff --git a/jxl_simd/src/lib.rs b/jxl_simd/src/lib.rs index 2992bc954..4f06dbddc 100644 --- a/jxl_simd/src/lib.rs +++ b/jxl_simd/src/lib.rs @@ -1179,4 +1179,40 @@ mod test { } } test_all_instruction_sets!(test_i32_mul_all_elements); + + fn test_store_u16(d: D) { + let data = [ + 0xbabau32 as i32, + 0x1234u32 as i32, + 0xdeadbabau32 as i32, + 0xdead1234u32 as i32, + 0x1111babau32 as i32, + 0x11111234u32 as i32, + 0x76543210u32 as i32, + 0x01234567u32 as i32, + 0x00000000u32 as i32, + 0xffffffffu32 as i32, + 0x23949289u32 as i32, + 0xf9371913u32 as i32, + 0xdeadbeefu32 as i32, + 0xbeefdeadu32 as i32, + 0xaaaaaaaau32 as i32, + 0xbbbbbbbbu32 as i32, + ]; + let mut output = [0u16; 16]; + for i in (0..16).step_by(D::I32Vec::LEN) { + let vec = D::I32Vec::load(d, &data[i..]); + vec.store_u16(&mut output[i..]); + } + + for i in 0..16 { + let expected = data[i] as u16; + assert_eq!( + output[i], expected, + "store_u16 failed at index {}: expected {}, got {}", + i, expected, output[i] + ); + } + } + test_all_instruction_sets!(test_store_u16); } diff --git a/jxl_simd/src/x86_64/avx.rs b/jxl_simd/src/x86_64/avx.rs index 872b1f27d..0da8ec9f0 100644 --- a/jxl_simd/src/x86_64/avx.rs +++ b/jxl_simd/src/x86_64/avx.rs @@ -607,7 +607,8 @@ unsafe impl F32SimdVec for F32VecAvx { #[inline(always)] fn prepare_table_bf16_8(_d: AvxDescriptor, table: &[f32; 8]) -> Bf16Table8Avx { // For AVX2, vpermps is exact and fast, so we just load the table as-is - // SAFETY: avx2 is available from the safety invariant on the descriptor + // SAFETY: avx2 is available from the safety invariant on the descriptor, + // and `table` has 8 elements, exactly as many as we load. Bf16Table8Avx(unsafe { _mm256_loadu_ps(table.as_ptr()) }) } @@ -681,7 +682,7 @@ unsafe impl F32SimdVec for F32VecAvx { let bits = unsafe { _mm_loadu_si128(mem.as_ptr() as *const __m128i) }; F32VecAvx(_mm256_cvtph_ps(bits), d) } - // SAFETY: f16c is available from the safety invariant on the descriptor + // SAFETY: avx2 and f16c are available from the safety invariant on the descriptor unsafe { load_f16_impl(d, mem) } } @@ -695,7 +696,7 @@ unsafe impl F32SimdVec for F32VecAvx { // SAFETY: dest.len() >= 8 is checked above unsafe { _mm_storeu_si128(dest.as_mut_ptr() as *mut __m128i, bits) }; } - // SAFETY: f16c is available from the safety invariant on the descriptor + // SAFETY: avx2 and f16c are available from the safety invariant on the descriptor unsafe { store_f16_bits_impl(self.0, dest) } } @@ -880,16 +881,22 @@ impl I32SimdVec for I32VecAvx { #[inline(always)] fn store_u16(self, dest: &mut [u16]) { - assert!(dest.len() >= Self::LEN); #[target_feature(enable = "avx2")] #[inline] fn store_u16_impl(v: __m256i, dest: &mut [u16]) { - let mut tmp = [0i32; 8]; - // SAFETY: tmp has 8 elements, matching LEN - unsafe { _mm256_storeu_si256(tmp.as_mut_ptr() as *mut __m256i, v) }; - for i in 0..8 { - dest[i] = tmp[i] as u16; - } + assert!(dest.len() >= I32VecAvx::LEN); + let tmp = _mm256_shuffle_epi8( + v, + _mm256_setr_epi8( + 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, + ), + ); + let tmp = _mm256_permute4x64_epi64(tmp, 0xD8); + // SAFETY: we just checked that `dest` has enough space. + unsafe { + _mm_storeu_si128(dest.as_mut_ptr().cast(), _mm256_extracti128_si256::<0>(tmp)) + }; } // SAFETY: avx2 is available from the safety invariant on the descriptor. unsafe { store_u16_impl(self.0, dest) } diff --git a/jxl_simd/src/x86_64/avx512.rs b/jxl_simd/src/x86_64/avx512.rs index 02abfabbb..89086c50c 100644 --- a/jxl_simd/src/x86_64/avx512.rs +++ b/jxl_simd/src/x86_64/avx512.rs @@ -665,7 +665,8 @@ unsafe impl F32SimdVec for F32VecAvx512 { #[target_feature(enable = "avx512f")] #[inline] fn prepare_impl(table: &[f32; 8]) -> __m512 { - // SAFETY: avx512f is available from target_feature + // SAFETY: avx512f is available from target_feature, and we load 8 elements, + // exactly as many as are present in `table`. let table_256 = unsafe { _mm256_loadu_ps(table.as_ptr()) }; // Zero-extend to 512-bit; vpermutexvar with indices 0-7 only reads first 256 bits _mm512_castps256_ps512(table_256) @@ -1058,16 +1059,13 @@ impl I32SimdVec for I32VecAvx512 { #[inline(always)] fn store_u16(self, dest: &mut [u16]) { - assert!(dest.len() >= Self::LEN); #[target_feature(enable = "avx512f")] #[inline] fn store_u16_impl(v: __m512i, dest: &mut [u16]) { - let mut tmp = [0i32; 16]; - // SAFETY: tmp has 16 elements, matching LEN - unsafe { _mm512_storeu_si512(tmp.as_mut_ptr() as *mut __m512i, v) }; - for i in 0..16 { - dest[i] = tmp[i] as u16; - } + assert!(dest.len() >= I32VecAvx512::LEN); + let tmp = _mm512_cvtepi32_epi16(v); + // SAFETY: We just checked `dst` has enough space. + unsafe { _mm256_storeu_epi32(dest.as_mut_ptr().cast(), tmp) }; } // SAFETY: avx512f is available from the safety invariant on the descriptor. unsafe { store_u16_impl(self.0, dest) } diff --git a/jxl_simd/src/x86_64/sse42.rs b/jxl_simd/src/x86_64/sse42.rs index bea4e12f4..b4021570c 100644 --- a/jxl_simd/src/x86_64/sse42.rs +++ b/jxl_simd/src/x86_64/sse42.rs @@ -814,12 +814,12 @@ impl I32SimdVec for I32VecSse42 { #[inline(always)] fn store_u16(self, dest: &mut [u16]) { - assert!(dest.len() >= Self::LEN); // Pack i32 to i16 with signed saturation, then store lower 64 bits // _mm_packs_epi32 saturates i32 to i16, which preserves low 16 bits for values in range #[target_feature(enable = "sse4.2")] #[inline] fn store_u16_impl(v: __m128i, dest: &mut [u16]) { + assert!(dest.len() >= I32VecSse42::LEN); // Use scalar loop since _mm_packs_epi32 would saturate incorrectly for unsigned values let mut tmp = [0i32; 4]; // SAFETY: tmp has 4 elements, matching LEN