diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index ff4d6f8be..6d989eb60 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -95,6 +95,7 @@ get_unique, in_micros, wait_for_connection, + wait_for_pending_wakeups, ) __all__ = [ @@ -190,4 +191,5 @@ "in_micros", "wait_for_connection", "completed_status", + "wait_for_pending_wakeups", ] diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 549da8678..280b990a4 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -36,6 +36,7 @@ Callback, LazyMock, T, + wait_for_pending_wakeups, ) @@ -459,7 +460,7 @@ async def observe_value( item = await asyncio.wait_for(q.get(), timeout) # yield here in case something else is filling the queue # like in test_observe_value_times_out_with_no_external_task() - await asyncio.sleep(0) + await wait_for_pending_wakeups() if done_status and item is done_status: if exc := done_status.exception(): raise exc diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index 2aa4b1c71..466b4e203 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -2,6 +2,7 @@ import asyncio import logging +import warnings from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence from dataclasses import dataclass from enum import Enum, EnumMeta @@ -295,3 +296,26 @@ def __call__(self) -> Mock: if self.parent is not None: self.parent().attach_mock(self._mock, self.name) return self._mock + + +async def wait_for_pending_wakeups(max_yields=10): + """Allow any ready asyncio tasks to be woken up. + + Used in: + - Tests to allow tasks like ``set()`` to start so that signal + puts can be tested + - `observe_value` to allow it to be wrapped in `asyncio.wait_for` + with a timeout + """ + loop = asyncio.get_event_loop() + # If anything has called loop.call_soon or is scheduled a wakeup + # then let it run + for _ in range(max_yields): + await asyncio.sleep(0) + if not loop._ready: # type: ignore # noqa: SLF001 + return + warnings.warn( + f"Tasks still scheduling wakeups after {max_yields} yields", + RuntimeWarning, + stacklevel=2, + ) diff --git a/tests/epics/signal/test_signals.py b/tests/epics/signal/test_signals.py index 5c5ce9bc8..0a3945caa 100644 --- a/tests/epics/signal/test_signals.py +++ b/tests/epics/signal/test_signals.py @@ -945,13 +945,13 @@ async def test_observe_ticking_signal_with_busy_loop(ioc: IOC): async def watch(): async for val in observe_value(sig): - time.sleep(0.15) + time.sleep(0.3) recv.append(val) start = time.time() with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(watch(), timeout=0.2) - assert time.time() - start == pytest.approx(0.3, abs=0.05) + await asyncio.wait_for(watch(), timeout=0.4) + assert time.time() - start == pytest.approx(0.6, abs=0.1) assert len(recv) == 2 # Don't check values as CA and PVA have different algorithms for # dropping updates for slow callbacks diff --git a/tests/epics/test_motor.py b/tests/epics/test_motor.py index 1760f245b..184243582 100644 --- a/tests/epics/test_motor.py +++ b/tests/epics/test_motor.py @@ -15,21 +15,11 @@ set_mock_put_proceeds, set_mock_value, soft_signal_rw, + wait_for_pending_wakeups, ) from ophyd_async.epics import motor -async def wait_for_wakeups(max_yields=10): - loop = asyncio.get_event_loop() - # If anything has called loop.call_soon or is scheduled a wakeup - # then let it run - for _ in range(max_yields): - await asyncio.sleep(0) - if not loop._ready: - return - raise RuntimeError(f"Tasks still scheduling wakeups after {max_yields} yields") - - @pytest.fixture async def sim_motor(): async with DeviceCollector(mock=True): @@ -44,7 +34,7 @@ async def sim_motor(): async def wait_for_eq(item, attribute, comparison, timeout): timeout_time = time.monotonic() + timeout while getattr(item, attribute) != comparison: - await wait_for_wakeups() + await wait_for_pending_wakeups() if time.monotonic() > timeout_time: raise TimeoutError @@ -56,7 +46,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None: s.watch(watcher) done = Mock() s.add_callback(done) - await wait_for_wakeups() + await wait_for_pending_wakeups() await wait_for_eq(watcher, "call_count", 1, 1) assert watcher.call_args == call( name="sim_motor", @@ -86,7 +76,7 @@ async def test_motor_moving_well(sim_motor: motor.Motor) -> None: set_mock_value(sim_motor.motor_done_move, True) set_mock_value(sim_motor.user_readback, 0.55) set_mock_put_proceeds(sim_motor.user_setpoint, True) - await wait_for_wakeups() + await wait_for_pending_wakeups() await wait_for_eq(s, "done", True, 1) done.assert_called_once_with(s) @@ -98,7 +88,7 @@ async def test_motor_moving_well_2(sim_motor: motor.Motor) -> None: s.watch(watcher) done = Mock() s.add_callback(done) - await wait_for_wakeups() + await wait_for_pending_wakeups() assert watcher.call_count == 1 assert watcher.call_args == call( name="sim_motor", @@ -126,7 +116,7 @@ async def test_motor_moving_well_2(sim_motor: motor.Motor) -> None: time_elapsed=pytest.approx(0.1, abs=0.2), ) set_mock_put_proceeds(sim_motor.user_setpoint, True) - await wait_for_wakeups() + await wait_for_pending_wakeups() assert s.done done.assert_called_once_with(s) @@ -165,7 +155,7 @@ async def test_motor_moving_stopped(sim_motor: motor.Motor): assert not s.done await sim_motor.stop() set_mock_put_proceeds(sim_motor.user_setpoint, True) - await wait_for_wakeups() + await wait_for_pending_wakeups() assert s.done assert s.success is False