diff --git a/tests/functional/builtins/codegen/test_convert.py b/tests/functional/builtins/codegen/test_convert.py index 9b1cd3028a..13ed97b62a 100644 --- a/tests/functional/builtins/codegen/test_convert.py +++ b/tests/functional/builtins/codegen/test_convert.py @@ -333,7 +333,9 @@ def _py_convert(val, i_typ, o_typ): val_bits = _padconvert(val_bits, _padding_direction(o_typ), n, padding_byte) if getattr(o_typ, "is_signed", False) and isinstance(i_typ, BytesM_T): - n_bits = _bits_of_type(i_typ) + out_size = _bits_of_type(o_typ) + in_size = _bits_of_type(i_typ) + n_bits = max(out_size, in_size) val_bits = _signextend(val_bits, n_bits) try: @@ -426,6 +428,68 @@ def _vyper_literal(val, typ): return str(val) +def test_bytes_to_int_different_sizes(get_contract): + code = r""" +@external +def foo() -> int16: + return convert(b'\xff', int16) + """ + + c = get_contract(code) + assert c.foo() == 255 + + code = r""" +@external +def foo() -> int16: + return convert(b'\x00\xff', int16) + """ + + c = get_contract(code) + assert c.foo() == 255 + + code = r""" +FOO: constant(Bytes[2]) = b'\xff' + +@external +def foo() -> int16: + return convert(FOO, int16) + """ + + c = get_contract(code) + assert c.foo() == 255 + + +def test_bytes_to_int_different_sizes_bytes3(get_contract): + code = r""" +@external +def foo(x: bytes3) -> int96: + return convert(x, int96) + """ + + c = get_contract(code) + assert c.foo(b"\xff\xff\xff") == 0xFF_FF_FF + + +def test_bytes_to_int_different_sizes_runtime(get_contract): + code = """ +@external +def foo(x: Bytes[1]) -> int16: + return convert(x, int16) + """ + + c = get_contract(code) + assert c.foo(b"\xff") == 255 + + code = """ +@external +def foo(x: Bytes[2]) -> int16: + return convert(x, int16) + """ + + c = get_contract(code) + assert c.foo(b"\xff") == 255 + + @pytest.mark.parametrize("i_typ,o_typ,val", generate_passing_cases()) @pytest.mark.fuzzing def test_convert_passing(get_contract, assert_compile_failed, i_typ, o_typ, val): diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index 52e9991f9c..28af218128 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -43,6 +43,7 @@ ) from vyper.semantics.types.bytestrings import _BytestringT from vyper.semantics.types.infinity import is_bounded_length +from vyper.semantics.types.primitives import NumericT from vyper.semantics.types.shortcuts import INT256_T, UINT160_T, UINT256_T from vyper.utils import DECIMAL_DIVISOR, round_towards_zero, unsigned_to_signed @@ -56,7 +57,7 @@ def _FAIL(ityp, otyp, source_expr=None): def _input_types(*allowed_types): def decorator(f): @functools.wraps(f) - def check_input_type(expr, arg, out_typ): + def check_input_type(expr, arg, in_typ, out_typ): # convert arg to out_typ. # (expr is the AST corresponding to `arg`) ok = isinstance(arg.typ, allowed_types) @@ -69,34 +70,55 @@ def check_input_type(expr, arg, out_typ): if arg.typ == out_typ and arg.typ not in (UINT256_T, INT256_T): raise InvalidType(f"value and target are both {out_typ}", expr) - return f(expr, arg, out_typ) + return f(expr, arg, in_typ, out_typ) return check_input_type return decorator -def _bytes_to_num(arg, out_typ, signed): +def _bits_count(typ) -> int | None: + if isinstance(typ, BoolT): + return 8 + elif isinstance(typ, NumericT): + return typ.bits + return None + + +def _bytes_to_num(arg, in_typ, out_typ, signed): # converting a bytestring to a number: # bytestring and bytes_m are right-padded with zeroes, int is left-padded. # convert by shr or sar the number of zero bytes (converted to bits) # e.g. "abcd000000000000" -> bitcast(000000000000abcd, output_type) + out_size = _bits_count(out_typ) + assert out_size is not None + out_size = out_size // 8 - if isinstance(arg.typ, _BytestringT): - if not is_bounded_length(arg.typ.maxlen): + if isinstance(in_typ, _BytestringT): + if not is_bounded_length(in_typ.maxlen): raise CodegenPanic("convert: unbounded bytestring type") _len = get_bytearray_length(arg) + assert isinstance(in_typ, _BytestringT) + if in_typ.maxlen > out_size: + out_size = in_typ.maxlen + arg = LOAD(bytes_data_ptr(arg)) - num_zero_bits = ["mul", 8, ["sub", 32, _len]] - elif is_bytes_m_type(arg.typ): - num_zero_bits = 8 * (32 - arg.typ.m) + runtime_compile_diff = ["sub", out_size, _len] + val = shr(["mul", runtime_compile_diff, 8], arg) + num_zero_bits = 8 * (32 - out_size) + elif is_bytes_m_type(in_typ): + if in_typ.m > out_size: + out_size = in_typ.m + runtime_compile_diff = out_size - in_typ.m + val = shr(["mul", runtime_compile_diff, 8], arg) + num_zero_bits = 8 * (32 - out_size) else: # pragma: nocover raise CompilerPanic("unreachable") if signed: - ret = sar(num_zero_bits, arg) + ret = sar(num_zero_bits, val) else: - ret = shr(num_zero_bits, arg) + ret = shr(num_zero_bits, val) annotation = (f"__intrinsic__byte_array_to_num({out_typ})",) return IRnode.from_list(ret, annotation=annotation) @@ -208,14 +230,17 @@ def _check_bytes(expr, arg, output_type, max_bytes_allowed): # apply sign extension, if expected. note that the sign bit # is always taken to be the first bit of the bytestring. # (e.g. convert(0xff , int16) == -1) -def _signextend(expr, val, arg_typ): +def _signextend(expr, val, arg_typ, out_size): if isinstance(expr, vy_ast.Hex): assert len(expr.value[2:]) // 2 == arg_typ.m n_bits = arg_typ.m_bits else: - assert len(expr.value) == arg_typ.maxlen + assert len(expr.value) <= arg_typ.maxlen n_bits = arg_typ.maxlen * 8 + if n_bits < out_size: + n_bits = out_size + return unsigned_to_signed(val, n_bits) @@ -231,7 +256,7 @@ def _literal_int(expr, arg_typ, out_typ): raise CompilerPanic("unreachable") if isinstance(expr, (vy_ast.Hex, vy_ast.Bytes, vy_ast.HexBytes)) and out_typ.is_signed: - val = _signextend(expr, val, arg_typ) + val = _signextend(expr, val, arg_typ, out_size=_bits_count(out_typ)) lo, hi = out_typ.int_bounds if not (lo <= val <= hi): @@ -258,7 +283,7 @@ def _literal_decimal(expr, arg_typ, out_typ): # apply sign extension, if expected if isinstance(expr, vy_ast.Hex) and out_typ.is_signed: - val = _signextend(expr, val, arg_typ) + val = _signextend(expr, val, arg_typ, out_size=_bits_count(out_typ)) lo, hi = out_typ.int_bounds if not lo <= val <= hi: @@ -269,12 +294,12 @@ def _literal_decimal(expr, arg_typ, out_typ): # any base type or bytes/string @_input_types(IntegerT, DecimalT, BytesM_T, AddressT, BoolT, BytesT, StringT) -def to_bool(expr, arg, out_typ): +def to_bool(expr, arg, in_typ, out_typ): _check_bytes(expr, arg, out_typ, 32) # should we restrict to Bytes[1]? if isinstance(arg.typ, _BytestringT): # no clamp. checks for any nonzero bytes. - arg = _bytes_to_num(arg, out_typ, signed=False) + arg = _bytes_to_num(arg, in_typ, out_typ, signed=False) # NOTE: for decimal, the behavior is x != 0.0, # (we do not issue an `sdiv DECIMAL_DIVISOR`) @@ -283,27 +308,27 @@ def to_bool(expr, arg, out_typ): @_input_types(IntegerT, DecimalT, BytesM_T, AddressT, BoolT, FlagT, BytesT) -def to_int(expr, arg, out_typ): - return _to_int(expr, arg, out_typ) +def to_int(expr, arg, in_typ, out_typ): + return _to_int(expr, arg, in_typ, out_typ) # an internal version of to_int without input validation -def _to_int(expr, arg, out_typ): +def _to_int(expr, arg, in_typ, out_typ): assert out_typ.bits % 8 == 0 _check_bytes(expr, arg, out_typ, 32) if isinstance(expr, vy_ast.Constant): - return _literal_int(expr, arg.typ, out_typ) + return _literal_int(expr, in_typ, out_typ) elif isinstance(arg.typ, BytesT): arg_typ = arg.typ - arg = _bytes_to_num(arg, out_typ, signed=out_typ.is_signed) + arg = _bytes_to_num(arg, in_typ, out_typ, signed=out_typ.is_signed) if arg_typ.maxlen * 8 > out_typ.bits: arg = int_clamp(arg, out_typ.bits, signed=out_typ.is_signed) elif is_bytes_m_type(arg.typ): arg_typ = arg.typ - arg = _bytes_to_num(arg, out_typ, signed=out_typ.is_signed) + arg = _bytes_to_num(arg, in_typ, out_typ, signed=out_typ.is_signed) if arg_typ.m_bits > out_typ.bits: arg = int_clamp(arg, out_typ.bits, signed=out_typ.is_signed) @@ -332,7 +357,7 @@ def _to_int(expr, arg, out_typ): @_input_types(IntegerT, BoolT, BytesM_T, BytesT) -def to_decimal(expr, arg, out_typ): +def to_decimal(expr, arg, in_typ, out_typ): _check_bytes(expr, arg, out_typ, 32) if isinstance(expr, vy_ast.Constant): @@ -340,7 +365,7 @@ def to_decimal(expr, arg, out_typ): if isinstance(arg.typ, BytesT): arg_typ = arg.typ - arg = _bytes_to_num(arg, out_typ, signed=True) + arg = _bytes_to_num(arg, in_typ, out_typ, signed=True) if arg_typ.maxlen * 8 > 168: arg = IRnode.from_list(arg, typ=out_typ) arg = clamp_basetype(arg) @@ -349,7 +374,7 @@ def to_decimal(expr, arg, out_typ): elif is_bytes_m_type(arg.typ): arg_typ = arg.typ - arg = _bytes_to_num(arg, out_typ, signed=True) + arg = _bytes_to_num(arg, in_typ, out_typ, signed=True) if arg_typ.m_bits > 168: arg = IRnode.from_list(arg, typ=out_typ) arg = clamp_basetype(arg) @@ -369,7 +394,7 @@ def to_decimal(expr, arg, out_typ): @_input_types(IntegerT, DecimalT, BytesM_T, AddressT, BytesT, BoolT) -def to_bytes_m(expr, arg, out_typ): +def to_bytes_m(expr, arg, in_typ, out_typ): _check_bytes(expr, arg, out_typ, max_bytes_allowed=out_typ.m) if isinstance(arg.typ, BytesT): @@ -417,13 +442,13 @@ def to_bytes_m(expr, arg, out_typ): @_input_types(BytesM_T, IntegerT, BytesT) -def to_address(expr, arg, out_typ): +def to_address(expr, arg, in_typ, out_typ): # question: should this be allowed? if is_integer_type(arg.typ): if arg.typ.is_signed: _FAIL(arg.typ, out_typ, expr) - ret = _to_int(expr, arg, UINT160_T) + ret = _to_int(expr, arg, in_typ, UINT160_T) return IRnode.from_list(ret, out_typ) @@ -445,17 +470,17 @@ def _cast_bytestring(expr, arg, out_typ): # question: should we allow bytesM -> String? @_input_types(BytesT, StringT) -def to_string(expr, arg, out_typ): +def to_string(expr, arg, in_typ, out_typ): return _cast_bytestring(expr, arg, out_typ) @_input_types(StringT, BytesT) -def to_bytes(expr, arg, out_typ): +def to_bytes(expr, arg, in_typ, out_typ): return _cast_bytestring(expr, arg, out_typ) @_input_types(IntegerT) -def to_flag(expr, arg, out_typ): +def to_flag(expr, arg, in_typ, out_typ): if arg.typ != UINT256_T: _FAIL(arg.typ, out_typ, expr) @@ -468,6 +493,7 @@ def to_flag(expr, arg, out_typ): def convert(expr, context): assert len(expr.args) == 2, "bad typecheck: convert" + in_typ = expr.args[0]._metadata["type"] arg_ast = expr.args[0].reduced() arg = Expr(arg_ast, context).ir_node original_arg = arg @@ -478,21 +504,21 @@ def convert(expr, context): arg = unwrap_location(arg) with arg.cache_when_complex("arg") as (b, arg): if out_typ == BoolT(): - ret = to_bool(arg_ast, arg, out_typ) + ret = to_bool(arg_ast, arg, in_typ, out_typ) elif out_typ == AddressT(): - ret = to_address(arg_ast, arg, out_typ) + ret = to_address(arg_ast, arg, in_typ, out_typ) elif is_flag_type(out_typ): - ret = to_flag(arg_ast, arg, out_typ) + ret = to_flag(arg_ast, arg, in_typ, out_typ) elif is_integer_type(out_typ): - ret = to_int(arg_ast, arg, out_typ) + ret = to_int(arg_ast, arg, in_typ, out_typ) elif is_bytes_m_type(out_typ): - ret = to_bytes_m(arg_ast, arg, out_typ) + ret = to_bytes_m(arg_ast, arg, in_typ, out_typ) elif is_decimal_type(out_typ): - ret = to_decimal(arg_ast, arg, out_typ) + ret = to_decimal(arg_ast, arg, in_typ, out_typ) elif isinstance(out_typ, BytesT): - ret = to_bytes(arg_ast, arg, out_typ) + ret = to_bytes(arg_ast, arg, in_typ, out_typ) elif isinstance(out_typ, StringT): - ret = to_string(arg_ast, arg, out_typ) + ret = to_string(arg_ast, arg, in_typ, out_typ) else: raise StructureException(f"Conversion to {out_typ} is invalid.", arg_ast) diff --git a/vyper/codegen_venom/builtins/convert.py b/vyper/codegen_venom/builtins/convert.py index b9eebef3bf..609b87c906 100644 --- a/vyper/codegen_venom/builtins/convert.py +++ b/vyper/codegen_venom/builtins/convert.py @@ -84,6 +84,12 @@ def _get_folded_value(node: vy_ast.VyperNode): return None +def _bytes_of_numeric_type(out_t: BoolT | IntegerT | DecimalT) -> int: + if isinstance(out_t, BoolT): + return 1 + return out_t.bits // 8 + + def _check_literal_int_bounds(arg_node: vy_ast.VyperNode, out_t: IntegerT) -> None: """ Check if a compile-time constant integer fits in the output type bounds. @@ -177,30 +183,16 @@ def _to_int( # From bytes/string: load data, shift right if isinstance(in_t, _BytestringT): - # Length at val, data at val+32 assert isinstance(val, IRVariable) - length = b.mload(val) - data_ptr = b.add(val, IRLiteral(32)) - data = b.mload(data_ptr) - # Right-shift to convert left-aligned bytes to right-aligned int - # num_zero_bits = (32 - len) * 8 - num_zero_bits = b.mul(b.sub(IRLiteral(32), length), IRLiteral(8)) - if out_t.is_signed: - val = b.sar(num_zero_bits, data) - else: - val = b.shr(num_zero_bits, data) + val = _bytestring_to_num(val, in_t, out_t, out_t.is_signed, ctx) # Clamp if bytes could exceed output range if in_t.maxlen * 8 > out_t.bits: val = _int_clamp(val, out_t, ctx) return val - # From bytesM: right-shift by (32 - M) * 8 bits + # From bytesM: shift to extract value if isinstance(in_t, BytesM_T): - shift_bits = (32 - in_t.m) * 8 - if out_t.is_signed: - val = b.sar(IRLiteral(shift_bits), val) - else: - val = b.shr(IRLiteral(shift_bits), val) + val = _bytes_m_to_num(val, in_t, out_t, out_t.is_signed, ctx) # Clamp if bytesM could exceed output range if in_t.m * 8 > out_t.bits: val = _int_clamp(val, out_t, ctx) @@ -261,21 +253,16 @@ def _to_decimal( # From bytes/string if isinstance(in_t, _BytestringT): assert isinstance(val, IRVariable) - length = b.mload(val) - data_ptr = b.add(val, IRLiteral(32)) - data = b.mload(data_ptr) - num_zero_bits = b.mul(b.sub(IRLiteral(32), length), IRLiteral(8)) - val = b.sar(num_zero_bits, data) + val = _bytestring_to_num(val, in_t, out_t, signed=True, ctx=ctx) # Clamp to decimal bounds if needed - if in_t.maxlen * 8 > 168: # decimal is 168 bits + if in_t.maxlen * 8 > out_t.bits: val = _clamp_basetype(val, out_t, ctx) return val # From bytesM if isinstance(in_t, BytesM_T): - shift_bits = (32 - in_t.m) * 8 - val = b.sar(IRLiteral(shift_bits), val) - if in_t.m * 8 > 168: + val = _bytes_m_to_num(val, in_t, out_t, signed=True, ctx=ctx) + if in_t.m * 8 > out_t.bits: val = _clamp_basetype(val, out_t, ctx) return val @@ -453,6 +440,69 @@ def _check_bytes(in_t, out_t, max_bytes_allowed: int, source_expr: vy_ast.VyperN raise TypeMismatch(f"Can't convert {in_t} to {out_t}", source_expr) +def _bytestring_to_num( + val: IRVariable, + in_t: _BytestringT, + out_t: BoolT | IntegerT | DecimalT, + signed: bool, + ctx: VenomCodegenContext, +) -> IROperand: + """ + Convert a dynamic bytestring into a numeric value. + + Bytestrings are left-aligned and right-padded, while numbers are right-aligned. For + signed outputs, the sign bit is taken from max(input_maxlen, output_size), matching legacy + codegen and literal folding. + """ + b = ctx.builder + length = b.mload(val) + data_ptr = b.add(val, IRLiteral(32)) + data = b.mload(data_ptr) + + assert isinstance(in_t.maxlen, int) + out_size = _bytes_of_numeric_type(out_t) + num_bytes = max(in_t.maxlen, out_size) + + # First shift the runtime value into a compile-time sized window. For example, + # Bytes[2](b"\xff") -> int16 uses a 2-byte window containing 0x00ff, not 0xff00. + runtime_compile_diff = b.sub(IRLiteral(num_bytes), length) + val = b.shr(b.mul(runtime_compile_diff, IRLiteral(8)), data) + + # Then shift/sign-extend that window into a right-aligned integer. + num_zero_bits = IRLiteral(8 * (32 - num_bytes)) + if signed: + return b.sar(num_zero_bits, val) + return b.shr(num_zero_bits, val) + + +def _bytes_m_to_num( + val: IROperand, + in_t: BytesM_T, + out_t: BoolT | IntegerT | DecimalT, + signed: bool, + ctx: VenomCodegenContext, +) -> IROperand: + """ + Convert a bytesM value into a numeric value. + + As with dynamic bytestrings, use max(input_size, output_size) as the sign-extension width. + This means convert(0xff, int16) is 255, while convert(0xff00, int8) still reverts after + clamping. + """ + b = ctx.builder + out_size = _bytes_of_numeric_type(out_t) + num_bytes = max(in_t.m, out_size) + + shift_into_window = 8 * (num_bytes - in_t.m) + if shift_into_window: + val = b.shr(IRLiteral(shift_into_window), val) + + num_zero_bits = IRLiteral(8 * (32 - num_bytes)) + if signed: + return b.sar(num_zero_bits, val) + return b.shr(num_zero_bits, val) + + def _int_clamp(val: IROperand, out_t: IntegerT, ctx: VenomCodegenContext) -> IROperand: """Clamp value to integer type bounds.""" b = ctx.builder