diff --git a/src/ophyd_async/core/_device.py b/src/ophyd_async/core/_device.py index 365e5714a..7b331ac2f 100644 --- a/src/ophyd_async/core/_device.py +++ b/src/ophyd_async/core/_device.py @@ -5,7 +5,7 @@ from collections.abc import Coroutine, Iterator, Mapping, MutableMapping from functools import cached_property from logging import LoggerAdapter, getLogger -from typing import Any, Literal, TypeVar +from typing import Any, TypeVar from bluesky.protocols import HasName from bluesky.run_engine import call_in_bluesky_event_loop, in_bluesky_event_loop @@ -63,11 +63,8 @@ class Device(HasName, Connectable): parent: Device | None = None # None if connect hasn't started, a Task if it has _connect_task: asyncio.Task | None = None - _mock: ( - None # If we have never connected - | LazyMock # The mock if we have connected in mock mode - | Literal[False] # If we have not connected in mock mode - ) = None + # The mock if we have connected in mock mode + _mock: LazyMock | None = None def __init__( self, name: str = "", connector: DeviceConnector | None = None @@ -111,16 +108,20 @@ def set_name(self, name: str): child.set_name(child_name) def __setattr__(self, name: str, value: Any) -> None: + # Bear in mind that this function is called *a lot*, so + # we need to make sure nothing expensive happens in it... if name == "parent": if self.parent not in (value, None): raise TypeError( f"Cannot set the parent of {self} to be {value}: " f"it is already a child of {self.parent}" ) + # ...hence not doing an isinstance check for attributes we + # know not to be Devices elif name not in _not_device_attrs and isinstance(value, Device): value.parent = self self._child_devices[name] = value - # Avoid the super call as this happens a lot + # ...and avoiding the super call as we know it resolves to `object` return object.__setattr__(self, name, value) async def connect( @@ -157,7 +158,7 @@ async def connect( and not (self._connect_task.done() and self._connect_task.exception()) ) if force_reconnect or not can_use_previous_connect: - self._mock = False + self._mock = None coro = self._connector.connect_real(self, timeout, force_reconnect) self._connect_task = asyncio.create_task(coro) assert self._connect_task, "Connect task not created, this shouldn't happen" diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index d546cb7d5..ca20d90a3 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -265,6 +265,24 @@ def __call__(self) -> T: class LazyMock: + """A lazily created Mock to be used when connecting in mock mode. + + Creating Mocks is reasonably expensive when each Device (and Signal) + requires its own, and the tree is only used when ``Signal.set()`` is + called. This class allows a tree of lazily connected Mocks to be + constructed so that when the leaf is created, so are its parents. + Any calls to the child are then accessible from the parent mock. + + >>> parent = LazyMock() + >>> child = parent.child("child") + >>> child_mock = child() + >>> child_mock() # doctest: +ELLIPSIS + + >>> parent_mock = parent() + >>> parent_mock.mock_calls + [call.child()] + """ + def __init__(self, name: str = "", parent: LazyMock | None = None) -> None: self.parent = parent self.name = name diff --git a/tests/core/test_device.py b/tests/core/test_device.py index 70620a0bf..1f9c0ffb8 100644 --- a/tests/core/test_device.py +++ b/tests/core/test_device.py @@ -175,11 +175,16 @@ def __init__(self, name: str) -> None: super().__init__(name) -async def test_many_individual_device_connects_not_slow(): +@pytest.mark.parametrize("parallel", (False, True)) +async def test_many_individual_device_connects_not_slow(parallel): start = time.time() - for i in range(100): - bundle = MotorBundle(f"bundle{i}") - await bundle.connect(mock=True) + bundles = [MotorBundle(f"bundle{i}") for i in range(200)] + if parallel: + for bundle in bundles: + await bundle.connect(mock=True) + else: + coros = {bundle.name: bundle.connect(mock=True) for bundle in bundles} + await wait_for_connection(**coros) duration = time.time() - start assert duration < 1