Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved H3 for hypercorn. #201

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/hypercorn/asyncio/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Awaitable, Callable, Optional

from ..config import Config
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope, Timer

try:
from asyncio import TaskGroup as AsyncioTaskGroup
Expand All @@ -33,6 +33,44 @@ async def _handle(
await send(None)


LONG_SLEEP = 86400.0

class AsyncioTimer(Timer):
def __init__(self, action: Callable) -> None:
self._action = action
self._done = False
self._wake_up = asyncio.Condition()
self._when: Optional[float] = None

async def schedule(self, when: Optional[float]) -> None:
self._when = when
async with self._wake_up:
self._wake_up.notify()

async def stop(self) -> None:
self._done = True
async with self._wake_up:
self._wake_up.notify()

async def _wait_for_wake_up(self) -> None:
async with self._wake_up:
await self._wake_up.wait()

async def run(self) -> None:
while not self._done:
if self._when is not None and asyncio.get_event_loop().time() >= self._when:
self._when = None
await self._action()
if self._when is not None:
timeout = max(self._when - asyncio.get_event_loop().time(), 0.0)
else:
timeout = LONG_SLEEP
if not self._done:
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except TimeoutError:
pass

class TaskGroup:
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
Expand Down Expand Up @@ -66,6 +104,11 @@ def _call_soon(func: Callable, *args: Any) -> Any:
def spawn(self, func: Callable, *args: Any) -> None:
self._task_group.create_task(func(*args))

def create_timer(self, action: Callable) -> Timer:
timer = AsyncioTimer(action)
self._task_group.create_task(timer.run())
return timer

async def __aenter__(self) -> "TaskGroup":
await self._task_group.__aenter__()
return self
Expand Down
3 changes: 3 additions & 0 deletions src/hypercorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
BYTES = 1
OCTETS = 1
SECONDS = 1.0
DEFAULT_QUIC_MAX_SAVED_SESSIONS = 100

FilePath = Union[AnyStr, os.PathLike]
SocketKind = Union[int, socket.SocketKind]
Expand Down Expand Up @@ -95,6 +96,8 @@ class Config:
max_requests: Optional[int] = None
max_requests_jitter: int = 0
pid_path: Optional[str] = None
quic_retry: bool = True
quic_max_saved_sessions: int = DEFAULT_QUIC_MAX_SAVED_SESSIONS
server_names: List[str] = []
shutdown_timeout = 60 * SECONDS
ssl_handshake_timeout = 60 * SECONDS
Expand Down
152 changes: 123 additions & 29 deletions src/hypercorn/protocol/quic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from functools import partial
from typing import Awaitable, Callable, Dict, Optional, Tuple
from secrets import token_bytes
from typing import Awaitable, Callable, Dict, Optional, Set, Tuple

from aioquic.buffer import Buffer
from aioquic.h3.connection import H3_ALPN
Expand All @@ -15,14 +16,31 @@
)
from aioquic.quic.packet import (
encode_quic_version_negotiation,
encode_quic_retry,
PACKET_TYPE_INITIAL,
pull_quic_header,
)
from aioquic.quic.retry import QuicRetryTokenHandler
from aioquic.tls import SessionTicket

from .h3 import H3Protocol
from ..config import Config
from ..events import Closed, Event, RawData
from ..typing import AppWrapper, TaskGroup, WorkerContext
from ..typing import AppWrapper, TaskGroup, WorkerContext, Timer


class ConnectionState:
def __init__(self, connection: QuicConnection):
self.connection = connection
self.timer: Optional[Timer] = None
self.cids: Set[bytes] = set()
self.h3_protocol: Optional[H3Protocol] = None

def add_cid(self, cid: bytes) -> None:
self.cids.add(cid)

def remove_cid(self, cid: bytes) -> None:
self.cids.remove(cid)


class QuicProtocol:
Expand All @@ -38,18 +56,23 @@ def __init__(
self.app = app
self.config = config
self.context = context
self.connections: Dict[bytes, QuicConnection] = {}
self.http_connections: Dict[QuicConnection, H3Protocol] = {}
self.connections: Dict[bytes, ConnectionState] = {}
self.send = send
self.server = server
self.task_group = task_group

self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False)
self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile)
self.retry: Optional[QuicRetryTokenHandler]
if config.quic_retry:
self.retry = QuicRetryTokenHandler()
else:
self.retry = None
self.session_tickets: Dict[bytes, bytes] = {}

@property
def idle(self) -> bool:
return len(self.connections) == 0 and len(self.http_connections) == 0
return len(self.connections) == 0

async def handle(self, event: Event) -> None:
if isinstance(event, RawData):
Expand All @@ -69,23 +92,71 @@ async def handle(self, event: Event) -> None:
await self.send(RawData(data=data, address=event.address))
return

connection = self.connections.get(header.destination_cid)
state = self.connections.get(header.destination_cid)
if state is not None:
connection = state.connection
else:
connection = None
if (
connection is None
state is None
and len(event.data) >= 1200
and header.packet_type == PACKET_TYPE_INITIAL
and not self.context.terminated.is_set()
):
cid = header.destination_cid
retry_cid = None
if self.retry is not None:
if not header.token:
if header.version is None:
return
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why return here, is a missing version an indication of an error?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retry is an option because it trades off increased latency for ensuring that you can roundtrip to the client and aren't just being used to amplify an attack (with the start of the TLS handshake) by someone spoofing UDP. I usually turn it on.

Re the missing version... if there is no header token, then this isn't the client retrying, and we have to ask them to retry, but to do that we need the version. If there isn't a version, they've sent us a versionless "short header" packet which is not a sensible thing to do, so we just drop the packet.

source_cid = token_bytes(8)
wire = encode_quic_retry(
version=header.version,
source_cid=source_cid,
destination_cid=header.source_cid,
original_destination_cid=header.destination_cid,
retry_token=self.retry.create_token(
event.address, header.destination_cid, source_cid
),
)
await self.send(RawData(data=wire, address=event.address))
return
else:
try:
(cid, retry_cid) = self.retry.validate_token(
event.address, header.token
)
if self.connections.get(cid) is not None:
# duplicate!
return
except ValueError:
return
fetcher: Optional[Callable]
handler: Optional[Callable]
if self.config.quic_max_saved_sessions > 0:
fetcher = self._get_session_ticket
handler = self._store_session_ticket
else:
fetcher = None
handler = None
connection = QuicConnection(
configuration=self.quic_config,
original_destination_connection_id=header.destination_cid,
original_destination_connection_id=cid,
retry_source_connection_id=retry_cid,
session_ticket_fetcher=fetcher,
session_ticket_handler=handler,
)
self.connections[header.destination_cid] = connection
self.connections[connection.host_cid] = connection
state = ConnectionState(connection)
timer = self.task_group.create_timer(partial(self._timeout, state))
state.timer = timer
state.add_cid(header.destination_cid)
self.connections[header.destination_cid] = state
state.add_cid(connection.host_cid)
self.connections[connection.host_cid] = state

if connection is not None:
connection.receive_datagram(event.data, event.address, now=self.context.time())
await self._handle_events(connection, event.address)
await self._wake_up_timer(state)
elif isinstance(event, Closed):
pass

Expand All @@ -94,42 +165,65 @@ async def send_all(self, connection: QuicConnection) -> None:
await self.send(RawData(data=data, address=address))

async def _handle_events(
self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None
self, state: ConnectionState, client: Optional[Tuple[str, int]] = None
) -> None:
connection = state.connection
event = connection.next_event()
while event is not None:
if isinstance(event, ConnectionTerminated):
pass
await state.timer.stop()
for cid in state.cids:
del self.connections[cid]
state.cids = set()
elif isinstance(event, ProtocolNegotiated):
self.http_connections[connection] = H3Protocol(
state.h3_protocol = H3Protocol(
self.app,
self.config,
self.context,
self.task_group,
client,
self.server,
connection,
partial(self.send_all, connection),
partial(self._wake_up_timer, state),
)
elif isinstance(event, ConnectionIdIssued):
self.connections[event.connection_id] = connection
state.add_cid(event.connection_id)
self.connections[event.connection_id] = state
elif isinstance(event, ConnectionIdRetired):
state.remove_cid(event.connection_id)
del self.connections[event.connection_id]

if connection in self.http_connections:
await self.http_connections[connection].handle(event)
elif state.h3_protocol is not None:
await state.h3_protocol.handle(event)

event = connection.next_event()

async def _wake_up_timer(self, state: ConnectionState) -> None:
# When new output is send, or new input is received, we
# fire the timer right away so we update our state.
await state.timer.schedule(0.0)

async def _timeout(self, state: ConnectionState) -> None:
connection = state.connection
now = self.context.time()
when = connection.get_timer()
if when is not None and now > when:
connection.handle_timer(now)
await self._handle_events(state, None)
await self.send_all(connection)

timer = connection.get_timer()
if timer is not None:
self.task_group.spawn(self._handle_timer, timer, connection)

async def _handle_timer(self, timer: float, connection: QuicConnection) -> None:
wait = max(0, timer - self.context.time())
await self.context.sleep(wait)
if connection._close_at is not None:
connection.handle_timer(now=self.context.time())
await self._handle_events(connection, None)
await state.timer.schedule(connection.get_timer())

def _get_session_ticket(self, ticket: bytes) -> None:
try:
self.session_tickets.pop(ticket)
except KeyError:
return None

def _store_session_ticket(self, session_ticket: SessionTicket) -> None:
self.session_tickets[session_ticket.ticket] = session_ticket
# Implement a simple FIFO remembering the self.config.quic_max_saved_sessions
# most recent sessions.
while len(self.session_tickets) > self.config.quic_max_saved_sessions:
# Grab the first key
key = next(iter(self.session_tickets.keys()))
del self.session_tickets[key]
41 changes: 40 additions & 1 deletion src/hypercorn/trio/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import trio

from ..config import Config
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope
from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope, Timer

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
Expand Down Expand Up @@ -39,6 +39,40 @@ async def _handle(
await send(None)


LONG_SLEEP = 86400.0

class TrioTimer(Timer):
def __init__(self, action: Callable) -> None:
self._action = action
self._done = False
self._wake_up = trio.Condition()
self._when: Optional[float] = None

async def schedule(self, when: Optional[float]) -> None:
self._when = when
async with self._wake_up:
self._wake_up.notify()

async def stop(self) -> None:
self._done = True
async with self._wake_up:
self._wake_up.notify()

async def run(self) -> None:
while not self._done:
if self._when is not None and trio.current_time() >= self._when:
self._when = None
await self._action()
if self._when is not None:
timeout = max(self._when - trio.current_time(), 0.0)
else:
timeout = LONG_SLEEP
if not self._done:
with trio.move_on_after(timeout):
async with self._wake_up:
await self._wake_up.wait()


class TaskGroup:
def __init__(self) -> None:
self._nursery: Optional[trio._core._run.Nursery] = None
Expand Down Expand Up @@ -67,6 +101,11 @@ async def spawn_app(
def spawn(self, func: Callable, *args: Any) -> None:
self._nursery.start_soon(func, *args)

def create_timer(self, action: Callable) -> Timer:
timer = TrioTimer(action)
self._nursery.start_soon(timer.run)
return timer

async def __aenter__(self) -> TaskGroup:
self._nursery_manager = trio.open_nursery()
self._nursery = await self._nursery_manager.__aenter__()
Expand Down
Loading