Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions tests/functional/codegen/test_interface_method_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import pytest

from vyper.utils import method_id


def test_interface_method_id_basic(get_contract):
code = """
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def get_method_id() -> bytes4:
return Foo.transfer.method_id
"""
c = get_contract(code)
result = c.get_method_id()
expected = method_id("transfer(address,uint256)")
assert result == expected


def test_interface_method_id_view_function(get_contract):
code = """
interface Foo:
def balanceOf(owner: address) -> uint256: view

@external
def get_method_id() -> bytes4:
return Foo.balanceOf.method_id
"""
c = get_contract(code)
result = c.get_method_id()
expected = method_id("balanceOf(address)")
assert result == expected


def test_interface_method_id_no_args(get_contract):
code = """
interface Foo:
def totalSupply() -> uint256: view

@external
def get_method_id() -> bytes4:
return Foo.totalSupply.method_id
"""
c = get_contract(code)
result = c.get_method_id()
expected = method_id("totalSupply()")
assert result == expected


def test_interface_method_id_in_raw_call(get_contract, env):
called_code = """
@external
def double(x: uint256) -> uint256:
return x * 2
"""
caller_code = """
interface Doubler:
def double(x: uint256) -> uint256: view

@external
def call_double(target: address, x: uint256) -> uint256:
response: Bytes[32] = raw_call(
target,
concat(Doubler.double.method_id, convert(x, bytes32)),
max_outsize=32
)
return convert(convert(response, bytes32), uint256)
"""
callee = get_contract(called_code)
caller = get_contract(caller_code)
assert caller.call_double(callee.address, 5) == 10


def test_interface_method_id_assign_to_variable(get_contract):
code = """
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def get_method_id() -> bytes4:
m: bytes4 = Foo.transfer.method_id
return m
"""
c = get_contract(code)
result = c.get_method_id()
expected = method_id("transfer(address,uint256)")
assert result == expected


def test_interface_method_id_compare(get_contract):
code = """
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def check() -> bool:
return Foo.transfer.method_id == method_id('transfer(address,uint256)', output_type=bytes4)
"""
c = get_contract(code)
assert c.check() is True


def test_interface_method_id_default_args(get_contract, make_input_bundle):
iface_code = """
@external
def take(auction_id: uint256, max_take_amount: uint256 = ...) -> uint256:
...
"""
input_bundle = make_input_bundle({"ifoo.vyi": iface_code})

code = """
import ifoo as IFoo

@external
def get_method_id() -> bytes4:
return IFoo.take.method_id
"""
c = get_contract(code, input_bundle=input_bundle)
result = c.get_method_id()
# should return the full signature selector (all args)
expected = method_id("take(uint256,uint256)")
assert result == expected


def test_interface_method_id_default_args_view(get_contract, make_input_bundle):
iface_code = """
@view
@external
def get_amount(token: address, receiver: address = ...) -> uint256:
...
"""
input_bundle = make_input_bundle({"ifoo.vyi": iface_code})

code = """
import ifoo as IFoo

@external
def get_method_id() -> bytes4:
return IFoo.get_amount.method_id
"""
c = get_contract(code, input_bundle=input_bundle)
result = c.get_method_id()
expected = method_id("get_amount(address,address)")
assert result == expected
69 changes: 69 additions & 0 deletions tests/functional/syntax/test_interface_method_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

from vyper.compiler import compile_code
from vyper.exceptions import StructureException


valid_list = [
# basic method_id access
"""
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def foo() -> bytes4:
return Foo.transfer.method_id
""",
# use in raw_call
"""
interface Foo:
def bar(x: uint256) -> uint256: view

@external
def foo():
x: Bytes[32] = raw_call(
msg.sender,
concat(Foo.bar.method_id, convert(1, bytes32)),
max_outsize=32
)
""",
]


@pytest.mark.parametrize("code", valid_list)
def test_interface_method_id_pass(code):
assert compile_code(code) is not None


def test_interface_method_id_no_instance_access():
code = """
interface Foo:
def transfer(to: address, amount: uint256): nonpayable

@external
def foo(addr: address) -> bytes4:
return Foo(addr).transfer.method_id
"""
with pytest.raises(StructureException):
compile_code(code)


def test_interface_method_id_default_args(make_input_bundle):
iface_code = """
@external
def take(
auction_id: uint256,
max_take_amount: uint256 = ...,
) -> uint256:
...
"""
input_bundle = make_input_bundle({"ifoo.vyi": iface_code})

code = """
import ifoo as IFoo

@external
def foo() -> bytes4:
return IFoo.take.method_id
"""
assert compile_code(code, input_bundle=input_bundle) is not None
11 changes: 11 additions & 0 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,17 @@ def parse_Attribute(self):
value = 2**flag_id # 0 => 0001, 1 => 0010, 2 => 0100, etc.
return IRnode.from_list(value, typ=typ)

# Interface.function.method_id, e.g. ERC20.transfer.method_id
if self.expr.attr == "method_id":
value_typ = self.expr.value._metadata["type"]
if is_type_t(value_typ, ContractFunctionT):
fn_t = value_typ.typedef
# use [-1] to get the full signature (all args, including defaults)
method_id = list(fn_t.method_ids.values())[-1]
# bytes4 is left-aligned in the 32-byte word
value = method_id << 224
return IRnode.from_list(value, typ=typ)

# x.balance: balance of address x
if self.expr.attr == "balance":
addr = Expr.parse_value_expr(self.expr.value, self.context)
Expand Down
13 changes: 12 additions & 1 deletion vyper/codegen_venom/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,19 @@ def lower_Attribute(self) -> VyperValue:
value = 2**flag_id # 0 => 1, 1 => 2, 2 => 4, etc.
return VyperValue.from_stack_op(IRLiteral(value), typ)

# Case 2: Address properties
# Case 1b: Interface.function.method_id (e.g. ERC20.transfer.method_id)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Case 1 should be modified to Case 1a then

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed here 532e685

attr = node.attr
if attr == "method_id":
value_typ = node.value._metadata.get("type")
if is_type_t(value_typ, ContractFunctionT):
fn_t = value_typ.typedef
# use [-1] to get the full signature (all args, including defaults)
method_id = list(fn_t.method_ids.values())[-1]
# bytes4 is left-aligned in the 32-byte word
value = method_id << 224
return VyperValue.from_stack_op(IRLiteral(value), typ)

# Case 2: Address properties
if attr == "balance":
sub = Expr(node.value, self.ctx).lower_value()
return VyperValue.from_stack_op(self.builder.balance(sub), UINT256_T)
Expand Down
7 changes: 6 additions & 1 deletion vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vyper.semantics.types.base import KwargSettings, VyperType
from vyper.semantics.types.bytestrings import BytesT
from vyper.semantics.types.primitives import BoolT
from vyper.semantics.types.shortcuts import UINT256_T
from vyper.semantics.types.shortcuts import BYTES4_T, UINT256_T
from vyper.semantics.types.subscriptable import TupleT
from vyper.semantics.types.utils import type_from_abi, type_from_annotation
from vyper.utils import OrderedSet, keccak256
Expand Down Expand Up @@ -662,6 +662,11 @@ def method_ids(self) -> Dict[str, int]:
method_ids.update(_generate_method_id(self.name, arg_types[:i]))
return method_ids

def get_type_member(self, attr, node):
if attr == "method_id":
return BYTES4_T
raise StructureException(f"{self} has no type member '{attr}'", node)

# add more information to type exceptions generated inside calls
def _enhance_call_exception(self, e, ast_node=None):
if ast_node is not None:
Expand Down
4 changes: 3 additions & 1 deletion vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def __init__(
self.decl_node = decl_node

def get_type_member(self, attr, node):
# get an event, struct or flag from this interface
# get a function, event, struct or flag from this interface
if attr in self.functions:
return TYPE_T(self.functions[attr])
Comment on lines +76 to +77

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reject interface function types in annotations

Exposing interface functions from get_type_member makes IFoo.transfer parse as a concrete annotation type via type_from_annotation(...), even though ContractFunctionT is not a valid ABI/storage type. As a result, declarations such as x: public(IFoo.transfer) can get past type parsing and then fail later when ABI/getter generation needs to_abi_arg (function types do not implement abi_type), producing an internal compiler failure instead of a user-facing type error. This member should be gated to expression use (...method_id) and not accepted as a general annotation type.

Useful? React with 👍 / 👎.

return TYPE_T(self._helper.get_member(attr, node))

@property
Expand Down