diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/FieldProductAgg.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/FieldProductAgg.java index 53ccfb94b304..dc14b2917094 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/FieldProductAgg.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/aggregate/FieldProductAgg.java @@ -57,16 +57,16 @@ public Object agg(Object accumulator, Object inputField) { product = fromBigDecimal(mul, mergeFieldDD.precision(), mergeFieldDD.scale()); break; case TINYINT: - product = (byte) ((byte) accumulator * (byte) inputField); + product = multiplyExactByte((byte) accumulator, (byte) inputField); break; case SMALLINT: - product = (short) ((short) accumulator * (short) inputField); + product = multiplyExactShort((short) accumulator, (short) inputField); break; case INTEGER: - product = (int) accumulator * (int) inputField; + product = multiplyExactInt((int) accumulator, (int) inputField); break; case BIGINT: - product = (long) accumulator * (long) inputField; + product = multiplyExactLong((long) accumulator, (long) inputField); break; case FLOAT: product = (float) accumulator * (float) inputField; @@ -84,6 +84,72 @@ public Object agg(Object accumulator, Object inputField) { return product; } + private static byte multiplyExactByte(byte a, byte b) { + int value = a * b; + if (value > Byte.MAX_VALUE || value < Byte.MIN_VALUE) { + throw new ArithmeticException( + String.format("byte overflow: %d * %d = %d", a, b, value)); + } + return (byte) value; + } + + private static short multiplyExactShort(short a, short b) { + int value = a * b; + if (value > Short.MAX_VALUE || value < Short.MIN_VALUE) { + throw new ArithmeticException( + String.format("short overflow: %d * %d = %d", a, b, value)); + } + return (short) value; + } + + private static int multiplyExactInt(int a, int b) { + try { + return Math.multiplyExact(a, b); + } catch (ArithmeticException e) { + throw new ArithmeticException(String.format("int overflow: %d * %d", a, b)); + } + } + + private static long multiplyExactLong(long a, long b) { + try { + return Math.multiplyExact(a, b); + } catch (ArithmeticException e) { + throw new ArithmeticException(String.format("long overflow: %d * %d", a, b)); + } + } + + private static byte divideExactByte(byte a, byte b) { + int value = a / b; + if (value > Byte.MAX_VALUE || value < Byte.MIN_VALUE) { + throw new ArithmeticException( + String.format("byte overflow: %d / %d = %d", a, b, value)); + } + return (byte) value; + } + + private static short divideExactShort(short a, short b) { + int value = a / b; + if (value > Short.MAX_VALUE || value < Short.MIN_VALUE) { + throw new ArithmeticException( + String.format("short overflow: %d / %d = %d", a, b, value)); + } + return (short) value; + } + + private static int divideExactInt(int a, int b) { + if (a == Integer.MIN_VALUE && b == -1) { + throw new ArithmeticException(String.format("int overflow: %d / %d", a, b)); + } + return a / b; + } + + private static long divideExactLong(long a, long b) { + if (a == Long.MIN_VALUE && b == -1L) { + throw new ArithmeticException(String.format("long overflow: %d / %d", a, b)); + } + return a / b; + } + @Override public Object retract(Object accumulator, Object inputField) { Object product; @@ -105,16 +171,16 @@ public Object retract(Object accumulator, Object inputField) { product = fromBigDecimal(div, mergeFieldDD.precision(), mergeFieldDD.scale()); break; case TINYINT: - product = (byte) ((byte) accumulator / (byte) inputField); + product = divideExactByte((byte) accumulator, (byte) inputField); break; case SMALLINT: - product = (short) ((short) accumulator / (short) inputField); + product = divideExactShort((short) accumulator, (short) inputField); break; case INTEGER: - product = (int) accumulator / (int) inputField; + product = divideExactInt((int) accumulator, (int) inputField); break; case BIGINT: - product = (long) accumulator / (long) inputField; + product = divideExactLong((long) accumulator, (long) inputField); break; case FLOAT: product = (float) accumulator / (float) inputField; diff --git a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/aggregate/FieldAggregatorTest.java b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/aggregate/FieldAggregatorTest.java index f040a5eb731a..1bdddcb846e1 100644 --- a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/aggregate/FieldAggregatorTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/aggregate/FieldAggregatorTest.java @@ -503,6 +503,78 @@ public void testFieldProductLongAgg() { assertThat(fieldProductAgg.retract(null, 5L)).isNull(); } + @Test + public void testFieldProductByteOverflow() { + FieldProductAgg fieldProductAgg = + new FieldProductAggFactory().create(new TinyIntType(), null, null); + assertThatThrownBy(() -> fieldProductAgg.agg((byte) 64, (byte) 2)) + .isInstanceOf(ArithmeticException.class); + assertThatThrownBy(() -> fieldProductAgg.agg((byte) -64, (byte) 4)) + .isInstanceOf(ArithmeticException.class); + } + + @Test + public void testFieldProductShortOverflow() { + FieldProductAgg fieldProductAgg = + new FieldProductAggFactory().create(new SmallIntType(), null, null); + assertThatThrownBy(() -> fieldProductAgg.agg((short) 1000, (short) 100)) + .isInstanceOf(ArithmeticException.class); + assertThatThrownBy(() -> fieldProductAgg.agg(Short.MIN_VALUE, (short) 2)) + .isInstanceOf(ArithmeticException.class); + } + + @Test + public void testFieldProductIntOverflow() { + FieldProductAgg fieldProductAgg = + new FieldProductAggFactory().create(new IntType(), null, null); + assertThatThrownBy(() -> fieldProductAgg.agg(100_000, 100_000)) + .isInstanceOf(ArithmeticException.class); + assertThatThrownBy(() -> fieldProductAgg.agg(Integer.MIN_VALUE, -1)) + .isInstanceOf(ArithmeticException.class); + } + + @Test + public void testFieldProductLongOverflow() { + FieldProductAgg fieldProductAgg = + new FieldProductAggFactory().create(new BigIntType(), null, null); + assertThatThrownBy(() -> fieldProductAgg.agg(Long.MAX_VALUE, 2L)) + .isInstanceOf(ArithmeticException.class); + assertThatThrownBy(() -> fieldProductAgg.agg(Long.MIN_VALUE, -1L)) + .isInstanceOf(ArithmeticException.class); + } + + @Test + public void testFieldProductByteRetractOverflow() { + FieldProductAgg fieldProductAgg = + new FieldProductAggFactory().create(new TinyIntType(), null, null); + assertThatThrownBy(() -> fieldProductAgg.retract(Byte.MIN_VALUE, (byte) -1)) + .isInstanceOf(ArithmeticException.class); + } + + @Test + public void testFieldProductShortRetractOverflow() { + FieldProductAgg fieldProductAgg = + new FieldProductAggFactory().create(new SmallIntType(), null, null); + assertThatThrownBy(() -> fieldProductAgg.retract(Short.MIN_VALUE, (short) -1)) + .isInstanceOf(ArithmeticException.class); + } + + @Test + public void testFieldProductIntRetractOverflow() { + FieldProductAgg fieldProductAgg = + new FieldProductAggFactory().create(new IntType(), null, null); + assertThatThrownBy(() -> fieldProductAgg.retract(Integer.MIN_VALUE, -1)) + .isInstanceOf(ArithmeticException.class); + } + + @Test + public void testFieldProductLongRetractOverflow() { + FieldProductAgg fieldProductAgg = + new FieldProductAggFactory().create(new BigIntType(), null, null); + assertThatThrownBy(() -> fieldProductAgg.retract(Long.MIN_VALUE, -1L)) + .isInstanceOf(ArithmeticException.class); + } + @Test public void testFieldProductFloatAgg() { FieldProductAgg fieldProductAgg =