diff --git a/docs/source/index.rst b/docs/source/index.rst index fd29f82..28173ea 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -33,6 +33,8 @@ and async I/O in Python. Features include: `__ library, so your async tests can use property-based testing: just use ``@given`` like you're used to. +* Integration with `pytest-timeout ` + * Support for testing projects that use Trio exclusively and want to use pytest-trio everywhere, and also for testing projects that support multiple async libraries and only want to enable diff --git a/docs/source/reference.rst b/docs/source/reference.rst index 351e8c1..7ae3be3 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -420,3 +420,28 @@ it can be passed directly to the marker. @pytest.mark.trio(run=qtrio.run) async def test(): assert True + + +Configuring timeouts with pytest-timeout +---------------------------------------- + +Timeouts can be configured using the ``@pytest.mark.timeout`` decorator. + +.. code-block:: python + + import pytest + import trio + + @pytest.mark.timeout(10) + async def test(): + await trio.sleep_forever() # will error after 10 seconds + +To get clean stacktraces that cover all tasks running when the timeout was triggered, enable the ``trio_timeout`` option. + +.. code-block:: ini + + # pytest.ini + [pytest] + trio_timeout = true + +This timeout method requires a functioning loop, and hence will not be triggered if your test doesn't yield to the loop. This typically occurs when the test is stuck on some non-async piece of code. diff --git a/newsfragments/53.feature.rst b/newsfragments/53.feature.rst new file mode 100644 index 0000000..b1247b2 --- /dev/null +++ b/newsfragments/53.feature.rst @@ -0,0 +1 @@ +Add support for pytest-timeout using our own timeout method. This timeout method can be enable via the option ``trio_timeout`` in ``pytest.ini`` and will print structured tracebacks of all tasks running when the timeout happened. diff --git a/pytest_trio/plugin.py b/pytest_trio/plugin.py index 1a56a83..c12b902 100644 --- a/pytest_trio/plugin.py +++ b/pytest_trio/plugin.py @@ -1,4 +1,5 @@ """pytest-trio implementation.""" +from __future__ import annotations import sys from functools import wraps, partial from collections.abc import Coroutine, Generator @@ -11,6 +12,8 @@ from trio.abc import Clock, Instrument from trio.testing import MockClock from _pytest.outcomes import Skipped, XFailed +# pytest_timeout_set_timer needs to be imported here for pluggy +from .timeout import set_timeout, pytest_timeout_set_timer as pytest_timeout_set_timer if sys.version_info[:2] < (3, 11): from exceptiongroup import BaseExceptionGroup @@ -41,6 +44,12 @@ def pytest_addoption(parser): type="bool", default=False, ) + parser.addini( + "trio_timeout", + "should pytest-trio handle timeouts on async functions?", + type="bool", + default=False, + ) parser.addini( "trio_run", "what runner should pytest-trio use? [trio, qtrio]", @@ -404,6 +413,9 @@ async def _bootstrap_fixtures_and_run_test(**kwargs): contextvars_ctx = contextvars.copy_context() contextvars_ctx.run(canary.set, "in correct context") + if item is not None: + set_timeout(item) + async with trio.open_nursery() as nursery: for fixture in test.register_and_collect_dependencies(): nursery.start_soon( diff --git a/pytest_trio/timeout.py b/pytest_trio/timeout.py new file mode 100644 index 0000000..e8ecc15 --- /dev/null +++ b/pytest_trio/timeout.py @@ -0,0 +1,103 @@ +from __future__ import annotations +from typing import Optional +import warnings +import threading +import trio +import pytest +import pytest_timeout +from .traceback_format import format_recursive_nursery_stack + + +pytest_timeout_settings = pytest.StashKey[pytest_timeout.Settings]() +send_timeout_callable = None +send_timeout_callable_ready_event = threading.Event() + + +def set_timeout(item: pytest.Item) -> None: + try: + settings = item.stash[pytest_timeout_settings] + except KeyError: + # No timeout or not our timeout + return + + if settings.func_only: + warnings.warn( + "Function only timeouts are not supported for trio based timeouts" + ) + + global send_timeout_callable + + # Shouldn't be racy, as xdist uses different processes + if send_timeout_callable is None: + threading.Thread(target=trio_timeout_thread, daemon=True).start() + + send_timeout_callable_ready_event.wait() + + send_timeout_callable(settings.timeout) + + +@pytest.hookimpl() +def pytest_timeout_set_timer( + item: pytest.Item, settings: pytest_timeout.Settings +) -> Optional[bool]: + if item.get_closest_marker("trio") is not None and item.config.getini("trio_timeout"): + item.stash[pytest_timeout_settings] = settings + return True + + +# No need for pytest_timeout_cancel_timer as we detect that the test loop has exited + + +def trio_timeout_thread(): + async def run_timeouts(): + async with trio.open_nursery() as nursery: + token = trio.lowlevel.current_trio_token() + + async def wait_timeout(token: trio.TrioToken, timeout: float) -> None: + await trio.sleep(timeout) + + try: + token.run_sync_soon( + lambda: trio.lowlevel.spawn_system_task(execute_timeout) + ) + except RuntimeError: + # test has finished + pass + + def send_timeout(timeout: float): + test_token = trio.lowlevel.current_trio_token() + token.run_sync_soon( + lambda: nursery.start_soon(wait_timeout, test_token, timeout) + ) + + global send_timeout_callable + send_timeout_callable = send_timeout + send_timeout_callable_ready_event.set() + + await trio.sleep_forever() + + trio.run(run_timeouts) + + +async def execute_timeout() -> None: + if pytest_timeout.is_debugging(): + return + + nursery = get_test_nursery() + stack = "\n".join(format_recursive_nursery_stack(nursery) + ["Timeout reached"]) + + async def report(): + pytest.fail(stack, pytrace=False) + + nursery.start_soon(report) + + +def get_test_nursery() -> trio.Nursery: + task = trio.lowlevel.current_task().parent_nursery.parent_task + + for nursery in task.child_nurseries: + for task in nursery.child_tasks: + if task.name.startswith("pytest_trio.plugin._trio_test_runner_factory"): + return task.child_nurseries[0] + + raise Exception("Could not find test nursery") diff --git a/pytest_trio/traceback_format.py b/pytest_trio/traceback_format.py new file mode 100644 index 0000000..eb4a962 --- /dev/null +++ b/pytest_trio/traceback_format.py @@ -0,0 +1,70 @@ +from __future__ import annotations +from trio.lowlevel import Task +from itertools import chain +import traceback + + +def format_stack_for_task(task: Task, prefix: str) -> list[str]: + stack = list(task.iter_await_frames()) + + nursery_waiting_children = False + + for i, (frame, lineno) in enumerate(stack): + if frame.f_code.co_name == "_nested_child_finished": + stack = stack[: i - 1] + nursery_waiting_children = True + break + if frame.f_code.co_name == "wait_task_rescheduled": + stack = stack[:i] + break + if frame.f_code.co_name == "checkpoint": + stack = stack[:i] + break + + stack = (frame for frame in stack if "__tracebackhide__" not in frame[0].f_locals) + + ss = traceback.StackSummary.extract(stack) + formated_traceback = list( + map(lambda x: prefix + x[2:], "".join(ss.format()).splitlines()) + ) + + if nursery_waiting_children: + formated_traceback.append(prefix + "Awaiting completion of children") + formated_traceback.append(prefix) + + return formated_traceback + + +def format_task(task: Task, prefix: str = "") -> list[str]: + lines = [] + + subtasks = list( + chain(*(child_nursery.child_tasks for child_nursery in task.child_nurseries)) + ) + + if subtasks: + trace_prefix = prefix + "│" + else: + trace_prefix = prefix + " " + + lines.extend(format_stack_for_task(task, trace_prefix)) + + for i, subtask in enumerate(subtasks): + if (i + 1) != len(subtasks): + lines.append(f"{prefix}├ {subtask.name}") + lines.extend(format_task(subtask, prefix=f"{prefix}│ ")) + else: + lines.append(f"{prefix}└ {subtask.name}") + lines.extend(format_task(subtask, prefix=f"{prefix} ")) + + return lines + + +def format_recursive_nursery_stack(nursery) -> list[str]: + stack = [] + + for task in nursery.child_tasks: + stack.append(task.name) + stack.extend(format_task(task)) + + return stack diff --git a/setup.py b/setup.py index 9fbb81f..98dfcb4 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ "trio >= 0.22.0", # for ExceptionGroup support "outcome >= 1.1.0", "pytest >= 7.2.0", # for ExceptionGroup support + "pytest_timeout", ], keywords=[ "async",