Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions tests/functional/codegen/test_interface_method_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from vyper.utils import method_id


def test_method_id_of_basic(get_contract):
code = """
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def get_method_id() -> bytes4:
return method_id_of(Foo.transfer)
"""
c = get_contract(code)
result = c.get_method_id()
expected = method_id("transfer(address,uint256)")
assert result == expected


def test_method_id_of_view_function(get_contract):
code = """
interface Foo:
def balanceOf(owner: address) -> uint256: view

@external
def get_method_id() -> bytes4:
return method_id_of(Foo.balanceOf)
"""
c = get_contract(code)
result = c.get_method_id()
expected = method_id("balanceOf(address)")
assert result == expected


def test_method_id_of_no_args(get_contract):
code = """
interface Foo:
def totalSupply() -> uint256: view

@external
def get_method_id() -> bytes4:
return method_id_of(Foo.totalSupply)
"""
c = get_contract(code)
result = c.get_method_id()
expected = method_id("totalSupply()")
assert result == expected


def test_method_id_of_in_raw_call(get_contract):
called_code = """
@external
def double(x: uint256) -> uint256:
return x * 2
"""
caller_code = """
interface Doubler:
def double(x: uint256) -> uint256: view

@external
def call_double(target: address, x: uint256) -> uint256:
response: Bytes[32] = raw_call(
target,
concat(method_id_of(Doubler.double), convert(x, bytes32)),
max_outsize=32
)
return convert(convert(response, bytes32), uint256)
"""
callee = get_contract(called_code)
caller = get_contract(caller_code)
assert caller.call_double(callee.address, 5) == 10


def test_method_id_of_assign_to_variable(get_contract):
code = """
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def get_method_id() -> bytes4:
m: bytes4 = method_id_of(Foo.transfer)
return m
"""
c = get_contract(code)
result = c.get_method_id()
expected = method_id("transfer(address,uint256)")
assert result == expected


def test_method_id_of_compare(get_contract):
code = """
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def check() -> bool:
return method_id_of(Foo.transfer) == method_id('transfer(address,uint256)', output_type=bytes4)
"""
c = get_contract(code)
assert c.check() is True


def test_method_id_of_default_args(get_contract, make_input_bundle):
iface_code = """
@external
def take(auction_id: uint256, max_take_amount: uint256 = ...) -> uint256:
...
"""
input_bundle = make_input_bundle({"ifoo.vyi": iface_code})

code = """
import ifoo as IFoo

@external
def get_full() -> bytes4:
return method_id_of(IFoo.take, n_optional_args=1)

@external
def get_minimal() -> bytes4:
return method_id_of(IFoo.take)

@external
def get_default() -> bytes4:
return method_id_of(IFoo.take, n_optional_args=0)
"""
c = get_contract(code, input_bundle=input_bundle)
# full signature (all args, 1 optional included)
assert c.get_full() == method_id("take(uint256,uint256)")
# minimal signature (positional only, default n_optional_args=0)
assert c.get_minimal() == method_id("take(uint256)")
# explicit n_optional_args=0, same as default
assert c.get_default() == method_id("take(uint256)")


def test_method_id_of_default_args_view(get_contract, make_input_bundle):
iface_code = """
@view
@external
def get_amount(token: address, receiver: address = ...) -> uint256:
...
"""
input_bundle = make_input_bundle({"ifoo.vyi": iface_code})

code = """
import ifoo as IFoo

@external
def get_method_id() -> bytes4:
return method_id_of(IFoo.get_amount)
"""
c = get_contract(code, input_bundle=input_bundle)
result = c.get_method_id()
# default n_optional_args=0, so only positional args
expected = method_id("get_amount(address)")
assert result == expected
134 changes: 134 additions & 0 deletions tests/functional/syntax/test_interface_method_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import pytest

from vyper.compiler import compile_code
from vyper.exceptions import ArgumentException, StructureException


valid_list = [
# basic method_id_of access
"""
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def foo() -> bytes4:
return method_id_of(Foo.transfer)
""",
# use in raw_call
"""
interface Foo:
def bar(x: uint256) -> uint256: view

@external
def foo():
x: Bytes[32] = raw_call(
msg.sender,
concat(method_id_of(Foo.bar), convert(1, bytes32)),
max_outsize=32
)
""",
]


@pytest.mark.parametrize("code", valid_list)
def test_method_id_of_pass(code):
assert compile_code(code) is not None


def test_method_id_of_not_a_function():
code = """
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def foo() -> bytes4:
return method_id_of(Foo)
"""
with pytest.raises((ArgumentException, StructureException)):
compile_code(code)


def test_method_id_of_string_not_accepted():
code = """
@external
def foo() -> bytes4:
return method_id_of("transfer(address,uint256)")
"""
with pytest.raises((ArgumentException, StructureException)):
compile_code(code)


def test_method_id_of_n_optional_args_out_of_range(make_input_bundle):
iface_code = """
@external
def take(auction_id: uint256, max_take_amount: uint256 = ...) -> uint256:
...
"""
input_bundle = make_input_bundle({"ifoo.vyi": iface_code})

code = """
import ifoo as IFoo

@external
def foo() -> bytes4:
return method_id_of(IFoo.take, n_optional_args=5)
"""
with pytest.raises(ArgumentException):
compile_code(code, input_bundle=input_bundle)


def test_method_id_of_n_optional_args_zero_no_defaults():
code = """
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def foo() -> bytes4:
return method_id_of(Foo.transfer, n_optional_args=0)
"""
assert compile_code(code) is not None


def test_method_id_of_n_optional_args_no_defaults():
code = """
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def foo() -> bytes4:
return method_id_of(Foo.transfer, n_optional_args=1)
"""
with pytest.raises(ArgumentException):
compile_code(code)


def test_method_id_of_default_args(make_input_bundle):
iface_code = """
@external
def take(
auction_id: uint256,
max_take_amount: uint256 = ...,
) -> uint256:
...
"""
input_bundle = make_input_bundle({"ifoo.vyi": iface_code})

code = """
import ifoo as IFoo

@external
def foo() -> bytes4:
return method_id_of(IFoo.take)
"""
assert compile_code(code, input_bundle=input_bundle) is not None


def test_interface_function_not_valid_as_type():
code = """
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

x: Foo.transfer
"""
with pytest.raises(StructureException):
compile_code(code)
67 changes: 67 additions & 0 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,72 @@ def infer_kwarg_types(self, node):
return {"output_type": TYPE_T(output_type)}


class MethodIDOf(BuiltinFunctionT):
_id = "method_id_of"
_inputs = [("func_ref", TYPE_T.any())]
_kwargs = {"n_optional_args": KwargSettings(UINT256_T, 0, require_literal=True)}
_return_type = BYTES4_T
_modifiability = Modifiability.CONSTANT

def fetch_call_return(self, node):
validate_call_args(node, 1, ["n_optional_args"])
self._validate_func_ref(node)
return BYTES4_T

def infer_arg_types(self, node, expected_return_typ=None):
func_t = self._get_func_t(node)
return [TYPE_T(func_t)]

def _get_func_t(self, node):
from vyper.semantics.analysis.utils import get_exact_type_from_node
from vyper.semantics.types.base import is_type_t
from vyper.semantics.types.function import ContractFunctionT

arg_type = get_exact_type_from_node(node.args[0])
if not is_type_t(arg_type, ContractFunctionT):
raise ArgumentException(
"method_id_of expects a function reference, e.g. method_id_of(IFoo.bar)",
node.args[0],
)
return arg_type.typedef

def _validate_func_ref(self, node):
func_t = self._get_func_t(node)

n_optional = self._get_n_optional_args(node, func_t)
if not (0 <= n_optional <= func_t.n_keyword_args):
raise ArgumentException(
f"n_optional_args must be between 0 and "
f"{func_t.n_keyword_args}, got {n_optional}",
node,
)

def _get_n_optional_args(self, node, func_t):
for kw in node.keywords:
if kw.arg == "n_optional_args":
Comment on lines +793 to +794

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit convoluted, wouldn't node.keywords.get("n_optional_args", 0) or something like it work ?

How do the other built-ins deal with kwargs ?

val = kw.value.get_folded_value()
if not isinstance(val, vy_ast.Int):
raise InvalidType("n_optional_args must be an integer literal", kw.value)
return val.value
# default: 0 optional args (positional only)
return 0

def _compute_method_id(self, node):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this logic a duplicate ?
It seems like we would already have something to compute that

func_t = self._get_func_t(node)
n_optional = self._get_n_optional_args(node, func_t)
n_total = func_t.n_positional_args + n_optional

arg_types = [i.canonical_abi_type for i in func_t.argument_types[:n_total]]
function_sig = f"{func_t.name}({','.join(arg_types)})"
selector = method_id(function_sig)
return selector

def build_IR(self, node, context):
selector = self._compute_method_id(node)
value = fourbytes_to_int(selector) << 224
return IRnode.from_list(value, typ=BYTES4_T)


class ECRecover(BuiltinFunctionT):
_id = "ecrecover"
_inputs = [
Expand Down Expand Up @@ -2544,6 +2610,7 @@ def _try_fold(self, node):
"concat": Concat(),
"sha256": Sha256(),
"method_id": MethodID(),
"method_id_of": MethodIDOf(),
"keccak256": Keccak256(),
"ecrecover": ECRecover(),
"ecadd": ECAdd(),
Expand Down
15 changes: 15 additions & 0 deletions vyper/codegen_venom/builtins/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,20 @@ def lower_breakpoint(node: vy_ast.Call, ctx: VenomCodegenContext) -> IROperand:
return IRLiteral(0)


def lower_method_id_of(node: vy_ast.Call, ctx: VenomCodegenContext) -> IROperand:
"""
method_id_of(IFoo.bar) -> bytes4

Returns the 4-byte function selector, evaluated at compile time.
"""
from vyper.builtins.functions import MethodIDOf

builtin = MethodIDOf()
selector = builtin._compute_method_id(node)
Comment on lines +657 to +658

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
builtin = MethodIDOf()
selector = builtin._compute_method_id(node)
selector = MethodIDOf()._compute_method_id(node)

value = int.from_bytes(selector, "big") << 224
return IRLiteral(value)


# Export handlers
HANDLERS = {
"ecrecover": lower_ecrecover,
Expand All @@ -662,4 +676,5 @@ def lower_breakpoint(node: vy_ast.Call, ctx: VenomCodegenContext) -> IROperand:
"isqrt": lower_isqrt,
"breakpoint": lower_breakpoint,
"print": lower_print,
"method_id_of": lower_method_id_of,
}
Loading