Skip to content

Commit

Permalink
feat: EIP-7549 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
madlabman committed Oct 31, 2024
1 parent 8cdfaee commit 8e9e4b8
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 22 deletions.
47 changes: 31 additions & 16 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from dataclasses import dataclass
from itertools import batched
from threading import Lock
from typing import Iterable, Sequence
from typing import Iterable, Sequence, TypeGuard

from src import variables
from src.constants import SLOTS_PER_HISTORICAL_ROOT
from src.metrics.prometheus.csm import CSM_UNPROCESSED_EPOCHS_COUNT, CSM_MIN_UNPROCESSED_EPOCH
from src.metrics.prometheus.csm import CSM_MIN_UNPROCESSED_EPOCH, CSM_UNPROCESSED_EPOCHS_COUNT
from src.modules.csm.state import State
from src.providers.consensus.client import ConsensusClient
from src.providers.consensus.types import BlockAttestation
from src.providers.consensus.types import BlockAttestationElectra, BlockAttestationPhase0
from src.types import BlockRoot, BlockStamp, EpochNumber, SlotNumber, ValidatorIndex
from src.utils.range import sequence
from src.utils.timeit import timeit
Expand All @@ -20,8 +20,10 @@
lock = Lock()


class MinStepIsNotReached(Exception):
...
class MinStepIsNotReached(Exception): ...


type BlockAttestation = BlockAttestationPhase0 | BlockAttestationElectra


@dataclass
Expand Down Expand Up @@ -103,7 +105,9 @@ def _is_min_step_reached(self):
return False


type Committees = dict[tuple[str, str], list[ValidatorDuty]]
type Slot = str
type CommitteeIndex = str
type Committees = dict[tuple[Slot, CommitteeIndex], list[ValidatorDuty]]


class FrameCheckpointProcessor:
Expand Down Expand Up @@ -228,19 +232,30 @@ def _prepare_committees(self, epoch: EpochNumber) -> Committees:


def process_attestations(attestations: Iterable[BlockAttestation], committees: Committees) -> None:
def _is_attested(bits: Sequence[bool], index: int) -> bool:
return bits[index]

for attestation in attestations:
committee_id = (attestation.data.slot, attestation.data.index)
committee = committees.get(committee_id, [])
att_bits = _to_bits(attestation.aggregation_bits)
for index_in_committee, validator_duty in enumerate(committee):
validator_duty.included = validator_duty.included or _is_attested(att_bits, index_in_committee)
committee_offset = 0
for committee_idx in get_committee_indices(attestation):
committee = committees.get((attestation.data.slot, committee_idx), [])
att_bits = hex_bitvector_to_list(attestation.aggregation_bits)[committee_offset:][: len(committee)]
for index_in_committee, validator_duty in enumerate(committee):
validator_duty.included = validator_duty.included or _is_attested(att_bits, index_in_committee)
committee_offset += len(committee)


def get_committee_indices(attestation: BlockAttestation) -> list[CommitteeIndex]:
if not is_electra_attestation(attestation):
return [attestation.data.index]
return [str(idx) for (idx, bit) in enumerate(hex_bitvector_to_list(attestation.committee_bits)) if bit]


def _is_attested(bits: Sequence[bool], index: int) -> bool:
return bits[index]
def is_electra_attestation(attestation: BlockAttestation) -> TypeGuard[BlockAttestationElectra]:
return getattr(attestation, "committee_bits") is not None and attestation.data.index == "0"


def _to_bits(aggregation_bits: str) -> Sequence[bool]:
def hex_bitvector_to_list(bitvector: str) -> list[bool]:
# copied from https://github.com/ethereum/py-ssz/blob/main/ssz/sedes/bitvector.py#L66
att_bytes = bytes.fromhex(aggregation_bits[2:])
return [bool((att_bytes[bit_index // 8] >> bit_index % 8) % 2) for bit_index in range(len(att_bytes) * 8)]
bytes_ = bytes.fromhex(bitvector[2:]) if bitvector.startswith("0x") else bytes.fromhex(bitvector)
return [bool((bytes_[bit_index // 8] >> bit_index % 8) % 2) for bit_index in range(len(bytes_) * 8)]
17 changes: 13 additions & 4 deletions src/providers/consensus/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from src.metrics.logging import logging
from src.metrics.prometheus.basic import CL_REQUESTS_DURATION
from src.providers.consensus.types import (
BlockAttestationElectra,
BlockAttestationPhase0,
BlockDetailsResponse,
BlockHeaderFullResponse,
BlockHeaderResponseData,
Expand Down Expand Up @@ -108,7 +110,9 @@ def get_block_details(self, state_id: SlotNumber | BlockRoot) -> BlockDetailsRes
return BlockDetailsResponse.from_response(**data)

@lru_cache(maxsize=256)
def get_block_attestations(self, state_id: SlotNumber | BlockRoot) -> list[BlockAttestation]:
def get_block_attestations(
self, state_id: SlotNumber | BlockRoot
) -> list[BlockAttestationPhase0 | BlockAttestationElectra]:
"""Spec: https://ethereum.github.io/beacon-APIs/#/Beacon/getBlockAttestations"""
data, _ = self._get(
self.API_GET_BLOCK_ATTESTATIONS,
Expand All @@ -124,20 +128,25 @@ def get_attestation_committees(
self,
blockstamp: BlockStamp,
epoch: EpochNumber | None = None,
index: int | None = None,
committee_index: int | None = None,
slot: SlotNumber | None = None
) -> list[SlotAttestationCommittee]:
"""Spec: https://ethereum.github.io/beacon-APIs/#/Beacon/getEpochCommittees"""
try:
data, _ = self._get(
self.API_GET_ATTESTATION_COMMITTEES,
path_params=(blockstamp.state_root,),
query_params={'epoch': epoch, 'index': index, 'slot': slot},
query_params={'epoch': epoch, 'index': committee_index, 'slot': slot},
force_raise=self.__raise_on_prysm_error
)
except NotOkResponse as error:
if self.PRYSM_STATE_NOT_FOUND_ERROR in error.text:
data = self._get_attestation_committees_with_prysm(blockstamp, epoch, index, slot)
data = self._get_attestation_committees_with_prysm(
blockstamp,
epoch,
committee_index,
slot,
)
else:
raise error
return cast(list[SlotAttestationCommittee], data)
Expand Down
15 changes: 13 additions & 2 deletions src/providers/consensus/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass
from enum import Enum
from typing import Literal, Protocol

from src.types import BlockHash, BlockRoot, StateRoot
from src.utils.dataclass import Nested, FromResponse
from src.utils.dataclass import FromResponse, Nested


@dataclass
Expand Down Expand Up @@ -75,7 +76,7 @@ class Checkpoint:
@dataclass
class AttestationData(Nested, FromResponse):
slot: str
index: str
index: str | Literal["0"]
beacon_block_root: BlockRoot
source: Checkpoint
target: Checkpoint
Expand All @@ -85,6 +86,16 @@ class AttestationData(Nested, FromResponse):
class BlockAttestation(Nested, FromResponse):
aggregation_bits: str
data: AttestationData
committee_bits: str | None = None


class BlockAttestationPhase0(Protocol):
aggregation_bits: str
data: AttestationData


class BlockAttestationElectra(BlockAttestationPhase0):
committee_bits: str


@dataclass
Expand Down
1 change: 1 addition & 0 deletions tests/factory/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class BlockAttestationFactory(Web3Factory):
__model__ = BlockAttestation

aggregation_bits = "0x"
committee_bits = None
data = AttestationData(
slot="0",
index="0",
Expand Down

0 comments on commit 8e9e4b8

Please sign in to comment.