-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add async fetcher (RangeStream only) and tests for async single reque…
…st RangeStream creation
- Loading branch information
Showing
5 changed files
with
291 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ aiostream | |
httpx | ||
python-ranges | ||
pyzstd | ||
tqdm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import time | ||
from asyncio.events import AbstractEventLoop | ||
from functools import partial | ||
from signal import SIGINT, SIGTERM, Signals | ||
from sys import stderr | ||
from typing import TYPE_CHECKING, Callable, Coroutine, Iterator | ||
|
||
from aiostream import stream | ||
from ranges import Range, RangeSet | ||
|
||
MYPY = False # when using mypy will be overrided as True | ||
if MYPY or not TYPE_CHECKING: # pragma: no cover | ||
import httpx # avoid importing to Sphinx type checker | ||
|
||
import tqdm | ||
from tqdm.asyncio import tqdm_asyncio | ||
|
||
from .log_utils import log, set_up_logging | ||
from .stream import RangeStream | ||
|
||
__all__ = ["SignalHaltError", "AsyncFetcher"] | ||
|
||
|
||
class AsyncFetcher: | ||
def __init__( | ||
self, | ||
urls: list[str], | ||
callback: Callable | None = None, | ||
verbose: bool = False, | ||
show_progress_bar: bool = True, | ||
timeout_s: float = 5.0, | ||
client=None, | ||
): | ||
""" | ||
Args: | ||
callback : A function to be passed 3 values: the AsyncFetcher which is calling | ||
it, the awaited RangeStream, and its source URL (a ``httpx.URL``, | ||
which can be coerced to a string). | ||
""" | ||
if urls == []: | ||
raise ValueError("The list of URLs to fetch cannot be empty") | ||
self.url_list = urls | ||
self.callback = callback | ||
self.n = len(urls) | ||
self.verbose = verbose | ||
self.show_progress_bar = show_progress_bar and not self.verbose | ||
self.client = client | ||
self.timeout = httpx.Timeout(timeout=timeout_s) | ||
self.completed = RangeSet() | ||
set_up_logging(quiet=not verbose) | ||
|
||
def make_calls(self): | ||
""" | ||
The method called to run the event loop to fetch URLs, after initialisation | ||
and/or repeatedly upon exitting the loop (i.e. it can recover from errors). | ||
""" | ||
urlset = (u for u in self.filtered_url_list) # single use URL generator | ||
if self.show_progress_bar: | ||
self.set_up_progress_bar() | ||
self.fetch_things(urls=urlset) | ||
if self.show_progress_bar: | ||
self.pbar.close() | ||
|
||
async def process_stream(self, rstream: RangeStream): | ||
""" | ||
Process an awaited RangeStream within an async fetch loop, calling the callback | ||
set on the `~range_streams.async_utils.AsyncFetcher.callback` attribute. | ||
Args: | ||
rstream : The awaited RangeStream | ||
""" | ||
monostream_response = rstream._ranges[rstream.total_range] | ||
resp = monostream_response.request.response # httpx.Response | ||
source_url = resp.history[0].url if resp.history else resp.url | ||
# Map the response back to the thing it came from in the url_list | ||
i = next(i for (i, u) in enumerate(self.url_list) if source_url == u) | ||
if self.callback is not None: | ||
await self.callback(self, stream, source_url) | ||
if self.verbose: | ||
log.debug(f"Processed URL in async callback: {source_url}") | ||
if self.show_progress_bar: | ||
self.pbar.update() | ||
self.completed.add(Range(i, i + 1)) | ||
await resp.aclose() | ||
|
||
@property | ||
def filtered_url_list(self) -> list[str]: | ||
if self.completed.isempty(): | ||
urls = self.url_list | ||
else: | ||
urls = [u for (i, u) in enumerate(self.url_list) if i not in self.completed] | ||
return urls | ||
|
||
def set_up_progress_bar(self): | ||
n_already_fetched = self.n - len(self.filtered_url_list) | ||
self.pbar = tqdm_asyncio(total=self.n) | ||
if n_already_fetched: | ||
self.pbar.update(n_already_fetched) | ||
self.pbar.refresh() | ||
|
||
def fetch_things(self, urls: Iterator[str]): | ||
try: | ||
return asyncio.run(self.async_fetch_urlset(urls)) | ||
except SignalHaltError as exc: | ||
if self.show_progress_bar: | ||
self.pbar.disable = True | ||
self.pbar.close() | ||
|
||
async def fetch(self, client: httpx.AsyncClient, url: httpx.URL) -> RangeStream: | ||
s = RangeStream( | ||
url=str(url), client=client, single_request=True, force_async=True | ||
) | ||
await s.add_async() | ||
return s | ||
|
||
async def async_fetch_urlset( | ||
self, | ||
urls: Iterator[str], | ||
) -> Coroutine: | ||
""" | ||
If the `~range_streams.async_utils.AsyncFetcher.client` is ``None``, create one | ||
in a contextmanager block (i.e. close it immediately after use), otherwise use | ||
the one provided, not in a contextmanager block (i.e. leave it up to the user to | ||
close the client). | ||
""" | ||
await self.set_async_signal_handlers() | ||
if self.client is None: | ||
async with httpx.AsyncClient() as client: | ||
processed = await self.fetch_and_process(urls=urls, client=client) | ||
else: | ||
if self.client.is_closed: | ||
msg = ( | ||
"Cannot use a closed client to fetch.\n\nDid you attempt to retry " | ||
" after using the client in a contextmanager block (which implicitly" | ||
" closes after exiting the block) perhaps?" | ||
) | ||
raise ValueError(msg) | ||
# assert self.client is not None # give mypy a clue | ||
processed = await self.fetch_and_process(urls=urls, client=client) | ||
return processed | ||
|
||
async def fetch_and_process(self, urls: Iterator[str], client): | ||
assert isinstance(client, httpx.AsyncClient) # Not type checked due to Sphinx | ||
client.timeout = self.timeout | ||
ws = stream.repeat(client) | ||
xs = stream.zip(ws, stream.iterate(urls)) | ||
ys = stream.starmap(xs, self.fetch, ordered=False, task_limit=20) | ||
zs = stream.map(ys, self.process_stream) | ||
return await zs | ||
|
||
def immediate_exit(self, signal_enum: Signals, loop: AbstractEventLoop) -> None: | ||
loop.stop() | ||
halt_error = SignalHaltError(signal_enum=signal_enum) | ||
raise halt_error | ||
|
||
async def set_async_signal_handlers(self) -> None: | ||
loop = asyncio.get_running_loop() | ||
for signal_enum in [SIGINT, SIGTERM]: | ||
exit_func = partial(self.immediate_exit, signal_enum=signal_enum, loop=loop) | ||
loop.add_signal_handler(signal_enum, exit_func) | ||
|
||
|
||
class SignalHaltError(SystemExit): | ||
def __init__(self, signal_enum: Signals): | ||
self.signal_enum = signal_enum | ||
print("", file=stderr) # Newline after the signal sequence printed to console | ||
log.critical(msg=repr(self)) | ||
super().__init__(self.exit_code) | ||
|
||
@property | ||
def exit_code(self) -> int: | ||
return self.signal_enum.value | ||
|
||
def __repr__(self) -> str: | ||
return f"Exitted due to {self.signal_enum.name}" | ||
|
||
|
||
# def demo_fetch(url_list): | ||
# fetched = AsyncFetcher(urls=url_list, verbose=False) | ||
# try: | ||
# fetched.make_calls() | ||
# except Exception as exc: | ||
# log.debug("DEBUG ::" + repr(exc)) # Suppress it to log | ||
# print(f"... {exc!r}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import logging | ||
|
||
__all__ = ["log", "set_up_logging"] | ||
|
||
log = logging.getLogger() # Provided for ease of access in other modules | ||
|
||
|
||
def set_up_logging(quiet: bool = True): | ||
""" | ||
Initialise the log | ||
Args: | ||
quiet : Change this flag to True/False to turn off/on console logging | ||
""" | ||
log.setLevel(logging.DEBUG) | ||
log_format = logging.Formatter("[%(asctime)s] [%(levelname)s] - %(message)s") | ||
if not quiet: | ||
console = logging.StreamHandler() | ||
console.setLevel(logging.DEBUG) | ||
console.setFormatter(log_format) | ||
log.addHandler(console) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import asyncio | ||
from signal import SIGINT | ||
|
||
from pytest import fixture, mark, raises | ||
from ranges import Range | ||
|
||
from range_streams import _EXAMPLE_PNG_URL, _EXAMPLE_ZIP_URL, RangeStream | ||
from range_streams.async_utils import AsyncFetcher, SignalHaltError | ||
|
||
from .data import EXAMPLE_FILE_LENGTH, EXAMPLE_URL | ||
|
||
# https://tonybaloney.github.io/posts/async-test-patterns-for-pytest-and-unittest.html | ||
|
||
THREE_URLS = [EXAMPLE_URL, _EXAMPLE_PNG_URL, _EXAMPLE_ZIP_URL] | ||
|
||
|
||
class CallbackMutatedClass: | ||
values = [] | ||
|
||
@classmethod | ||
def reset(cls): | ||
""" | ||
Reset the class attribute where tests store the URLs they called back from | ||
""" | ||
cls.values = [] | ||
|
||
|
||
async def demo_callback_func(fetcher, range_stream, url): | ||
return CallbackMutatedClass.values.append(url) | ||
|
||
|
||
async def sigint_callback_func(fetcher, range_stream, url): | ||
""" | ||
Mimic the act of sending the signal interrupt by raising it in a callback | ||
""" | ||
await demo_callback_func(fetcher, range_stream, url) | ||
# raise KeyboardInterrupt ? | ||
loop = asyncio.get_running_loop() | ||
fetcher.immediate_exit(signal_enum=SIGINT, loop=loop) | ||
|
||
|
||
@mark.parametrize("callback", [None, demo_callback_func]) | ||
@mark.parametrize("verbose", [True, False]) | ||
@mark.parametrize("error_msg", ["The list of URLs to fetch cannot be empty"]) | ||
@mark.parametrize("urls", [([]), (THREE_URLS)]) | ||
def test_fetcher(urls, error_msg, verbose, callback): | ||
""" | ||
Fetch lists of 0 or 3 URLs asynchronously, with/out a callback, verbosely/quietly. | ||
""" | ||
args = dict(callback=callback, urls=urls, verbose=verbose, show_progress_bar=False) | ||
if urls == []: | ||
with raises(ValueError, match=error_msg): | ||
fetched = AsyncFetcher(**args) | ||
else: | ||
fetched = AsyncFetcher(**args) | ||
fetched.make_calls() | ||
expected_values = set() if callback is None else set(urls) | ||
stored_urls = getattr(CallbackMutatedClass, "values") | ||
assert set(stored_urls) == set(expected_values) | ||
CallbackMutatedClass.reset() | ||
|
||
|
||
@mark.parametrize("callback", [sigint_callback_func]) | ||
@mark.parametrize("error_msg", ["The list of URLs to fetch cannot be empty"]) | ||
@mark.parametrize("urls", [(THREE_URLS)]) | ||
def test_fetcher_sigint(urls, error_msg, callback): | ||
""" | ||
Fetch lists of 3 URLs asynchronously, with/out a callback, verbosely/quietly. | ||
Cannot figure out how to emulate passing the SIGINT from this test so can't catch, | ||
best I can do here is to check that the loop is stopped at the first callback when | ||
``immediate_exit`` is called. | ||
""" | ||
args = dict(callback=callback, urls=urls, show_progress_bar=False) | ||
fetched = AsyncFetcher(**args) | ||
# with raises(SignalHaltError, match=error_msg): | ||
fetched.make_calls() | ||
stored_urls = getattr(CallbackMutatedClass, "values") | ||
assert len(stored_urls) == 1 | ||
assert set(stored_urls) < set(urls) | ||
CallbackMutatedClass.reset() |