diff --git a/mautrix/api.py b/mautrix/api.py index f6c9f475..0deca39a 100644 --- a/mautrix/api.py +++ b/mautrix/api.py @@ -21,7 +21,12 @@ from yarl import URL from mautrix import __optional_imports__, __version__ as mautrix_version -from mautrix.errors import MatrixConnectionError, MatrixRequestError, make_request_error +from mautrix.errors import ( + MatrixConnectionError, + MatrixRequestError, + MLimitExceeded, + make_request_error, +) from mautrix.util.async_body import AsyncBody, async_iter_bytes from mautrix.util.logging import TraceLogger from mautrix.util.opt_prometheus import Counter @@ -239,7 +244,7 @@ async def _send( ) async with request as response: if response.status < 200 or response.status >= 300: - errcode = unstable_errcode = message = None + response_data = errcode = unstable_errcode = message = None try: response_data = await response.json() errcode = response_data["errcode"] @@ -250,6 +255,7 @@ async def _send( raise make_request_error( http_status=response.status, text=await response.text(), + data=response_data, errcode=errcode, message=message, unstable_errcode=unstable_errcode, @@ -397,6 +403,23 @@ async def request( ) self._log_request_done(path, req_id, time.monotonic() - start, resp.status) return resp_data + except MLimitExceeded as e: + API_CALLS_FAILED.labels(method=metrics_method).inc() + if retry_count > 0: + retry = e.retry_after_ms + if retry is None: + retry = backoff + backoff *= 2 + else: + retry /= 1000 + self.log.info( + f"Request #{req_id} failed with {e.errcode}, " + f"retrying in {retry} seconds" + ) + await asyncio.sleep(retry) + else: + self._log_request_done(path, req_id, time.monotonic() - start, e.http_status) + raise except MatrixRequestError as e: API_CALLS_FAILED.labels(method=metrics_method).inc() if retry_count > 0 and e.http_status in (502, 503, 504): @@ -404,6 +427,8 @@ async def request( f"Request #{req_id} failed with HTTP {e.http_status}, " f"retrying in {backoff} seconds" ) + await asyncio.sleep(backoff) + backoff *= 2 else: self._log_request_done(path, req_id, time.monotonic() - start, e.http_status) raise @@ -413,13 +438,13 @@ async def request( self.log.warning( f"Request #{req_id} failed with {e}, retrying in {backoff} seconds" ) + await asyncio.sleep(backoff) + backoff *= 2 else: raise MatrixConnectionError(str(e)) from e except Exception: API_CALLS_FAILED.labels(method=metrics_method).inc() raise - await asyncio.sleep(backoff) - backoff *= 2 retry_count -= 1 def get_txn_id(self) -> str: diff --git a/mautrix/errors/request.py b/mautrix/errors/request.py index ebff4d76..6949226f 100644 --- a/mautrix/errors/request.py +++ b/mautrix/errors/request.py @@ -5,7 +5,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Callable, Type +from typing import Callable, Type, TypeVar from .base import MatrixError @@ -45,7 +45,7 @@ class MatrixStandardRequestError(MatrixRequestError): errcode: str = None - def __init__(self, http_status: int, message: str = "") -> None: + def __init__(self, http_status: int, message: str = "", **kwargs) -> None: super().__init__(message) self.http_status: int = http_status self.message: str = message @@ -55,9 +55,11 @@ def __init__(self, http_status: int, message: str = "") -> None: ec_map: dict[str, MxSRE] = {} uec_map: dict[str, MxSRE] = {} +T = TypeVar("T", bound=MxSRE) -def standard_error(code: str, unstable: str | None = None) -> Callable[[MxSRE], MxSRE]: - def decorator(cls: MxSRE) -> MxSRE: + +def standard_error(code: str, unstable: str | None = None) -> Callable[[T], T]: + def decorator(cls: T) -> T: cls.errcode = code ec_map[code] = cls if unstable: @@ -71,6 +73,7 @@ def decorator(cls: MxSRE) -> MxSRE: def make_request_error( http_status: int, text: str, + data: dict | None, errcode: str | None, message: str | None, unstable_errcode: str | None = None, @@ -82,6 +85,7 @@ def make_request_error( Args: http_status: The HTTP status code. text: The raw response text. + data: The response JSON. errcode: The errcode field in the response JSON. message: The error field in the response JSON. unstable_errcode: The MSC3848 error code field in the response JSON. @@ -94,7 +98,10 @@ def make_request_error( pass try: ec_class = ec_map[errcode] - return ec_class(http_status, message) + data = data if data else {} + data["http_status"] = http_status + data["message"] = message + return ec_class(**data) except KeyError: return MatrixUnknownRequestError(http_status, text, errcode, message) @@ -172,7 +179,9 @@ class MNotFound(MatrixStandardRequestError): @standard_error("M_LIMIT_EXCEEDED") class MLimitExceeded(MatrixStandardRequestError): - pass + def __init__(self, retry_after_ms: int | None = None, **kwargs) -> None: + super().__init__(**kwargs) + self.retry_after_ms: int | None = retry_after_ms @standard_error("M_UNKNOWN")