Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ParamSpec for wrapped signatures #508

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
24 changes: 24 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[mypy]
files = async_lru, tests
check_untyped_defs = True
follow_imports_for_stubs = True
disallow_any_decorated = True
asvetlov marked this conversation as resolved.
Show resolved Hide resolved
disallow_any_generics = True
disallow_any_unimported = True
disallow_incomplete_defs = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_decorators = True
disallow_untyped_defs = True
enable_error_code = ignore-without-code, possibly-undefined, redundant-expr, redundant-self, truthy-bool, truthy-iterable, unused-awaitable
implicit_reexport = False
no_implicit_optional = True
pretty = True
show_column_numbers = True
show_error_codes = True
strict_equality = True
warn_incomplete_stub = True
warn_redundant_casts = True
warn_return_any = True
warn_unreachable = True
warn_unused_ignores = True
49 changes: 29 additions & 20 deletions async_lru/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
)


if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


if sys.version_info >= (3, 11):
from typing import Self
else:
Expand All @@ -35,9 +41,10 @@

_T = TypeVar("_T")
_R = TypeVar("_R")
_P = ParamSpec("_P")
_Coro = Coroutine[Any, Any, _R]
_CB = Callable[..., _Coro[_R]]
_CBP = Union[_CB[_R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"]
_CB = Callable[_P, _Coro[_R]]
_CBP = Union[_CB[_P, _R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"]


@final
Expand All @@ -61,10 +68,10 @@ def cancel(self) -> None:


@final
class _LRUCacheWrapper(Generic[_R]):
class _LRUCacheWrapper(Generic[_P, _R]):
def __init__(
self,
fn: _CB[_R],
fn: _CB[_P, _R],
maxsize: Optional[int],
typed: bool,
ttl: Optional[float],
Expand Down Expand Up @@ -188,7 +195,7 @@ def _task_done_callback(

fut.set_result(task.result())

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
async def __call__(self, /, *fn_args: _P.args, **fn_kwargs: _P.kwargs) -> _R:
if self.__closed:
raise RuntimeError(f"alru_cache is closed for {self}")

Expand All @@ -207,7 +214,7 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:

fut = loop.create_future()
coro = self.__wrapped__(*fn_args, **fn_kwargs)
task: asyncio.Task[_R] = loop.create_task(coro)
task = loop.create_task(coro)
self.__tasks.add(task)
task.add_done_callback(partial(self._task_done_callback, fut, key))

Expand All @@ -222,18 +229,18 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:

def __get__(
self, instance: _T, owner: Optional[Type[_T]]
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]:
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_P, _R, _T]"]:
if owner is None:
return self
else:
return _LRUCacheWrapperInstanceMethod(self, instance)


@final
class _LRUCacheWrapperInstanceMethod(Generic[_R, _T]):
class _LRUCacheWrapperInstanceMethod(Generic[_P, _R, _T]):
def __init__(
self,
wrapper: _LRUCacheWrapper[_R],
wrapper: _LRUCacheWrapper[_P, _R],
instance: _T,
) -> None:
try:
Expand Down Expand Up @@ -284,16 +291,16 @@ def cache_info(self) -> _CacheInfo:
def cache_parameters(self) -> _CacheParameters:
return self.__wrapper.cache_parameters()

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs)
async def __call__(self, /, *fn_args: _P.args, **fn_kwargs: _P.kwargs) -> _R:
return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs) # type: ignore[arg-type]


def _make_wrapper(
maxsize: Optional[int],
typed: bool,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
) -> Callable[[_CBP[_P, _R]], _LRUCacheWrapper[_P, _R]]:
def wrapper(fn: _CBP[_P, _R]) -> _LRUCacheWrapper[_P, _R]:
origin = fn

while isinstance(origin, (partial, partialmethod)):
Expand All @@ -306,7 +313,7 @@ def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
if hasattr(fn, "_make_unbound_method"):
fn = fn._make_unbound_method()

return _LRUCacheWrapper(cast(_CB[_R], fn), maxsize, typed, ttl)
return _LRUCacheWrapper(cast(_CB[_P, _R], fn), maxsize, typed, ttl)

return wrapper

Expand All @@ -317,28 +324,30 @@ def alru_cache(
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
) -> Callable[[_CBP[_P, _R]], _LRUCacheWrapper[_P, _R]]:
...


@overload
def alru_cache(
maxsize: _CBP[_R],
maxsize: _CBP[_P, _R],
/,
) -> _LRUCacheWrapper[_R]:
) -> _LRUCacheWrapper[_P, _R]:
...


def alru_cache(
maxsize: Union[Optional[int], _CBP[_R]] = 128,
maxsize: Union[Optional[int], _CBP[_P, _R]] = 128,
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Union[Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], _LRUCacheWrapper[_R]]:
) -> Union[
Callable[[_CBP[_P, _R]], _LRUCacheWrapper[_P, _R]], _LRUCacheWrapper[_P, _R]
]:
if maxsize is None or isinstance(maxsize, int):
return _make_wrapper(maxsize, typed, ttl)
else:
fn = cast(_CB[_R], maxsize)
fn = maxsize

if callable(fn) or hasattr(fn, "_make_unbound_method"):
return _make_wrapper(128, False, None)(fn)
Expand Down
5 changes: 0 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,3 @@ junit_family=xunit2
asyncio_mode=auto
timeout=15
xfail_strict = true

[mypy]
strict=True
pretty=True
packages=async_lru, tests
19 changes: 15 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
import sys
from functools import _CacheInfo
from typing import Callable
from typing import Callable, TypeVar

import pytest

from async_lru import _R, _LRUCacheWrapper
from async_lru import _LRUCacheWrapper


if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


_T = TypeVar("_T")
_P = ParamSpec("_P")


@pytest.fixture
def check_lru() -> Callable[..., None]:
def check_lru() -> Callable[..., None]: # type: ignore[misc]
def _check_lru(
wrapped: _LRUCacheWrapper[_R],
wrapped: _LRUCacheWrapper[_P, _T],
*,
hits: int,
misses: int,
Expand Down
Loading