diff --git a/src/zkevm_specs/evm_circuit/execution/create.py b/src/zkevm_specs/evm_circuit/execution/create.py index dd20b4bfb..16a1058a3 100644 --- a/src/zkevm_specs/evm_circuit/execution/create.py +++ b/src/zkevm_specs/evm_circuit/execution/create.py @@ -4,10 +4,12 @@ from zkevm_specs.util.param import ( GAS_COST_COPY_SHA3, GAS_COST_CREATE, + MAX_U64, N_BYTES_ACCOUNT_ADDRESS, N_BYTES_GAS, N_BYTES_MEMORY_ADDRESS, N_BYTES_MEMORY_SIZE, + N_BYTES_STACK, N_BYTES_U64, ) from ...util import ( @@ -40,64 +42,17 @@ def create(instruction: Instruction): tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId) caller_address = instruction.call_context_lookup(CallContextFieldTag.CallerAddress) nonce, nonce_prev = instruction.account_write(caller_address, AccountFieldTag.Nonce) + _, balance_prev = instruction.account_write(caller_address, AccountFieldTag.Balance) is_success = instruction.call_context_lookup(CallContextFieldTag.IsSuccess) is_static = instruction.call_context_lookup(CallContextFieldTag.IsStatic) reversion_info = instruction.reversion_info() - # calculate contract address - contract_address = ( - instruction.generate_contract_address(caller_address, nonce) - if is_create == 1 - else instruction.generate_CREAET2_contract_address( - caller_address, salt_word, instruction.next.code_hash - ) - ) - - # verify the equality of input `size` and length of calldata - code_size = instruction.bytecode_length(instruction.next.code_hash) - instruction.constrain_equal(code_size, size) - - # verify return contract address - instruction.constrain_equal( - instruction.word_to_fq(return_contract_address_word, N_BYTES_ACCOUNT_ADDRESS), - is_success * contract_address, - ) - - # Verify depth is less than 1024 - instruction.range_lookup(depth, CALL_CREATE_DEPTH) - - # ErrNonceUintOverflow constraint - (is_not_overflow, _) = instruction.compare(nonce, nonce_prev, N_BYTES_U64) - instruction.is_zero(is_not_overflow) - - # add contract address to access list - instruction.add_account_to_access_list(tx_id, contract_address) - - # ErrContractAddressCollision constraint - # code_hash_prev could be either 0 or EMPTY_CODE_HASH - # code_hash should be EMPTY_CODE_HASH to make sure the account is created properly - code_hash, code_hash_prev = instruction.account_write_word( - contract_address, AccountFieldTag.CodeHash - ) - instruction.constrain_in_word( - code_hash_prev, - [Word(0), Word(EMPTY_CODE_HASH)], - ) - instruction.constrain_equal_word(code_hash, Word(EMPTY_CODE_HASH)) - - # Propagate is_persistent - callee_reversion_info = instruction.reversion_info(call_id=callee_call_id) - instruction.constrain_equal( - callee_reversion_info.is_persistent, - reversion_info.is_persistent * is_success.expr(), - ) + has_init_code = size != FQ(0) # can't be a STATICCALL instruction.is_zero(is_static) - # transfer value from caller to contract address - instruction.transfer(caller_address, contract_address, value_word, callee_reversion_info) - + ### Gas cost calculation # gas cost of memory expansion ( next_memory_size, @@ -130,71 +85,181 @@ def create(instruction: Instruction): all_but_one_64th_gas, ) - # copy init_code from memory to bytecode - copy_rwc_inc, _ = instruction.copy_lookup( - instruction.curr.call_id, # src_id - CopyDataTypeTag.Memory, # src_type - instruction.next.code_hash, # dst_id - CopyDataTypeTag.Bytecode, # dst_type - offset, # src_addr - offset + size, # src_addr_boundary - FQ(0), # dst_addr - size, # length - instruction.curr.rw_counter + instruction.rw_counter_offset, + ### Do stack depth, nonce and balance pre-check + # ErrDepth constraint + is_error_depth, _ = instruction.compare(FQ(CALL_CREATE_DEPTH), depth, N_BYTES_STACK) + # ErrInsufficientBalance constraint + is_insufficient_balance, _ = instruction.compare_word(Word(balance_prev.expr().n), value_word) + # ErrNonceUintOverflow constraint + is_nonce_in_range, _ = instruction.compare(nonce_prev, FQ(MAX_U64), N_BYTES_U64) + + # pass the pre-check if none of above errors happen + is_precheck_ok = ( + is_error_depth == FQ(0) and is_insufficient_balance == FQ(0) and is_nonce_in_range == FQ(1) ) - instruction.rw_counter_offset += int(copy_rwc_inc) - - # CREATE: 3 pops and 1 push, stack delta = 2 - # CREATE2: 4 pops and 1 push, stack delta = 3 - stack_pointer_delta = 2 + is_create2 - # Save caller's call state - for field_tag, expected_value in [ - (CallContextFieldTag.ProgramCounter, instruction.curr.program_counter + 1), - ( - CallContextFieldTag.StackPointer, - instruction.curr.stack_pointer + stack_pointer_delta, - ), - (CallContextFieldTag.GasLeft, gas_left - gas_cost - callee_gas_left), - (CallContextFieldTag.MemorySize, next_memory_size), - ( - CallContextFieldTag.ReversibleWriteCounter, - instruction.curr.reversible_write_counter + 1, - ), - ]: + + # error cases, should end this call and return + # is_static should be false + if not is_precheck_ok: + for field_tag, expected_value in [ + (CallContextFieldTag.LastCalleeId, FQ(0)), + (CallContextFieldTag.LastCalleeReturnDataOffset, FQ(0)), + (CallContextFieldTag.LastCalleeReturnDataLength, FQ(0)), + ]: + instruction.constrain_equal( + instruction.call_context_lookup(field_tag, RW.Write), + expected_value, + ) + + instruction.constrain_step_state_transition( + rw_counter=Transition.delta(instruction.rw_counter_offset), + program_counter=Transition.delta(1), + stack_pointer=Transition.delta(2 + is_create2), + reversible_write_counter=Transition.delta(1), + gas_left=Transition.delta(-gas_cost), + memory_word_size=Transition.to(next_memory_size), + # Always stay same + call_id=Transition.same(), + is_root=Transition.same(), + is_create=Transition.same(), + code_hash=Transition.same_word(), + ) + else: + # calculate contract address + code_hash = instruction.next.code_hash if has_init_code else Word(EMPTY_CODE_HASH) + contract_address = ( + instruction.generate_contract_address(caller_address, nonce) + if is_create == 1 + else instruction.generate_CREAET2_contract_address(caller_address, salt_word, code_hash) + ) + # add contract address to access list + instruction.add_account_to_access_list(tx_id, contract_address) + + if has_init_code: + # verify the equality of input `size` and length of calldata + code_size = instruction.bytecode_length(instruction.next.code_hash) + instruction.constrain_equal(code_size, size) + + # verify return contract address instruction.constrain_equal( - instruction.call_context_lookup(field_tag, RW.Write), - expected_value, + instruction.word_to_fq(return_contract_address_word, N_BYTES_ACCOUNT_ADDRESS), + is_success * contract_address, + ) + + # ErrContractAddressCollision constraint + # code_hash_prev could be either 0 or EMPTY_CODE_HASH + # code_hash should be EMPTY_CODE_HASH to make sure the account is created properly + code_hash, code_hash_prev = instruction.account_write_word( + contract_address, AccountFieldTag.CodeHash ) + instruction.constrain_in_word( + code_hash_prev, + [Word(0), Word(EMPTY_CODE_HASH)], + ) + instruction.constrain_equal_word(code_hash, Word(EMPTY_CODE_HASH)) - # Setup next call's context. - for field_tag, expected_value in [ - (CallContextFieldTag.CallerId, instruction.curr.call_id), - (CallContextFieldTag.TxId, tx_id), - (CallContextFieldTag.Depth, depth + 1), - (CallContextFieldTag.CallerAddress, caller_address), - (CallContextFieldTag.CalleeAddress, contract_address), - (CallContextFieldTag.IsSuccess, is_success), - (CallContextFieldTag.IsStatic, FQ(False)), - (CallContextFieldTag.IsRoot, FQ(False)), - (CallContextFieldTag.IsCreate, FQ(True)), - ]: + # Propagate is_persistent + callee_reversion_info = instruction.reversion_info(call_id=callee_call_id) instruction.constrain_equal( - instruction.call_context_lookup(field_tag, call_id=callee_call_id), - expected_value, + callee_reversion_info.is_persistent, + reversion_info.is_persistent * is_success.expr(), ) - instruction.constrain_equal_word( - instruction.call_context_lookup_word(CallContextFieldTag.CodeHash, call_id=callee_call_id), - code_hash, - ) - instruction.step_state_transition_to_new_context( - rw_counter=Transition.delta(instruction.rw_counter_offset), - call_id=Transition.to(callee_call_id), - is_root=Transition.to(False), - is_create=Transition.to(True), - code_hash=Transition.to_word(instruction.next.code_hash), - gas_left=Transition.to(callee_gas_left), - # `transfer` includes two balance updates - reversible_write_counter=Transition.to(2), - log_id=Transition.same(), - ) + # transfer value from caller to contract address + instruction.transfer(caller_address, contract_address, value_word, callee_reversion_info) + + # CREATE: 3 pops and 1 push, stack delta = 2 + # CREATE2: 4 pops and 1 push, stack delta = 3 + stack_pointer_delta = 2 + is_create2 + + if has_init_code: + # copy init_code from memory to bytecode + copy_rwc_inc, _ = instruction.copy_lookup( + instruction.curr.call_id, # src_id + CopyDataTypeTag.Memory, # src_type + instruction.next.code_hash, # dst_id + CopyDataTypeTag.Bytecode, # dst_type + offset, # src_addr + offset + size, # src_addr_boundary + FQ(0), # dst_addr + size, # length + instruction.curr.rw_counter + instruction.rw_counter_offset, + ) + instruction.rw_counter_offset += int(copy_rwc_inc) + + # Save caller's call state + for field_tag, expected_value in [ + (CallContextFieldTag.ProgramCounter, instruction.curr.program_counter + 1), + ( + CallContextFieldTag.StackPointer, + instruction.curr.stack_pointer + stack_pointer_delta, + ), + (CallContextFieldTag.GasLeft, gas_left - gas_cost - callee_gas_left), + (CallContextFieldTag.MemorySize, next_memory_size), + ( + CallContextFieldTag.ReversibleWriteCounter, + instruction.curr.reversible_write_counter + 1, + ), + ]: + instruction.constrain_equal( + instruction.call_context_lookup(field_tag, RW.Write), + expected_value, + ) + # Setup next call's context. + for field_tag, expected_value in [ + (CallContextFieldTag.CallerId, instruction.curr.call_id), + (CallContextFieldTag.TxId, tx_id), + (CallContextFieldTag.Depth, depth + 1), + (CallContextFieldTag.CallerAddress, caller_address), + (CallContextFieldTag.CalleeAddress, contract_address), + (CallContextFieldTag.IsSuccess, is_success), + (CallContextFieldTag.IsStatic, FQ(False)), + (CallContextFieldTag.IsRoot, FQ(False)), + (CallContextFieldTag.IsCreate, FQ(True)), + ]: + instruction.constrain_equal( + instruction.call_context_lookup(field_tag, call_id=callee_call_id), + expected_value, + ) + instruction.constrain_equal_word( + instruction.call_context_lookup_word( + CallContextFieldTag.CodeHash, call_id=callee_call_id + ), + code_hash, + ) + + instruction.step_state_transition_to_new_context( + rw_counter=Transition.delta(instruction.rw_counter_offset), + call_id=Transition.to(callee_call_id), + is_root=Transition.to(False), + is_create=Transition.to(True), + code_hash=Transition.to_word(instruction.next.code_hash), + gas_left=Transition.to(callee_gas_left), + # `transfer` includes two balance updates + reversible_write_counter=Transition.to(2), + log_id=Transition.same(), + ) + else: + for field_tag, expected_value in [ + (CallContextFieldTag.LastCalleeId, FQ(0)), + (CallContextFieldTag.LastCalleeReturnDataOffset, FQ(0)), + (CallContextFieldTag.LastCalleeReturnDataLength, FQ(0)), + ]: + instruction.constrain_equal( + instruction.call_context_lookup(field_tag, RW.Write), + expected_value, + ) + + instruction.constrain_step_state_transition( + rw_counter=Transition.delta(instruction.rw_counter_offset), + program_counter=Transition.delta(1), + stack_pointer=Transition.delta(stack_pointer_delta), + gas_left=Transition.delta(-gas_cost), + reversible_write_counter=Transition.delta(3), + memory_word_size=Transition.to(next_memory_size), + # Always stay same + call_id=Transition.same(), + is_root=Transition.same(), + is_create=Transition.same(), + code_hash=Transition.same_word(), + ) diff --git a/src/zkevm_specs/evm_circuit/instruction.py b/src/zkevm_specs/evm_circuit/instruction.py index 8a8bd6765..0bde253f6 100644 --- a/src/zkevm_specs/evm_circuit/instruction.py +++ b/src/zkevm_specs/evm_circuit/instruction.py @@ -1108,8 +1108,10 @@ def memory_gas_cost(self, memory_size: Expression) -> FQ: return quadratic_cost + linear_cost def memory_expansion(self, offset: Expression, length: Expression) -> Tuple[FQ, FQ]: - memory_size, _ = self.constant_divmod( - length.expr() + offset.expr() + 31, FQ(32), N_BYTES_MEMORY_SIZE + memory_size, _ = ( + self.constant_divmod(length.expr() + offset.expr() + 31, FQ(32), N_BYTES_MEMORY_SIZE) + if length != FQ(0) + else (FQ(0), FQ(0)) ) next_memory_size = self.max(self.curr.memory_word_size, memory_size, N_BYTES_MEMORY_SIZE) diff --git a/tests/evm/test_create.py b/tests/evm/test_create.py index 18f55ce53..e10437fac 100644 --- a/tests/evm/test_create.py +++ b/tests/evm/test_create.py @@ -44,7 +44,10 @@ CALLER = Account(address=0xFE, balance=int(1e20), nonce=10) -def gen_bytecode(is_return: bool, offset: int) -> Bytecode: +def gen_bytecode(is_return: bool, offset: int, has_init_code: bool) -> Bytecode: + if not has_init_code: + return Bytecode() + """Generate bytecode that has 64 bytes of memory initialized and returns with `offset` and `length`""" bytecode = ( Bytecode() @@ -98,6 +101,7 @@ def memory_size(offset: int, length: int) -> int: return ( caller_gas_left, callee_gas_left, + gas_cost, next_memory_size, ) @@ -113,14 +117,15 @@ def gen_testing_data(): ] create_contexts = [ CreateContext(gas_left=1_000_000, is_persistent=True), - CreateContext(gas_left=1_000_000, is_persistent=False, rw_counter_end_of_reversion=88), + CreateContext(gas_left=1_000_000, is_persistent=False, rw_counter_end_of_reversion=80), ] stacks = [ Stack(value=int(1e18), offset=64, salt=int(12345)), - Stack(offset=200), - Stack(offset=0), + Stack(value=int(1e25), offset=64), # insufficient balance ] - is_warm_accesss = [True, False] + stack_depth = [1, 1024, 1025] + is_warm_access = [True, False] + has_init_code = [True, False] return [ ( @@ -129,10 +134,18 @@ def gen_testing_data(): is_return, create_contexts, stack, + stack_depth, is_warm_access, + has_init_code, ) - for opcode, is_return, create_contexts, stack, is_warm_access in product( - opcodes, is_return, create_contexts, stacks, is_warm_accesss + for opcode, is_return, create_contexts, stack, stack_depth, is_warm_access, has_init_code in product( + opcodes, + is_return, + create_contexts, + stacks, + stack_depth, + is_warm_access, + has_init_code, ) ] @@ -141,25 +154,29 @@ def gen_testing_data(): @pytest.mark.parametrize( - "opcode, caller, is_return, caller_ctx, stack, is_warm_access", + "opcode, caller, is_return, caller_ctx, stack, stack_depth, is_warm_access, has_init_code", TESTING_DATA, ) def test_create_create2( opcode: Opcode, caller: Account, - is_return: Bytecode, + is_return: bool, caller_ctx: CreateContext, stack: Stack, + stack_depth: int, is_warm_access: bool, + has_init_code: bool, ): randomness_keccak = rand_fq() CURRENT_CALL_ID = 1 + # can't be a static all + is_static = 0 - init_codes = gen_bytecode(is_return, stack.offset) + init_codes = gen_bytecode(is_return, stack.offset, has_init_code) stack = stack._replace(size=len(init_codes.code)) init_codes_hash = Word(init_codes.hash()) - init_bytecode = gen_bytecode(is_return, stack.offset) + init_bytecode = gen_bytecode(is_return, stack.offset, has_init_code) is_create2 = 1 if opcode == Opcode.CREATE2 else 0 if is_create2 == 1: caller_bytecode = init_bytecode.create2( @@ -176,7 +193,7 @@ def test_create_create2( ).stop() caller_bytecode_hash = Word(caller_bytecode.hash()) - (caller_gas_left, callee_gas_left, next_memory_size) = calc_gas_cost( + (caller_gas_left, callee_gas_left, gas_cost, next_memory_size) = calc_gas_cost( opcode, caller_ctx, stack, @@ -209,20 +226,31 @@ def test_create_create2( contract_addr = keccak256(rlp.encode([caller.address.to_bytes(20, "big"), caller.nonce])) contract_address = int.from_bytes(contract_addr[-20:], "big") - # can't be a static all - is_static = 0 - - next_call_id = 65 + next_call_id = 66 rw_counter = next_call_id + # CREATE: 33 * 3(push) + 1(CREATE) + 1(mstore) + 33(PUSH32) + 2(PUSH) + 1(RETURN) # CREATE2: 33 * 4(push) + 1(CREATE2) + 1(mstore) + 33(PUSH32) + 2(PUSH) + 1(RETURN) - next_program_counter = 33 * 4 + 1 + 35 + 1 if is_create2 else 33 * 3 + 1 + 35 + 1 - assert caller_bytecode.code[next_program_counter] == opcode + next_program_counter = 33 * 4 + 1 if is_create2 else 33 * 3 + 1 + if has_init_code: + next_program_counter += 1 + 35 + 1 + assert caller_bytecode.code[next_program_counter - 1] == opcode # CREATE: 1024 - 3 + 1 = 1022 # CREATE2: 1024 - 4 + 1 = 1021 stack_pointer = 1021 - is_create2 + # caller and callee balance + caller_balance_prev = caller.balance + callee_balance_prev = 0 + caller_balance = caller_balance_prev - stack.value + callee_balance = callee_balance_prev + stack.value + + is_precheck_ok = ( + (caller_balance >= stack.value) and (nonce > nonce - 1) and (stack_depth <= 1024) + ) + should_move_to_next_context = is_precheck_ok and not is_static + src_data = dict( [ ( @@ -252,83 +280,105 @@ def test_create_create2( # caller's call context rw_dictionary \ - .call_context_read(CURRENT_CALL_ID, CallContextFieldTag.Depth, 1) \ + .call_context_read(CURRENT_CALL_ID, CallContextFieldTag.Depth, stack_depth) \ .call_context_read(CURRENT_CALL_ID, CallContextFieldTag.TxId, 1) \ .call_context_read(CURRENT_CALL_ID, CallContextFieldTag.CallerAddress, caller.address) \ .account_write(caller.address, AccountFieldTag.Nonce, nonce, nonce - 1) \ + .account_write(caller.address, AccountFieldTag.Balance, caller_balance, caller_balance_prev) \ .call_context_read(CURRENT_CALL_ID, CallContextFieldTag.IsSuccess, is_success) \ .call_context_read(CURRENT_CALL_ID, CallContextFieldTag.IsStatic, is_static) \ .call_context_read(CURRENT_CALL_ID, CallContextFieldTag.RwCounterEndOfReversion, caller_ctx.rw_counter_end_of_reversion) \ - .call_context_read(CURRENT_CALL_ID, CallContextFieldTag.IsPersistent, caller_ctx.is_persistent) \ - .tx_access_list_account_write(CURRENT_CALL_ID, contract_address, True, is_warm_access, rw_counter_of_reversion=None if caller_ctx.is_persistent else caller_ctx.rw_counter_end_of_reversion - caller_ctx.reversible_write_counter) \ - .account_write(contract_address, AccountFieldTag.CodeHash, Word(EMPTY_CODE_HASH), 0) \ - - # callee's reversion_info - rw_dictionary \ - .call_context_read(next_call_id, CallContextFieldTag.RwCounterEndOfReversion, callee_rw_counter_end_of_reversion) \ - .call_context_read(next_call_id, CallContextFieldTag.IsPersistent, callee_is_persistent) - - # For `transfer` invocation. - caller_balance_prev = Word(caller.balance) - callee_balance_prev = Word(0) - caller_balance = Word(caller.balance - stack.value) - callee_balance = Word(stack.value) - rw_dictionary \ - .account_write(caller.address, AccountFieldTag.Balance, caller_balance, caller_balance_prev, rw_counter_of_reversion=None if callee_is_persistent else callee_rw_counter_end_of_reversion) \ - .account_write(contract_address, AccountFieldTag.Balance, callee_balance, callee_balance_prev, rw_counter_of_reversion=None if callee_is_persistent else callee_rw_counter_end_of_reversion - 1) - - # copy_table - copy_circuit = CopyCircuit().copy( - randomness_keccak, - rw_dictionary, - CURRENT_CALL_ID, - CopyDataTypeTag.Memory, - init_codes_hash, - CopyDataTypeTag.Bytecode, - stack.offset, - stack.offset + stack.size, - 0, - stack.size, - src_data, - ) + .call_context_read(CURRENT_CALL_ID, CallContextFieldTag.IsPersistent, caller_ctx.is_persistent) + + if should_move_to_next_context: + rw_dictionary \ + .tx_access_list_account_write(CURRENT_CALL_ID, contract_address, True, is_warm_access, rw_counter_of_reversion=None if caller_ctx.is_persistent else caller_ctx.rw_counter_end_of_reversion - caller_ctx.reversible_write_counter) \ + .account_write(contract_address, AccountFieldTag.CodeHash, Word(EMPTY_CODE_HASH), 0) + + # callee's reversion_info + rw_dictionary \ + .call_context_read(next_call_id, CallContextFieldTag.RwCounterEndOfReversion, callee_rw_counter_end_of_reversion) \ + .call_context_read(next_call_id, CallContextFieldTag.IsPersistent, callee_is_persistent) + + # For `transfer` invocation. + rw_dictionary \ + .account_write(caller.address, AccountFieldTag.Balance, Word(caller_balance), Word(caller_balance_prev), rw_counter_of_reversion=None if callee_is_persistent else callee_rw_counter_end_of_reversion) \ + .account_write(contract_address, AccountFieldTag.Balance, Word(callee_balance), Word(callee_balance_prev), rw_counter_of_reversion=None if callee_is_persistent else callee_rw_counter_end_of_reversion - 1) + + if has_init_code and should_move_to_next_context: + # copy_table + copy_circuit = CopyCircuit().copy( + randomness_keccak, + rw_dictionary, + CURRENT_CALL_ID, + CopyDataTypeTag.Memory, + init_codes_hash, + CopyDataTypeTag.Bytecode, + stack.offset, + stack.offset + stack.size, + 0, + stack.size, + src_data, + ) - # caller's call context - rw_dictionary \ - .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.ProgramCounter, next_program_counter + 1) \ - .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.StackPointer, 1023) \ - .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.GasLeft, caller_gas_left) \ - .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.MemorySize, next_memory_size) \ - .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.ReversibleWriteCounter, caller_ctx.reversible_write_counter + 1) + # caller's call context + rw_dictionary \ + .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.ProgramCounter, next_program_counter) \ + .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.StackPointer, 1023) \ + .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.GasLeft, caller_gas_left) \ + .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.MemorySize, next_memory_size) \ + .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.ReversibleWriteCounter, caller_ctx.reversible_write_counter + 1) + + # callee's call context + rw_dictionary \ + .call_context_read(next_call_id, CallContextFieldTag.CallerId, CURRENT_CALL_ID) \ + .call_context_read(next_call_id, CallContextFieldTag.TxId, 1) \ + .call_context_read(next_call_id, CallContextFieldTag.Depth, stack_depth+1) \ + .call_context_read(next_call_id, CallContextFieldTag.CallerAddress, caller.address) \ + .call_context_read(next_call_id, CallContextFieldTag.CalleeAddress, contract_address) \ + .call_context_read(next_call_id, CallContextFieldTag.IsSuccess, is_success) \ + .call_context_read(next_call_id, CallContextFieldTag.IsStatic, is_static) \ + .call_context_read(next_call_id, CallContextFieldTag.IsRoot, False) \ + .call_context_read(next_call_id, CallContextFieldTag.IsCreate, True) \ + .call_context_read(next_call_id, CallContextFieldTag.CodeHash, Word(EMPTY_CODE_HASH)) - # callee's call context - rw_dictionary \ - .call_context_read(next_call_id, CallContextFieldTag.CallerId, CURRENT_CALL_ID) \ - .call_context_read(next_call_id, CallContextFieldTag.TxId, 1) \ - .call_context_read(next_call_id, CallContextFieldTag.Depth, 2) \ - .call_context_read(next_call_id, CallContextFieldTag.CallerAddress, caller.address) \ - .call_context_read(next_call_id, CallContextFieldTag.CalleeAddress, contract_address) \ - .call_context_read(next_call_id, CallContextFieldTag.IsSuccess, is_success) \ - .call_context_read(next_call_id, CallContextFieldTag.IsStatic, is_static) \ - .call_context_read(next_call_id, CallContextFieldTag.IsRoot, False) \ - .call_context_read(next_call_id, CallContextFieldTag.IsCreate, True) \ - .call_context_read(next_call_id, CallContextFieldTag.CodeHash, Word(EMPTY_CODE_HASH)) + + tables = Tables( + block_table=set(Block().table_assignments()), + tx_table=set(), + bytecode_table=set( + chain( + caller_bytecode.table_assignments(), + init_codes.table_assignments(), + ) + ), + rw_table=set(rw_dictionary.rws), + copy_circuit=copy_circuit.rows, + ) + verify_copy_table(copy_circuit, tables, randomness_keccak) + + else: + # caller's call context + rw_dictionary \ + .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.LastCalleeId, 0) \ + .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.LastCalleeReturnDataOffset, 0) \ + .call_context_write(CURRENT_CALL_ID, CallContextFieldTag.LastCalleeReturnDataLength, 0) + + tables = Tables( + block_table=set(Block().table_assignments()), + tx_table=set(), + bytecode_table=set( + chain( + caller_bytecode.table_assignments(), + init_codes.table_assignments(), + ) + ), + rw_table=set(rw_dictionary.rws),) # fmt: on - tables = Tables( - block_table=set(Block().table_assignments()), - tx_table=set(), - bytecode_table=set( - chain( - caller_bytecode.table_assignments(), - init_codes.table_assignments(), - ) - ), - rw_table=set(rw_dictionary.rws), - copy_circuit=copy_circuit.rows, + reversible_write_counter = caller_ctx.reversible_write_counter + ( + 3 if should_move_to_next_context else 1 ) - - verify_copy_table(copy_circuit, tables, randomness_keccak) - verify_steps( tables=tables, steps=[ @@ -341,7 +391,7 @@ def test_create_create2( is_root=False, is_create=True, code_hash=caller_bytecode_hash, - program_counter=next_program_counter, + program_counter=next_program_counter - 1, stack_pointer=stack_pointer, gas_left=caller_ctx.gas_left, memory_word_size=caller_ctx.memory_word_size, @@ -360,6 +410,20 @@ def test_create_create2( gas_left=callee_gas_left, reversible_write_counter=2, ) + if has_init_code and should_move_to_next_context + else StepState( + execution_state=ExecutionState.PUSH, + rw_counter=rw_dictionary.rw_counter, + call_id=CURRENT_CALL_ID, + is_root=False, + is_create=True, + code_hash=caller_bytecode_hash, + program_counter=next_program_counter, + stack_pointer=1023, + gas_left=caller_ctx.gas_left - gas_cost, + memory_word_size=next_memory_size, + reversible_write_counter=reversible_write_counter, + ) ), ], )