Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ public Object agg(Object accumulator, Object inputField) {
product = fromBigDecimal(mul, mergeFieldDD.precision(), mergeFieldDD.scale());
break;
case TINYINT:
product = (byte) ((byte) accumulator * (byte) inputField);
product = multiplyExactByte((byte) accumulator, (byte) inputField);
break;
case SMALLINT:
product = (short) ((short) accumulator * (short) inputField);
product = multiplyExactShort((short) accumulator, (short) inputField);
break;
case INTEGER:
product = (int) accumulator * (int) inputField;
product = multiplyExactInt((int) accumulator, (int) inputField);
break;
case BIGINT:
product = (long) accumulator * (long) inputField;
product = multiplyExactLong((long) accumulator, (long) inputField);
break;
case FLOAT:
product = (float) accumulator * (float) inputField;
Expand All @@ -84,6 +84,72 @@ public Object agg(Object accumulator, Object inputField) {
return product;
}

private static byte multiplyExactByte(byte a, byte b) {
int value = a * b;
if (value > Byte.MAX_VALUE || value < Byte.MIN_VALUE) {
throw new ArithmeticException(
String.format("byte overflow: %d * %d = %d", a, b, value));
}
return (byte) value;
}

private static short multiplyExactShort(short a, short b) {
int value = a * b;
if (value > Short.MAX_VALUE || value < Short.MIN_VALUE) {
throw new ArithmeticException(
String.format("short overflow: %d * %d = %d", a, b, value));
}
return (short) value;
}

private static int multiplyExactInt(int a, int b) {
try {
return Math.multiplyExact(a, b);
} catch (ArithmeticException e) {
throw new ArithmeticException(String.format("int overflow: %d * %d", a, b));
}
}

private static long multiplyExactLong(long a, long b) {
try {
return Math.multiplyExact(a, b);
} catch (ArithmeticException e) {
throw new ArithmeticException(String.format("long overflow: %d * %d", a, b));
}
}

private static byte divideExactByte(byte a, byte b) {
int value = a / b;
if (value > Byte.MAX_VALUE || value < Byte.MIN_VALUE) {
throw new ArithmeticException(
String.format("byte overflow: %d / %d = %d", a, b, value));
}
return (byte) value;
}

private static short divideExactShort(short a, short b) {
int value = a / b;
if (value > Short.MAX_VALUE || value < Short.MIN_VALUE) {
throw new ArithmeticException(
String.format("short overflow: %d / %d = %d", a, b, value));
}
return (short) value;
}

private static int divideExactInt(int a, int b) {
if (a == Integer.MIN_VALUE && b == -1) {
throw new ArithmeticException(String.format("int overflow: %d / %d", a, b));
}
return a / b;
}

private static long divideExactLong(long a, long b) {
if (a == Long.MIN_VALUE && b == -1L) {
throw new ArithmeticException(String.format("long overflow: %d / %d", a, b));
}
return a / b;
}

@Override
public Object retract(Object accumulator, Object inputField) {
Object product;
Expand All @@ -105,16 +171,16 @@ public Object retract(Object accumulator, Object inputField) {
product = fromBigDecimal(div, mergeFieldDD.precision(), mergeFieldDD.scale());
break;
case TINYINT:
product = (byte) ((byte) accumulator / (byte) inputField);
product = divideExactByte((byte) accumulator, (byte) inputField);
break;
case SMALLINT:
product = (short) ((short) accumulator / (short) inputField);
product = divideExactShort((short) accumulator, (short) inputField);
break;
case INTEGER:
product = (int) accumulator / (int) inputField;
product = divideExactInt((int) accumulator, (int) inputField);
break;
case BIGINT:
product = (long) accumulator / (long) inputField;
product = divideExactLong((long) accumulator, (long) inputField);
break;
case FLOAT:
product = (float) accumulator / (float) inputField;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,78 @@ public void testFieldProductLongAgg() {
assertThat(fieldProductAgg.retract(null, 5L)).isNull();
}

@Test
public void testFieldProductByteOverflow() {
FieldProductAgg fieldProductAgg =
new FieldProductAggFactory().create(new TinyIntType(), null, null);
assertThatThrownBy(() -> fieldProductAgg.agg((byte) 64, (byte) 2))
.isInstanceOf(ArithmeticException.class);
assertThatThrownBy(() -> fieldProductAgg.agg((byte) -64, (byte) 4))
.isInstanceOf(ArithmeticException.class);
}

@Test
public void testFieldProductShortOverflow() {
FieldProductAgg fieldProductAgg =
new FieldProductAggFactory().create(new SmallIntType(), null, null);
assertThatThrownBy(() -> fieldProductAgg.agg((short) 1000, (short) 100))
.isInstanceOf(ArithmeticException.class);
assertThatThrownBy(() -> fieldProductAgg.agg(Short.MIN_VALUE, (short) 2))
.isInstanceOf(ArithmeticException.class);
}

@Test
public void testFieldProductIntOverflow() {
FieldProductAgg fieldProductAgg =
new FieldProductAggFactory().create(new IntType(), null, null);
assertThatThrownBy(() -> fieldProductAgg.agg(100_000, 100_000))
.isInstanceOf(ArithmeticException.class);
assertThatThrownBy(() -> fieldProductAgg.agg(Integer.MIN_VALUE, -1))
.isInstanceOf(ArithmeticException.class);
}

@Test
public void testFieldProductLongOverflow() {
FieldProductAgg fieldProductAgg =
new FieldProductAggFactory().create(new BigIntType(), null, null);
assertThatThrownBy(() -> fieldProductAgg.agg(Long.MAX_VALUE, 2L))
.isInstanceOf(ArithmeticException.class);
assertThatThrownBy(() -> fieldProductAgg.agg(Long.MIN_VALUE, -1L))
.isInstanceOf(ArithmeticException.class);
}

@Test
public void testFieldProductByteRetractOverflow() {
FieldProductAgg fieldProductAgg =
new FieldProductAggFactory().create(new TinyIntType(), null, null);
assertThatThrownBy(() -> fieldProductAgg.retract(Byte.MIN_VALUE, (byte) -1))
.isInstanceOf(ArithmeticException.class);
}

@Test
public void testFieldProductShortRetractOverflow() {
FieldProductAgg fieldProductAgg =
new FieldProductAggFactory().create(new SmallIntType(), null, null);
assertThatThrownBy(() -> fieldProductAgg.retract(Short.MIN_VALUE, (short) -1))
.isInstanceOf(ArithmeticException.class);
}

@Test
public void testFieldProductIntRetractOverflow() {
FieldProductAgg fieldProductAgg =
new FieldProductAggFactory().create(new IntType(), null, null);
assertThatThrownBy(() -> fieldProductAgg.retract(Integer.MIN_VALUE, -1))
.isInstanceOf(ArithmeticException.class);
}

@Test
public void testFieldProductLongRetractOverflow() {
FieldProductAgg fieldProductAgg =
new FieldProductAggFactory().create(new BigIntType(), null, null);
assertThatThrownBy(() -> fieldProductAgg.retract(Long.MIN_VALUE, -1L))
.isInstanceOf(ArithmeticException.class);
}

@Test
public void testFieldProductFloatAgg() {
FieldProductAgg fieldProductAgg =
Expand Down
Loading