Skip to content
This repository has been archived by the owner on Jul 5, 2024. It is now read-only.

Commit

Permalink
feat/#412 CREATE/CREATE2 Synchronization (#416)
Browse files Browse the repository at this point in the history
* feat: handle empty init_code case

* fix lint

* feat: handle depth, nonce and insufficient balance properly

* feat: handle STATIC_CALL properly in CREATE/CREATE2

* fix: should revert while static_call and fix nonce constraint

* fix review comment

* Update tests/evm/test_create.py

Co-authored-by: Chih Cheng Liang <[email protected]>

---------

Co-authored-by: Chih Cheng Liang <[email protected]>
  • Loading branch information
KimiWu123 and ChihChengLiang authored May 9, 2023
1 parent 0da1195 commit 31201a2
Show file tree
Hide file tree
Showing 3 changed files with 329 additions and 198 deletions.
287 changes: 176 additions & 111 deletions src/zkevm_specs/evm_circuit/execution/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from zkevm_specs.util.param import (
GAS_COST_COPY_SHA3,
GAS_COST_CREATE,
MAX_U64,
N_BYTES_ACCOUNT_ADDRESS,
N_BYTES_GAS,
N_BYTES_MEMORY_ADDRESS,
N_BYTES_MEMORY_SIZE,
N_BYTES_STACK,
N_BYTES_U64,
)
from ...util import (
Expand Down Expand Up @@ -40,64 +42,17 @@ def create(instruction: Instruction):
tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId)
caller_address = instruction.call_context_lookup(CallContextFieldTag.CallerAddress)
nonce, nonce_prev = instruction.account_write(caller_address, AccountFieldTag.Nonce)
_, balance_prev = instruction.account_write(caller_address, AccountFieldTag.Balance)
is_success = instruction.call_context_lookup(CallContextFieldTag.IsSuccess)
is_static = instruction.call_context_lookup(CallContextFieldTag.IsStatic)
reversion_info = instruction.reversion_info()

# calculate contract address
contract_address = (
instruction.generate_contract_address(caller_address, nonce)
if is_create == 1
else instruction.generate_CREAET2_contract_address(
caller_address, salt_word, instruction.next.code_hash
)
)

# verify the equality of input `size` and length of calldata
code_size = instruction.bytecode_length(instruction.next.code_hash)
instruction.constrain_equal(code_size, size)

# verify return contract address
instruction.constrain_equal(
instruction.word_to_fq(return_contract_address_word, N_BYTES_ACCOUNT_ADDRESS),
is_success * contract_address,
)

# Verify depth is less than 1024
instruction.range_lookup(depth, CALL_CREATE_DEPTH)

# ErrNonceUintOverflow constraint
(is_not_overflow, _) = instruction.compare(nonce, nonce_prev, N_BYTES_U64)
instruction.is_zero(is_not_overflow)

# add contract address to access list
instruction.add_account_to_access_list(tx_id, contract_address)

# ErrContractAddressCollision constraint
# code_hash_prev could be either 0 or EMPTY_CODE_HASH
# code_hash should be EMPTY_CODE_HASH to make sure the account is created properly
code_hash, code_hash_prev = instruction.account_write_word(
contract_address, AccountFieldTag.CodeHash
)
instruction.constrain_in_word(
code_hash_prev,
[Word(0), Word(EMPTY_CODE_HASH)],
)
instruction.constrain_equal_word(code_hash, Word(EMPTY_CODE_HASH))

# Propagate is_persistent
callee_reversion_info = instruction.reversion_info(call_id=callee_call_id)
instruction.constrain_equal(
callee_reversion_info.is_persistent,
reversion_info.is_persistent * is_success.expr(),
)
has_init_code = size != FQ(0)

# can't be a STATICCALL
instruction.is_zero(is_static)

# transfer value from caller to contract address
instruction.transfer(caller_address, contract_address, value_word, callee_reversion_info)

### Gas cost calculation
# gas cost of memory expansion
(
next_memory_size,
Expand Down Expand Up @@ -130,71 +85,181 @@ def create(instruction: Instruction):
all_but_one_64th_gas,
)

# copy init_code from memory to bytecode
copy_rwc_inc, _ = instruction.copy_lookup(
instruction.curr.call_id, # src_id
CopyDataTypeTag.Memory, # src_type
instruction.next.code_hash, # dst_id
CopyDataTypeTag.Bytecode, # dst_type
offset, # src_addr
offset + size, # src_addr_boundary
FQ(0), # dst_addr
size, # length
instruction.curr.rw_counter + instruction.rw_counter_offset,
### Do stack depth, nonce and balance pre-check
# ErrDepth constraint
is_error_depth, _ = instruction.compare(FQ(CALL_CREATE_DEPTH), depth, N_BYTES_STACK)
# ErrInsufficientBalance constraint
is_insufficient_balance, _ = instruction.compare_word(Word(balance_prev.expr().n), value_word)
# ErrNonceUintOverflow constraint
is_nonce_in_range, _ = instruction.compare(nonce_prev, FQ(MAX_U64), N_BYTES_U64)

# pass the pre-check if none of above errors happen
is_precheck_ok = (
is_error_depth == FQ(0) and is_insufficient_balance == FQ(0) and is_nonce_in_range == FQ(1)
)
instruction.rw_counter_offset += int(copy_rwc_inc)

# CREATE: 3 pops and 1 push, stack delta = 2
# CREATE2: 4 pops and 1 push, stack delta = 3
stack_pointer_delta = 2 + is_create2
# Save caller's call state
for field_tag, expected_value in [
(CallContextFieldTag.ProgramCounter, instruction.curr.program_counter + 1),
(
CallContextFieldTag.StackPointer,
instruction.curr.stack_pointer + stack_pointer_delta,
),
(CallContextFieldTag.GasLeft, gas_left - gas_cost - callee_gas_left),
(CallContextFieldTag.MemorySize, next_memory_size),
(
CallContextFieldTag.ReversibleWriteCounter,
instruction.curr.reversible_write_counter + 1,
),
]:

# error cases, should end this call and return
# is_static should be false
if not is_precheck_ok:
for field_tag, expected_value in [
(CallContextFieldTag.LastCalleeId, FQ(0)),
(CallContextFieldTag.LastCalleeReturnDataOffset, FQ(0)),
(CallContextFieldTag.LastCalleeReturnDataLength, FQ(0)),
]:
instruction.constrain_equal(
instruction.call_context_lookup(field_tag, RW.Write),
expected_value,
)

instruction.constrain_step_state_transition(
rw_counter=Transition.delta(instruction.rw_counter_offset),
program_counter=Transition.delta(1),
stack_pointer=Transition.delta(2 + is_create2),
reversible_write_counter=Transition.delta(1),
gas_left=Transition.delta(-gas_cost),
memory_word_size=Transition.to(next_memory_size),
# Always stay same
call_id=Transition.same(),
is_root=Transition.same(),
is_create=Transition.same(),
code_hash=Transition.same_word(),
)
else:
# calculate contract address
code_hash = instruction.next.code_hash if has_init_code else Word(EMPTY_CODE_HASH)
contract_address = (
instruction.generate_contract_address(caller_address, nonce)
if is_create == 1
else instruction.generate_CREAET2_contract_address(caller_address, salt_word, code_hash)
)
# add contract address to access list
instruction.add_account_to_access_list(tx_id, contract_address)

if has_init_code:
# verify the equality of input `size` and length of calldata
code_size = instruction.bytecode_length(instruction.next.code_hash)
instruction.constrain_equal(code_size, size)

# verify return contract address
instruction.constrain_equal(
instruction.call_context_lookup(field_tag, RW.Write),
expected_value,
instruction.word_to_fq(return_contract_address_word, N_BYTES_ACCOUNT_ADDRESS),
is_success * contract_address,
)

# ErrContractAddressCollision constraint
# code_hash_prev could be either 0 or EMPTY_CODE_HASH
# code_hash should be EMPTY_CODE_HASH to make sure the account is created properly
code_hash, code_hash_prev = instruction.account_write_word(
contract_address, AccountFieldTag.CodeHash
)
instruction.constrain_in_word(
code_hash_prev,
[Word(0), Word(EMPTY_CODE_HASH)],
)
instruction.constrain_equal_word(code_hash, Word(EMPTY_CODE_HASH))

# Setup next call's context.
for field_tag, expected_value in [
(CallContextFieldTag.CallerId, instruction.curr.call_id),
(CallContextFieldTag.TxId, tx_id),
(CallContextFieldTag.Depth, depth + 1),
(CallContextFieldTag.CallerAddress, caller_address),
(CallContextFieldTag.CalleeAddress, contract_address),
(CallContextFieldTag.IsSuccess, is_success),
(CallContextFieldTag.IsStatic, FQ(False)),
(CallContextFieldTag.IsRoot, FQ(False)),
(CallContextFieldTag.IsCreate, FQ(True)),
]:
# Propagate is_persistent
callee_reversion_info = instruction.reversion_info(call_id=callee_call_id)
instruction.constrain_equal(
instruction.call_context_lookup(field_tag, call_id=callee_call_id),
expected_value,
callee_reversion_info.is_persistent,
reversion_info.is_persistent * is_success.expr(),
)
instruction.constrain_equal_word(
instruction.call_context_lookup_word(CallContextFieldTag.CodeHash, call_id=callee_call_id),
code_hash,
)

instruction.step_state_transition_to_new_context(
rw_counter=Transition.delta(instruction.rw_counter_offset),
call_id=Transition.to(callee_call_id),
is_root=Transition.to(False),
is_create=Transition.to(True),
code_hash=Transition.to_word(instruction.next.code_hash),
gas_left=Transition.to(callee_gas_left),
# `transfer` includes two balance updates
reversible_write_counter=Transition.to(2),
log_id=Transition.same(),
)
# transfer value from caller to contract address
instruction.transfer(caller_address, contract_address, value_word, callee_reversion_info)

# CREATE: 3 pops and 1 push, stack delta = 2
# CREATE2: 4 pops and 1 push, stack delta = 3
stack_pointer_delta = 2 + is_create2

if has_init_code:
# copy init_code from memory to bytecode
copy_rwc_inc, _ = instruction.copy_lookup(
instruction.curr.call_id, # src_id
CopyDataTypeTag.Memory, # src_type
instruction.next.code_hash, # dst_id
CopyDataTypeTag.Bytecode, # dst_type
offset, # src_addr
offset + size, # src_addr_boundary
FQ(0), # dst_addr
size, # length
instruction.curr.rw_counter + instruction.rw_counter_offset,
)
instruction.rw_counter_offset += int(copy_rwc_inc)

# Save caller's call state
for field_tag, expected_value in [
(CallContextFieldTag.ProgramCounter, instruction.curr.program_counter + 1),
(
CallContextFieldTag.StackPointer,
instruction.curr.stack_pointer + stack_pointer_delta,
),
(CallContextFieldTag.GasLeft, gas_left - gas_cost - callee_gas_left),
(CallContextFieldTag.MemorySize, next_memory_size),
(
CallContextFieldTag.ReversibleWriteCounter,
instruction.curr.reversible_write_counter + 1,
),
]:
instruction.constrain_equal(
instruction.call_context_lookup(field_tag, RW.Write),
expected_value,
)
# Setup next call's context.
for field_tag, expected_value in [
(CallContextFieldTag.CallerId, instruction.curr.call_id),
(CallContextFieldTag.TxId, tx_id),
(CallContextFieldTag.Depth, depth + 1),
(CallContextFieldTag.CallerAddress, caller_address),
(CallContextFieldTag.CalleeAddress, contract_address),
(CallContextFieldTag.IsSuccess, is_success),
(CallContextFieldTag.IsStatic, FQ(False)),
(CallContextFieldTag.IsRoot, FQ(False)),
(CallContextFieldTag.IsCreate, FQ(True)),
]:
instruction.constrain_equal(
instruction.call_context_lookup(field_tag, call_id=callee_call_id),
expected_value,
)
instruction.constrain_equal_word(
instruction.call_context_lookup_word(
CallContextFieldTag.CodeHash, call_id=callee_call_id
),
code_hash,
)

instruction.step_state_transition_to_new_context(
rw_counter=Transition.delta(instruction.rw_counter_offset),
call_id=Transition.to(callee_call_id),
is_root=Transition.to(False),
is_create=Transition.to(True),
code_hash=Transition.to_word(instruction.next.code_hash),
gas_left=Transition.to(callee_gas_left),
# `transfer` includes two balance updates
reversible_write_counter=Transition.to(2),
log_id=Transition.same(),
)
else:
for field_tag, expected_value in [
(CallContextFieldTag.LastCalleeId, FQ(0)),
(CallContextFieldTag.LastCalleeReturnDataOffset, FQ(0)),
(CallContextFieldTag.LastCalleeReturnDataLength, FQ(0)),
]:
instruction.constrain_equal(
instruction.call_context_lookup(field_tag, RW.Write),
expected_value,
)

instruction.constrain_step_state_transition(
rw_counter=Transition.delta(instruction.rw_counter_offset),
program_counter=Transition.delta(1),
stack_pointer=Transition.delta(stack_pointer_delta),
gas_left=Transition.delta(-gas_cost),
reversible_write_counter=Transition.delta(3),
memory_word_size=Transition.to(next_memory_size),
# Always stay same
call_id=Transition.same(),
is_root=Transition.same(),
is_create=Transition.same(),
code_hash=Transition.same_word(),
)
6 changes: 4 additions & 2 deletions src/zkevm_specs/evm_circuit/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,8 +1108,10 @@ def memory_gas_cost(self, memory_size: Expression) -> FQ:
return quadratic_cost + linear_cost

def memory_expansion(self, offset: Expression, length: Expression) -> Tuple[FQ, FQ]:
memory_size, _ = self.constant_divmod(
length.expr() + offset.expr() + 31, FQ(32), N_BYTES_MEMORY_SIZE
memory_size, _ = (
self.constant_divmod(length.expr() + offset.expr() + 31, FQ(32), N_BYTES_MEMORY_SIZE)
if length != FQ(0)
else (FQ(0), FQ(0))
)

next_memory_size = self.max(self.curr.memory_word_size, memory_size, N_BYTES_MEMORY_SIZE)
Expand Down
Loading

0 comments on commit 31201a2

Please sign in to comment.