Skip to content

Commit

Permalink
merge(fix): block multiple calls to cipher stream finalization [#19]
Browse files Browse the repository at this point in the history
  • Loading branch information
rmlibre committed Aug 10, 2024
2 parents 51cbd81 + e128390 commit 1bdbfba
Show file tree
Hide file tree
Showing 45 changed files with 272 additions and 165 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ include aiootp/databases/sync_database.py
include aiootp/tor/README_TOR.rst
include aiootp/db/README_DATABASES.rst
include tests/pytest.ini
include tests/conftest.py
include tests/test_aiootp.py
include tests/test_initialization.py
include tests/test_typing.py
include tests/test_Typing_class.py
include tests/test_typing_protocols.py
Expand Down
96 changes: 48 additions & 48 deletions SIGNATURE.txt

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions aiootp/asynchs/concurrency_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def __aenter__(self, /) -> t.Self:
in the 0th position of the queue.
"""
self.queue.append(self.token)
while not compare_digest(self.queue[0], self.token):
while not compare_digest(self.token, self.queue[0]):
await asleep(self.probe_delay)
return self

Expand All @@ -128,7 +128,7 @@ def __enter__(self, /) -> t.Self:
in the 0th position of the queue.
"""
self.queue.append(self.token)
while not compare_digest(self.queue[0], self.token):
while not compare_digest(self.token, self.queue[0]):
sleep(self.probe_delay)
return self

Expand Down
80 changes: 46 additions & 34 deletions aiootp/ciphers/cipher_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import io
from collections import deque
from hmac import compare_digest
from secrets import token_bytes

from aiootp._typing import Typing as t
from aiootp._constants.misc import DEFAULT_AAD
Expand Down Expand Up @@ -69,7 +71,7 @@ class AsyncCipherStream(CipherStreamProperties, metaclass=AsyncInit):
"_byte_count",
"_config",
"_digesting_now",
"_is_finalized",
"_finalizing_now",
"_key_bundle",
"_padding",
"_stream",
Expand Down Expand Up @@ -117,7 +119,7 @@ async def __init__(
self._padding = cipher._padding
self._byte_count = 0
self._digesting_now = deque(maxlen=self._MAX_SIMULTANEOUS_BUFFERS)
self._is_finalized = False
self._finalizing_now = deque() # don't let maxlen remove entries
self._buffer = buffer = deque([self._padding.start_padding()])
self._key_bundle = key_bundle = await cipher._KeyAADBundle(
kdfs=cipher._kdfs, salt=salt, aad=aad
Expand Down Expand Up @@ -188,18 +190,23 @@ async def afinalize(
async for block_id, ciphertext in stream.afinalize(): # <------
session.send_packet(block_id + ciphertext)
"""
self._is_finalized = True
end_padding = await self._padding.aend_padding(self._byte_count)
final_blocks = abatch(
self._buffer.pop() + end_padding, size=self._config.BLOCKSIZE
)
async for block in final_blocks:
self._buffer.append(block)
while self._buffer:
block = await self._stream.asend(None)
block_id = await self.shmac.anext_block_id(block)
yield block_id, block
await self.shmac.afinalize()
self._finalizing_now.append(token := token_bytes(32))
if not compare_digest(token, self._finalizing_now[0]):
raise ConcurrencyGuard.IncoherentConcurrencyState

async with ConcurrencyGuard(self._digesting_now, token=token):
end_padding = await self._padding.aend_padding(self._byte_count)
final_blocks = abatch(
self._buffer.pop() + end_padding,
size=self._config.BLOCKSIZE,
)
async for block in final_blocks:
self._buffer.append(block)
while self._buffer:
block = await self._stream.asend(None)
block_id = await self.shmac.anext_block_id(block)
yield block_id, block
await self.shmac.afinalize()

async def _adigest_data(
self,
Expand Down Expand Up @@ -245,10 +252,9 @@ async def abuffer(self, data: bytes) -> t.Self:
async for block_id, ciphertext in stream.afinalize():
session.send_packet(block_id + ciphertext)
"""
if self._is_finalized:
raise CipherStreamIssue.stream_has_been_closed()

async with ConcurrencyGuard(self._digesting_now):
if self._finalizing_now:
raise CipherStreamIssue.stream_has_been_closed()
self._byte_count += len(data)
data = io.BytesIO(data).read
_buffer, append = self._buffer_shortcuts
Expand Down Expand Up @@ -296,7 +302,7 @@ class CipherStream(CipherStreamProperties):
"_byte_count",
"_config",
"_digesting_now",
"_is_finalized",
"_finalizing_now",
"_key_bundle",
"_padding",
"_stream",
Expand Down Expand Up @@ -344,7 +350,7 @@ def __init__(
self._padding = cipher._padding
self._byte_count = 0
self._digesting_now = deque(maxlen=self._MAX_SIMULTANEOUS_BUFFERS)
self._is_finalized = False
self._finalizing_now = deque() # don't let maxlen remove entries
self._buffer = buffer = deque([self._padding.start_padding()])
self._key_bundle = key_bundle = cipher._KeyAADBundle(
kdfs=cipher._kdfs, salt=salt, aad=aad
Expand Down Expand Up @@ -411,18 +417,25 @@ def finalize(self) -> t.Generator[t.Tuple[bytes, bytes], None, None]:
for block_id, ciphertext in stream.finalize(): # <-------------
session.send_packet(block_id + ciphertext)
"""
self._is_finalized = True
end_padding = self._padding.end_padding(self._byte_count)
final_blocks = batch(
self._buffer.pop() + end_padding, size=self._config.BLOCKSIZE
)
for block in final_blocks:
self._buffer.append(block)
while self._buffer:
block = self._stream.send(None)
block_id = self.shmac.next_block_id(block)
yield block_id, block
self.shmac.finalize()
self._finalizing_now.append(token := token_bytes(32))
if not compare_digest(token, self._finalizing_now[0]):
raise ConcurrencyGuard.IncoherentConcurrencyState

with ConcurrencyGuard(
self._digesting_now, probe_delay=0.0001, token=token
):
end_padding = self._padding.end_padding(self._byte_count)
final_blocks = batch(
self._buffer.pop() + end_padding,
size=self._config.BLOCKSIZE,
)
for block in final_blocks:
self._buffer.append(block)
while self._buffer:
block = self._stream.send(None)
block_id = self.shmac.next_block_id(block)
yield block_id, block
self.shmac.finalize()

def _digest_data(
self,
Expand Down Expand Up @@ -467,10 +480,9 @@ def buffer(self, data: bytes) -> t.Self:
for block_id, ciphertext in stream.finalize():
session.send_packet(block_id + ciphertext)
"""
if self._is_finalized:
raise CipherStreamIssue.stream_has_been_closed()

with ConcurrencyGuard(self._digesting_now, probe_delay=0.0001):
if self._finalizing_now:
raise CipherStreamIssue.stream_has_been_closed()
self._byte_count += len(data)
data = io.BytesIO(data).read
_buffer, append = self._buffer_shortcuts
Expand Down
70 changes: 42 additions & 28 deletions aiootp/ciphers/decipher_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import io
from collections import deque
from hmac import compare_digest
from secrets import token_bytes

from aiootp._typing import Typing as t
from aiootp._constants import DEFAULT_AAD, DEFAULT_TTL
Expand Down Expand Up @@ -69,7 +71,7 @@ class AsyncDecipherStream(CipherStreamProperties, metaclass=AsyncInit):
"_bytes_to_trim",
"_config",
"_digesting_now",
"_is_finalized",
"_finalizing_now",
"_is_streaming",
"_key_bundle",
"_padding",
Expand Down Expand Up @@ -129,8 +131,8 @@ async def __init__(
self._padding = cipher._padding
self._ttl = ttl
self._digesting_now = deque(maxlen=self._MAX_SIMULTANEOUS_BUFFERS)
self._finalizing_now = deque() # don't let maxlen remove entries
self._is_streaming = False
self._is_finalized = False
self._result_queue = deque()
self._buffer = buffer = deque()
self._bytes_to_trim = self._config.INNER_HEADER_BYTES
Expand Down Expand Up @@ -246,16 +248,22 @@ async def afinalize(self) -> t.AsyncGenerator[bytes, None]:
async for plaintext in stream.afinalize():
yield plaintext
"""
self._is_finalized = True
await self.shmac.afinalize()
async for result in self:
yield result
queue = self._result_queue
footer_index = await self._padding.adepadding_end_index(queue[-1])
async for block in abatch(
b"".join(queue)[:footer_index], size=self._config.BLOCKSIZE
):
yield block
self._finalizing_now.append(token := token_bytes(32))
if not compare_digest(token, self._finalizing_now[0]):
raise ConcurrencyGuard.IncoherentConcurrencyState

async with ConcurrencyGuard(self._digesting_now, token=token):
await self.shmac.afinalize()
async for result in self:
yield result
queue = self._result_queue
footer_index = await self._padding.adepadding_end_index(
queue[-1]
)
async for block in abatch(
b"".join(queue)[:footer_index], size=self._config.BLOCKSIZE
):
yield block

async def _adigest_data(
self,
Expand Down Expand Up @@ -306,12 +314,12 @@ async def abuffer(self, data: bytes) -> t.Self:
async for plaintext in stream.afinalize():
yield plaintext
"""
if self._is_finalized:
raise CipherStreamIssue.stream_has_been_closed()
elif not data or len(data) % self.PACKETSIZE:
if not data or len(data) % self.PACKETSIZE:
raise Issue.invalid_length("data", len(data))

async with ConcurrencyGuard(self._digesting_now):
if self._finalizing_now:
raise CipherStreamIssue.stream_has_been_closed()
data = io.BytesIO(data).read
atest_block_id, append = self._buffer_shortcuts
await self._adigest_data(data, atest_block_id, append)
Expand Down Expand Up @@ -358,7 +366,7 @@ class DecipherStream(CipherStreamProperties):
"_config",
"_bytes_to_trim",
"_digesting_now",
"_is_finalized",
"_finalizing_now",
"_is_streaming",
"_key_bundle",
"_padding",
Expand Down Expand Up @@ -418,8 +426,8 @@ def __init__(
self._padding = cipher._padding
self._ttl = ttl
self._digesting_now = deque(maxlen=self._MAX_SIMULTANEOUS_BUFFERS)
self._finalizing_now = deque() # don't let maxlen remove entries
self._is_streaming = False
self._is_finalized = False
self._result_queue = deque()
self._buffer = buffer = deque()
self._bytes_to_trim = self._config.INNER_HEADER_BYTES
Expand Down Expand Up @@ -529,14 +537,20 @@ def finalize(self) -> t.Generator[bytes, None, None]:
for plaintext in stream.finalize():
yield plaintext
"""
self._is_finalized = True
self.shmac.finalize()
yield from self
queue = self._result_queue
footer_index = self._padding.depadding_end_index(queue[-1])
yield from batch(
b"".join(queue)[:footer_index], size=self._config.BLOCKSIZE
)
self._finalizing_now.append(token := token_bytes(32))
if not compare_digest(token, self._finalizing_now[0]):
raise ConcurrencyGuard.IncoherentConcurrencyState

with ConcurrencyGuard(
self._digesting_now, probe_delay=0.0001, token=token
):
self.shmac.finalize()
yield from self
queue = self._result_queue
footer_index = self._padding.depadding_end_index(queue[-1])
yield from batch(
b"".join(queue)[:footer_index], size=self._config.BLOCKSIZE
)

def _digest_data(
self,
Expand Down Expand Up @@ -585,12 +599,12 @@ def buffer(self, data: bytes) -> t.Self:
for plaintext in stream.finalize():
yield plaintext
"""
if self._is_finalized:
raise CipherStreamIssue.stream_has_been_closed()
elif not data or len(data) % self.PACKETSIZE:
if not data or len(data) % self.PACKETSIZE:
raise Issue.invalid_length("data", len(data))

with ConcurrencyGuard(self._digesting_now, probe_delay=0.0001):
if self._finalizing_now:
raise CipherStreamIssue.stream_has_been_closed()
data = io.BytesIO(data).read
atest_block_id, append = self._buffer_shortcuts
self._digest_data(data, atest_block_id, append)
Expand Down
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
"*{config,format}.py" = [
"PLR0913", # config objects can have many arguments
]
"test_initialization.py" = [
"conftest.py" = [
"E402", # there's a wibbly wobbly, timey wimey import sequence here
"F401", # most other tests get their imports from this module
]
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/test_ByteIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from base64 import urlsafe_b64encode

from test_initialization import *
from conftest import *


NON_ZERO_PREFIX = choice(list(range(1, 256))).to_bytes(1, BIG)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_Clock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import time
import warnings

from test_initialization import *
from conftest import *


TIME_RESOLUTION = time.get_clock_info("time").resolution
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ConcurrencyGuard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from collections import deque

from test_initialization import *
from conftest import *

from aiootp.asynchs import ConcurrencyGuard

Expand Down
2 changes: 1 addition & 1 deletion tests/test_Database_AsyncDatabase.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#


from test_initialization import *
from conftest import *


class TestDBKDF:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_GUID.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#


from test_initialization import *
from conftest import *
from test_Clock import TIME_RESOLUTION

from aiootp.randoms.ids.raw_guid_config import RawGUIDContainer
Expand Down
4 changes: 2 additions & 2 deletions tests/test_PackageSigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import platform

from test_initialization import *
from conftest import *

from aiootp._paths import DatabasePath
from aiootp.asynchs import sleep
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_sign_and_verify() -> None:
)

filename_sheet = """
include tests/test_initialization.py
include tests/conftest.py
include tests/test_aiootp.py
include tests/test_generics.py
include tests/test_ByteIO.py
Expand Down
2 changes: 1 addition & 1 deletion tests/test_Padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#


from test_initialization import *
from conftest import *


class TestPlaintextPadding:
Expand Down
Loading

0 comments on commit 1bdbfba

Please sign in to comment.