diff --git a/constantine/math_compiler/impl_fields_sat.nim b/constantine/math_compiler/impl_fields_sat.nim index cd52f96f2..d5a913c27 100644 --- a/constantine/math_compiler/impl_fields_sat.nim +++ b/constantine/math_compiler/impl_fields_sat.nim @@ -78,49 +78,6 @@ import const SectionName = "ctt.fields" -proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, M, carry: ValueRef) = - ## If a >= Modulus: r <- a-M - ## else: r <- a - ## - ## This is constant-time straightline code. - ## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU. - ## - ## To be used when the final substraction can - ## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256) - - # Mask: contains 0xFFFF or 0x0000 - let (_, mask) = asy.br.subborrow(fd.zero, fd.zero, carry) - - # Now substract the modulus, and test a < M - # (underflow) with the last borrow - let (borrow, a_minus_M) = asy.br.llvm_sub_overflow(a, M) - - # If it underflows here, it means that it was - # smaller than the modulus and we don't need `a-M` - let (ctl, _) = asy.br.subborrow(mask, fd.zero, borrow) - - let t = asy.br.select(ctl, a, a_minus_M) - asy.store(rr, t) - -proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, M: ValueRef) = - ## If a >= Modulus: r <- a-M - ## else: r <- a - ## - ## This is constant-time straightline code. - ## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU. - ## - ## To be used when the modulus does not use the full bitwidth of the storing words - ## (say using 255 bits for the modulus out of 256 available in words) - - # Now substract the modulus, and test a < M - # (underflow) with the last borrow - let (borrow, a_minus_M) = asy.br.llvm_sub_overflow(a, M) - - # If it underflows here, it means that it was - # smaller than the modulus and we don't need `a-M` - let t = asy.br.select(borrow, a, a_minus_M) - asy.store(rr, t) - proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) = ## Generate an optimized modular addition kernel ## with parameters `a, b, modulus: Limbs -> Limbs` @@ -142,11 +99,70 @@ proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) = let b = asy.load2(fd.intBufTy, bb, "b") let M = asy.load2(fd.intBufTy, MM, "M") - let (carry, apb) = asy.br.llvm_add_overflow(a, b) if fd.spareBits >= 1: - asy.finalSubNoOverflow(fd, rr, apb, M) + let apb = asy.br.add(a, b, "a_plus_b") + if false: + # 33% more instructions + # https://github.com/llvm/llvm-project/issues/103717 + + # Now substract the modulus, and test apb < M + # (underflow) with the last borrow + let (borrow, apb_minus_M) = asy.br.llvm_sub_overflow(apb, M) + + # If it underflows here, it means that it was + # smaller than the modulus and we don't need `a-M` + let t = asy.br.select(borrow, apb, apb_minus_M) + asy.store(rr, t) + + else: + # 1 or 2 extra instructions + # https://github.com/llvm/llvm-project/issues/103841 + # https://github.com/llvm/llvm-project/issues/103855 + + let s = constInt(fd.intBufTy, fd.w * fd.numWords - 1) + + let apb_minus_M = asy.br.sub(apb, M) + let underflow = asy.br.lshr(apb_minus_M, s) + let borrow = asy.br.trunc(underflow, asy.ctx.int1_t()) + + let t = asy.br.select(borrow, apb, apb_minus_M) + asy.store(rr, t) else: - asy.finalSubMayOverflow(fd, rr, apb, M, carry) + if false: + let (carry, apb) = asy.br.llvm_add_overflow(a, b, "a_plus_b") + + # Mask: contains 0xFFFF or 0x0000 + let (_, mask) = asy.br.subborrow(fd.zero, fd.zero, carry) + + # Now substract the modulus, and test a < M + # (underflow) with the last borrow + let (borrow, apb_minus_M) = asy.br.llvm_sub_overflow(apb, M) + + # If it underflows here, it means that it was + # smaller than the modulus and we don't need `a-M` + let (ctl, _) = asy.br.subborrow(mask, fd.zero, borrow) + + let t = asy.br.select(ctl, apb, apb_minus_M) + asy.store(rr, t) + else: + let biggerIntBits = fd.w * (fd.numWords+1) + let biggerInt = asy.ctx.int_t(biggerIntBits) + + let ax = asy.br.zext(a, biggerInt, "ax") + let bx = asy.br.zext(b, biggerInt, "bx") + + let apb = asy.br.add(ax, bx, "a_plus_b") + + let mx = asy.br.zext(M, biggerInt, "mx") + let apb_minus_M = asy.br.sub(apb, mx, "apb_minus_M") + + let s = constInt(biggerInt, biggerIntBits - 1) + let underflow = asy.br.lshr(apb_minus_M, s) + let borrow = asy.br.trunc(underflow, asy.ctx.int1_t()) + + let tLarge = asy.br.select(borrow, apb, apb_minus_M) + let t = asy.br.trunc(tLarge, fd.intBufTy) + asy.store(rr, t) asy.br.retVoid()