Skip to content

Commit

Permalink
Moving the _get method length checks into the dtype method.
Browse files Browse the repository at this point in the history
  • Loading branch information
scott-griffiths committed Dec 31, 2023
1 parent c3be1a9 commit 4da10e9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 20 deletions.
20 changes: 2 additions & 18 deletions bitstring/bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,14 +780,10 @@ def _getintle(self) -> int:
return bs.slice_to_int()

def _gete4m3float(self) -> float:
if len(self) != 8:
raise InterpretError(f"A e4m3float must be 8 bits long, not {len(self)} bits.")
u = self._getuint()
return e4m3float_fmt.lut_int8_to_float[u]

def _gete5m2float(self) -> float:
if len(self) != 8:
raise InterpretError(f"A e5m2float must be 8 bits long, not {len(self)} bits.")
u = self._getuint()
return e5m2float_fmt.lut_int8_to_float[u]

Expand All @@ -800,10 +796,7 @@ def _setfloatbe(self, f: float, length: Optional[int] = None, _offset: None = No

def _getfloatbe(self) -> float:
"""Interpret the whole bitstring as a big-endian float."""
try:
fmt = {16: '>e', 32: '>f', 64: '>d'}[len(self)]
except KeyError:
raise InterpretError(f"Floats can only be 16, 32 or 64 bits long, not {len(self)} bits")
fmt = {16: '>e', 32: '>f', 64: '>d'}[len(self)]
return struct.unpack(fmt, self._bitstore.tobytes())[0]

def _setfloatle(self, f: float, length: Optional[int] = None, _offset: None = None) -> None:
Expand All @@ -815,15 +808,10 @@ def _setfloatle(self, f: float, length: Optional[int] = None, _offset: None = No

def _getfloatle(self) -> float:
"""Interpret the whole bitstring as a little-endian float."""
try:
fmt = {16: '<e', 32: '<f', 64: '<d'}[len(self)]
except KeyError:
raise InterpretError(f"Floats can only be 16, 32 or 64 bits long, not {len(self)} bits")
fmt = {16: '<e', 32: '<f', 64: '<d'}[len(self)]
return struct.unpack(fmt, self._bitstore.tobytes())[0]

def _getbfloatbe(self) -> float:
if len(self) != 16:
raise InterpretError(f"bfloats must be length 16, received a length of {len(self)} bits.")
zero_padded = self + Bits(16)
return zero_padded._getfloatbe()

Expand All @@ -833,8 +821,6 @@ def _setbfloatbe(self, f: Union[float, str], length: Optional[int] = None, _offs
self._bitstore = bfloat2bitstore(f)

def _getbfloatle(self) -> float:
if len(self) != 16:
raise InterpretError(f"bfloats must be length 16, received a length of {len(self)} bits.")
zero_padded = Bits(16) + self
return zero_padded._getfloatle()

Expand Down Expand Up @@ -983,8 +969,6 @@ def _setbool(self, value: Union[bool, str], length: Optional[int] = None, _offse
raise CreationError(f"Cannot initialise boolean with {value}.")

def _getbool(self) -> bool:
if len(self) != 1:
raise InterpretError(f"For a bool interpretation a bitstring must be 1 bit long, not {len(self)} bits.")
return self[0]

def _getpad(self) -> None:
Expand Down
17 changes: 15 additions & 2 deletions bitstring/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,21 @@ def __init__(self, name: str, set_fn, get_fn, return_type: Any = Any, is_signed:
self.multiplier = multiplier

self.set_fn = set_fn
self.get_fn = get_fn # Interpret everything

if self.fixed_length:
if len(self.fixed_length) == 1:
def length_checked_get_fn(bs):
if len(bs) != self.fixed_length[0]:
raise InterpretError(f"'{self.name}' dtypes must have a length of {self.fixed_length[0]}, but received a length of {len(bs)}.")
return get_fn(bs)
else:
def length_checked_get_fn(bs):
if len(bs) not in self.fixed_length:
raise InterpretError(f"'{self.name}' dtypes must have one of the lengths {self.fixed_length}, but received a length of {len(bs)}.")
return get_fn(bs)
self.get_fn = length_checked_get_fn # Interpret everything and check the length
else:
self.get_fn = get_fn # Interpret everything

# Create a reading function from the get_fn.
if self.is_unknown_length:
Expand All @@ -121,7 +135,6 @@ def read_fn(bs, start):
def read_fn(bs, start, length):
return self.get_fn(bs[start:start + length])
self.read_fn = read_fn

self.bitlength2chars_fn = bitlength2chars_fn

def getDtype(self, length: Optional[int] = None) -> Dtype:
Expand Down

0 comments on commit 4da10e9

Please sign in to comment.