Skip to content

Commit

Permalink
Add async fetcher (RangeStream only) and tests for async single reque…
Browse files Browse the repository at this point in the history
…st RangeStream creation
  • Loading branch information
lmmx committed Aug 11, 2021
1 parent a08e258 commit 68604a7
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ aiostream
httpx
python-ranges
pyzstd
tqdm
2 changes: 2 additions & 0 deletions src/range_streams/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
# Get classes into package namespace but exclude from __all__ so Sphinx can access types

from . import codecs, http_utils, overlaps, range_utils
from .async_utils import AsyncFetcher
from .request import RangeRequest
from .response import RangeResponse
from .stream import RangeStream
Expand All @@ -137,6 +138,7 @@
"overlaps",
"range_utils",
"codecs",
"async_utils",
]

__author__ = "Louis Maddox"
Expand Down
187 changes: 187 additions & 0 deletions src/range_streams/async_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from __future__ import annotations

import asyncio
import time
from asyncio.events import AbstractEventLoop
from functools import partial
from signal import SIGINT, SIGTERM, Signals
from sys import stderr
from typing import TYPE_CHECKING, Callable, Coroutine, Iterator

from aiostream import stream
from ranges import Range, RangeSet

MYPY = False # when using mypy will be overrided as True
if MYPY or not TYPE_CHECKING: # pragma: no cover
import httpx # avoid importing to Sphinx type checker

import tqdm
from tqdm.asyncio import tqdm_asyncio

from .log_utils import log, set_up_logging
from .stream import RangeStream

__all__ = ["SignalHaltError", "AsyncFetcher"]


class AsyncFetcher:
def __init__(
self,
urls: list[str],
callback: Callable | None = None,
verbose: bool = False,
show_progress_bar: bool = True,
timeout_s: float = 5.0,
client=None,
):
"""
Args:
callback : A function to be passed 3 values: the AsyncFetcher which is calling
it, the awaited RangeStream, and its source URL (a ``httpx.URL``,
which can be coerced to a string).
"""
if urls == []:
raise ValueError("The list of URLs to fetch cannot be empty")
self.url_list = urls
self.callback = callback
self.n = len(urls)
self.verbose = verbose
self.show_progress_bar = show_progress_bar and not self.verbose
self.client = client
self.timeout = httpx.Timeout(timeout=timeout_s)
self.completed = RangeSet()
set_up_logging(quiet=not verbose)

def make_calls(self):
"""
The method called to run the event loop to fetch URLs, after initialisation
and/or repeatedly upon exitting the loop (i.e. it can recover from errors).
"""
urlset = (u for u in self.filtered_url_list) # single use URL generator
if self.show_progress_bar:
self.set_up_progress_bar()
self.fetch_things(urls=urlset)
if self.show_progress_bar:
self.pbar.close()

async def process_stream(self, rstream: RangeStream):
"""
Process an awaited RangeStream within an async fetch loop, calling the callback
set on the `~range_streams.async_utils.AsyncFetcher.callback` attribute.
Args:
rstream : The awaited RangeStream
"""
monostream_response = rstream._ranges[rstream.total_range]
resp = monostream_response.request.response # httpx.Response
source_url = resp.history[0].url if resp.history else resp.url
# Map the response back to the thing it came from in the url_list
i = next(i for (i, u) in enumerate(self.url_list) if source_url == u)
if self.callback is not None:
await self.callback(self, stream, source_url)
if self.verbose:
log.debug(f"Processed URL in async callback: {source_url}")
if self.show_progress_bar:
self.pbar.update()
self.completed.add(Range(i, i + 1))
await resp.aclose()

@property
def filtered_url_list(self) -> list[str]:
if self.completed.isempty():
urls = self.url_list
else:
urls = [u for (i, u) in enumerate(self.url_list) if i not in self.completed]
return urls

def set_up_progress_bar(self):
n_already_fetched = self.n - len(self.filtered_url_list)
self.pbar = tqdm_asyncio(total=self.n)
if n_already_fetched:
self.pbar.update(n_already_fetched)
self.pbar.refresh()

def fetch_things(self, urls: Iterator[str]):
try:
return asyncio.run(self.async_fetch_urlset(urls))
except SignalHaltError as exc:
if self.show_progress_bar:
self.pbar.disable = True
self.pbar.close()

async def fetch(self, client: httpx.AsyncClient, url: httpx.URL) -> RangeStream:
s = RangeStream(
url=str(url), client=client, single_request=True, force_async=True
)
await s.add_async()
return s

async def async_fetch_urlset(
self,
urls: Iterator[str],
) -> Coroutine:
"""
If the `~range_streams.async_utils.AsyncFetcher.client` is ``None``, create one
in a contextmanager block (i.e. close it immediately after use), otherwise use
the one provided, not in a contextmanager block (i.e. leave it up to the user to
close the client).
"""
await self.set_async_signal_handlers()
if self.client is None:
async with httpx.AsyncClient() as client:
processed = await self.fetch_and_process(urls=urls, client=client)
else:
if self.client.is_closed:
msg = (
"Cannot use a closed client to fetch.\n\nDid you attempt to retry "
" after using the client in a contextmanager block (which implicitly"
" closes after exiting the block) perhaps?"
)
raise ValueError(msg)
# assert self.client is not None # give mypy a clue
processed = await self.fetch_and_process(urls=urls, client=client)
return processed

async def fetch_and_process(self, urls: Iterator[str], client):
assert isinstance(client, httpx.AsyncClient) # Not type checked due to Sphinx
client.timeout = self.timeout
ws = stream.repeat(client)
xs = stream.zip(ws, stream.iterate(urls))
ys = stream.starmap(xs, self.fetch, ordered=False, task_limit=20)
zs = stream.map(ys, self.process_stream)
return await zs

def immediate_exit(self, signal_enum: Signals, loop: AbstractEventLoop) -> None:
loop.stop()
halt_error = SignalHaltError(signal_enum=signal_enum)
raise halt_error

async def set_async_signal_handlers(self) -> None:
loop = asyncio.get_running_loop()
for signal_enum in [SIGINT, SIGTERM]:
exit_func = partial(self.immediate_exit, signal_enum=signal_enum, loop=loop)
loop.add_signal_handler(signal_enum, exit_func)


class SignalHaltError(SystemExit):
def __init__(self, signal_enum: Signals):
self.signal_enum = signal_enum
print("", file=stderr) # Newline after the signal sequence printed to console
log.critical(msg=repr(self))
super().__init__(self.exit_code)

@property
def exit_code(self) -> int:
return self.signal_enum.value

def __repr__(self) -> str:
return f"Exitted due to {self.signal_enum.name}"


# def demo_fetch(url_list):
# fetched = AsyncFetcher(urls=url_list, verbose=False)
# try:
# fetched.make_calls()
# except Exception as exc:
# log.debug("DEBUG ::" + repr(exc)) # Suppress it to log
# print(f"... {exc!r}")
21 changes: 21 additions & 0 deletions src/range_streams/log_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import logging

__all__ = ["log", "set_up_logging"]

log = logging.getLogger() # Provided for ease of access in other modules


def set_up_logging(quiet: bool = True):
"""
Initialise the log
Args:
quiet : Change this flag to True/False to turn off/on console logging
"""
log.setLevel(logging.DEBUG)
log_format = logging.Formatter("[%(asctime)s] [%(levelname)s] - %(message)s")
if not quiet:
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
console.setFormatter(log_format)
log.addHandler(console)
80 changes: 80 additions & 0 deletions tests/async_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import asyncio
from signal import SIGINT

from pytest import fixture, mark, raises
from ranges import Range

from range_streams import _EXAMPLE_PNG_URL, _EXAMPLE_ZIP_URL, RangeStream
from range_streams.async_utils import AsyncFetcher, SignalHaltError

from .data import EXAMPLE_FILE_LENGTH, EXAMPLE_URL

# https://tonybaloney.github.io/posts/async-test-patterns-for-pytest-and-unittest.html

THREE_URLS = [EXAMPLE_URL, _EXAMPLE_PNG_URL, _EXAMPLE_ZIP_URL]


class CallbackMutatedClass:
values = []

@classmethod
def reset(cls):
"""
Reset the class attribute where tests store the URLs they called back from
"""
cls.values = []


async def demo_callback_func(fetcher, range_stream, url):
return CallbackMutatedClass.values.append(url)


async def sigint_callback_func(fetcher, range_stream, url):
"""
Mimic the act of sending the signal interrupt by raising it in a callback
"""
await demo_callback_func(fetcher, range_stream, url)
# raise KeyboardInterrupt ?
loop = asyncio.get_running_loop()
fetcher.immediate_exit(signal_enum=SIGINT, loop=loop)


@mark.parametrize("callback", [None, demo_callback_func])
@mark.parametrize("verbose", [True, False])
@mark.parametrize("error_msg", ["The list of URLs to fetch cannot be empty"])
@mark.parametrize("urls", [([]), (THREE_URLS)])
def test_fetcher(urls, error_msg, verbose, callback):
"""
Fetch lists of 0 or 3 URLs asynchronously, with/out a callback, verbosely/quietly.
"""
args = dict(callback=callback, urls=urls, verbose=verbose, show_progress_bar=False)
if urls == []:
with raises(ValueError, match=error_msg):
fetched = AsyncFetcher(**args)
else:
fetched = AsyncFetcher(**args)
fetched.make_calls()
expected_values = set() if callback is None else set(urls)
stored_urls = getattr(CallbackMutatedClass, "values")
assert set(stored_urls) == set(expected_values)
CallbackMutatedClass.reset()


@mark.parametrize("callback", [sigint_callback_func])
@mark.parametrize("error_msg", ["The list of URLs to fetch cannot be empty"])
@mark.parametrize("urls", [(THREE_URLS)])
def test_fetcher_sigint(urls, error_msg, callback):
"""
Fetch lists of 3 URLs asynchronously, with/out a callback, verbosely/quietly.
Cannot figure out how to emulate passing the SIGINT from this test so can't catch,
best I can do here is to check that the loop is stopped at the first callback when
``immediate_exit`` is called.
"""
args = dict(callback=callback, urls=urls, show_progress_bar=False)
fetched = AsyncFetcher(**args)
# with raises(SignalHaltError, match=error_msg):
fetched.make_calls()
stored_urls = getattr(CallbackMutatedClass, "values")
assert len(stored_urls) == 1
assert set(stored_urls) < set(urls)
CallbackMutatedClass.reset()

0 comments on commit 68604a7

Please sign in to comment.