Skip to content

Commit

Permalink
Make 'steal' command atomic (#1144)
Browse files Browse the repository at this point in the history
Either unschedule all requested tests, or none if it's not possible -
if some of the requested tests have already been processed by the time
the request arrives. It may happen if the worker runs tests faster than
the controller receives and processes status updates. But in this case
maybe it's just better to let the worker keep running.

This is a prerequisite for group/scope support in worksteal scheduler -
so they won't be broken up incorrectly.

This change could break schedulers that use "steal" command. However:

1) worksteal scheduler doesn't need any adjustments.

2) I'm not aware of any external schedulers relying on this command yet.

So I think it's better to keep the protocol simple, not complicate it for
imaginary compatibility with some unknown and likely non-existent
schedulers.

Co-authored-by: Bruno Oliveira <[email protected]>
  • Loading branch information
amezin and nicoddemus authored Oct 30, 2024
1 parent 34c5549 commit 9788f12
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 28 deletions.
3 changes: 3 additions & 0 deletions changelog/1144.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The internal `steal` command is now atomic - it unschedules either all requested tests or none.

This is a prerequisite for group/scope support in the `worksteal` scheduler, so test groups won't be broken up incorrectly.
88 changes: 60 additions & 28 deletions src/xdist/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@

from __future__ import annotations

import collections
import contextlib
import enum
import os
import sys
import time
from typing import Any
from typing import Generator
from typing import Iterable
from typing import Literal
from typing import Sequence
from typing import TypedDict
from typing import Union
import warnings

from _pytest.config import _prepareconfig
Expand Down Expand Up @@ -66,7 +69,44 @@ def worker_title(title: str) -> None:

class Marker(enum.Enum):
SHUTDOWN = 0
QUEUE_REPLACED = 1


class TestQueue:
"""A simple queue that can be inspected and modified while the lock is held via the ``lock()`` method."""

Item = Union[int, Literal[Marker.SHUTDOWN]]

def __init__(self, execmodel: execnet.gateway_base.ExecModel):
self._items: collections.deque[TestQueue.Item] = collections.deque()
self._lock = execmodel.RLock() # type: ignore[no-untyped-call]
self._has_items_event = execmodel.Event()

def get(self) -> Item:
while True:
with self.lock() as locked_items:
if locked_items:
return locked_items.popleft()

self._has_items_event.wait()

def put(self, item: Item) -> None:
with self.lock() as locked_items:
locked_items.append(item)

def replace(self, iterable: Iterable[Item]) -> None:
with self.lock():
self._items = collections.deque(iterable)

@contextlib.contextmanager
def lock(self) -> Generator[collections.deque[Item], None, None]:
with self._lock:
try:
yield self._items
finally:
if self._items:
self._has_items_event.set()
else:
self._has_items_event.clear()


class WorkerInteractor:
Expand All @@ -77,22 +117,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
self.testrunuid = workerinput["testrunuid"]
self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug)
self.channel = channel
self.torun = self._make_queue()
self.torun = TestQueue(self.channel.gateway.execmodel)
self.nextitem_index: int | None | Literal[Marker.SHUTDOWN] = None
config.pluginmanager.register(self)

def _make_queue(self) -> Any:
return self.channel.gateway.execmodel.queue.Queue()

def _get_next_item_index(self) -> int | Literal[Marker.SHUTDOWN]:
"""Gets the next item from test queue. Handles the case when the queue
is replaced concurrently in another thread.
"""
result = self.torun.get()
while result is Marker.QUEUE_REPLACED:
result = self.torun.get()
return result # type: ignore[no-any-return]

def sendevent(self, name: str, **kwargs: object) -> None:
self.log("sending", name, kwargs)
self.channel.send((name, kwargs))
Expand Down Expand Up @@ -146,30 +174,34 @@ def handle_command(
self.steal(kwargs["indices"])

def steal(self, indices: Sequence[int]) -> None:
indices_set = set(indices)
stolen = []
"""
Remove tests from the queue.
old_queue, self.torun = self.torun, self._make_queue()
Removes either all requested tests, or none, if some of these tests
are not in the queue (for example, if they were processed already).
def old_queue_get_nowait_noraise() -> int | None:
with contextlib.suppress(self.channel.gateway.execmodel.queue.Empty):
return old_queue.get_nowait() # type: ignore[no-any-return]
return None
:param indices: indices of the tests to remove.
"""
requested_set = set(indices)

with self.torun.lock() as locked_queue:
stolen = list(item for item in locked_queue if item in requested_set)

for i in iter(old_queue_get_nowait_noraise, None):
if i in indices_set:
stolen.append(i)
# Stealing only if all requested tests are still pending
if len(stolen) == len(requested_set):
self.torun.replace(
item for item in locked_queue if item not in requested_set
)
else:
self.torun.put(i)
stolen = []

self.sendevent("unscheduled", indices=stolen)
old_queue.put(Marker.QUEUE_REPLACED)

@pytest.hookimpl
def pytest_runtestloop(self, session: pytest.Session) -> bool:
self.log("entering main loop")
self.channel.setcallback(self.handle_command, endmarker=Marker.SHUTDOWN)
self.nextitem_index = self._get_next_item_index()
self.nextitem_index = self.torun.get()
while self.nextitem_index is not Marker.SHUTDOWN:
self.run_one_test()
if session.shouldfail or session.shouldstop:
Expand All @@ -179,7 +211,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool:
def run_one_test(self) -> None:
assert isinstance(self.nextitem_index, int)
self.item_index = self.nextitem_index
self.nextitem_index = self._get_next_item_index()
self.nextitem_index = self.torun.get()

items = self.session.items
item = items[self.item_index]
Expand Down
6 changes: 6 additions & 0 deletions testing/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,12 @@ def test_func4(): pass

worker.sendcommand("steal", indices=[1, 2])
ev = worker.popevent("unscheduled")
# Cannot steal index 1 because it is completed already, so do not steal any.
assert ev.kwargs["indices"] == []

# Index 2 can be stolen, as it is still pending.
worker.sendcommand("steal", indices=[2])
ev = worker.popevent("unscheduled")
assert ev.kwargs["indices"] == [2]

reports = [
Expand Down

0 comments on commit 9788f12

Please sign in to comment.