Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[lang]: stateless modules should not be initialized #4347

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
12 changes: 12 additions & 0 deletions tests/functional/codegen/features/test_transient.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,15 @@ def foo() -> (uint256[3], uint256, uint256, uint256):

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == ([1, 2, 3], 1, 2, 42)


def test_transient_is_state(make_input_bundle):
lib = """
message: transient(bool)
"""
main = """
import lib
initializes: lib
"""
input_bundle = make_input_bundle({"lib.vy": lib, "main.vy": main})
compile_code(main, input_bundle=input_bundle)
4 changes: 4 additions & 0 deletions tests/functional/codegen/modules/test_nonreentrant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
def test_export_nonreentrant(make_input_bundle, get_contract, tx_failed):
lib1 = """
phony: uint32

interface Foo:
def foo() -> uint256: nonpayable

Expand Down Expand Up @@ -38,6 +40,8 @@ def __default__():

def test_internal_nonreentrant(make_input_bundle, get_contract, tx_failed):
lib1 = """
phony: uint32

interface Foo:
def foo() -> uint256: nonpayable

Expand Down
133 changes: 133 additions & 0 deletions tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,8 @@ def test_uses_skip_import(make_input_bundle):
lib2 = """
import lib1

phony: uint32

@internal
def foo():
pass
Expand Down Expand Up @@ -1418,3 +1420,134 @@ def set_some_mod():
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "module `lib3.vy` is used but never initialized!"
assert e.value._hint is None


storage_var_modules = [
"""
phony: uint32
""",
"""
ended: public(bool)
""",
]


@pytest.mark.parametrize("module", storage_var_modules)
def test_initializes_on_modules_with_state_related_vars(module, make_input_bundle):
main = """
import lib
initializes: lib
"""
input_bundle = make_input_bundle({"lib.vy": module, "main.vy": main})
compile_code(main, input_bundle=input_bundle)


def test_initializes_on_modules_with_immutables(make_input_bundle):
lib = """
foo: immutable(int128)

@deploy
def __init__():
foo = 2
"""

main = """
import lib
initializes: lib

@deploy
def __init__():
lib.__init__()
"""
input_bundle = make_input_bundle({"lib.vy": lib, "main.vy": main})
compile_code(main, input_bundle=input_bundle)


stateless_modules = [
"""
""",
"""
@internal
@pure
def foo(x: uint256, y: uint256) -> uint256:
return unsafe_add(x & y, (x ^ y) >> 1)
""",
"""
FOO: constant(int128) = 128
""",
]


@pytest.mark.parametrize("module", stateless_modules)
def test_forbids_initializes_on_stateless_modules(module, make_input_bundle):
main = """
import lib
initializes: lib
"""
input_bundle = make_input_bundle({"lib.vy": module, "main.vy": main})
with pytest.raises(StructureException):
compile_code(main, input_bundle=input_bundle)


def test_initializes_on_modules_with_uses(make_input_bundle):
lib0 = """
import lib1
uses: lib1

@external
def foo() -> uint32:
return lib1.phony
"""
lib1 = """
phony: uint32
"""
main = """
import lib1
initializes: lib1

import lib0
initializes: lib0[lib1 := lib1]
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib0.vy": lib0, "main.vy": main})
compile_code(main, input_bundle=input_bundle)


def test_initializes_on_modules_with_initializes(make_input_bundle):
lib0 = """
import lib1
initializes: lib1
"""
lib1 = """
phony: uint32
"""
main = """
import lib0
initializes: lib0
"""
input_bundle = make_input_bundle({"lib1.vy": lib1, "lib0.vy": lib0, "main.vy": main})
compile_code(main, input_bundle=input_bundle)


def test_initializes_on_modules_with_init_function(make_input_bundle):
lib = """
interface Foo:
def foo(): payable

@deploy
def __init__():
extcall Foo(self).foo()
"""
main = """
import lib
initializes: lib

@deploy
def __init__():
lib.__init__()

@external
def foo():
pass
"""
input_bundle = make_input_bundle({"lib.vy": lib, "main.vy": main})
compile_code(main, input_bundle=input_bundle)
8 changes: 8 additions & 0 deletions tests/unit/compiler/test_abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def bar(x: {type}):

def test_exports_abi(make_input_bundle):
lib1 = """
phony: uint32

@external
def foo():
pass
Expand Down Expand Up @@ -330,6 +332,8 @@ def __init__():
def test_event_export_from_init(make_input_bundle):
# test that events get exported when used in init functions
lib1 = """
phony: uint32

event MyEvent:
pass

Expand Down Expand Up @@ -361,6 +365,8 @@ def __init__():
def test_event_export_from_function_export(make_input_bundle):
# test events used in exported functions are exported
lib1 = """
phony: uint32

event MyEvent:
pass

Expand Down Expand Up @@ -396,6 +402,8 @@ def foo():
def test_event_export_unused_function(make_input_bundle):
# test events in unused functions are not exported
lib1 = """
phony: uint32

event MyEvent:
pass

Expand Down
4 changes: 4 additions & 0 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,10 @@ def visit_InitializesDecl(self, node):
module_info = get_expr_info(module_ref).module_info
if module_info is None:
raise StructureException("Not a module!", module_ref)
if module_info.module_t.is_stateless():
raise StructureException(
f"Cannot initialize a stateless module {module_info.alias}!", module_ref
)

used_modules = {i.module_t: i for i in module_info.module_t.used_modules}

Expand Down
15 changes: 15 additions & 0 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,18 @@ def immutable_section_bytes(self):
@cached_property
def interface(self):
return InterfaceT.from_ModuleT(self)

def is_stateless(self):
"""
Determine whether ModuleT is stateless by examining its top-level declarations.
A module has state if it contains storage variables, transient variables, or
immutables, or if it includes a "uses" or "initializes" declaration.
"""
for i in self._module.body:
if isinstance(i, (vy_ast.InitializesDecl, vy_ast.UsesDecl)):
return False
if isinstance(i, vy_ast.VariableDecl) and not i.is_constant:
return False
if isinstance(i, vy_ast.FunctionDef) and i.name == "__init__":
return False
return True
Loading