Skip to content

Commit

Permalink
Fix regression (#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
mindflayer authored May 31, 2024
1 parent 501088e commit 0c5c07a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 25 deletions.
67 changes: 48 additions & 19 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import socket
import ssl
from datetime import datetime, timedelta
from io import BytesIO
from json.decoder import JSONDecodeError
from typing import Optional, Tuple

import urllib3
from urllib3.connection import match_hostname as urllib3_match_hostname
Expand All @@ -27,6 +27,7 @@
from .utils import (
SSL_PROTOCOL,
MocketMode,
MocketSocketCore,
get_mocketize,
hexdump,
hexload,
Expand Down Expand Up @@ -73,15 +74,15 @@


class SuperFakeSSLContext:
"""For Python 3.6"""
"""For Python 3.6 and newer."""

class FakeSetter(int):
def __set__(self, *args):
pass

minimum_version = FakeSetter()
options = FakeSetter()
verify_mode = FakeSetter(ssl.CERT_NONE)
verify_mode = FakeSetter()


class FakeSSLContext(SuperFakeSSLContext):
Expand Down Expand Up @@ -177,6 +178,7 @@ class MocketSocket:
_secure_socket = False
_did_handshake = False
_sent_non_empty_bytes = False
_io = None

def __init__(
self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
Expand All @@ -200,10 +202,18 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

@property
def fd(self):
if self._fd is None:
self._fd = BytesIO()
return self._fd
def io(self):
if self._io is None:
self._io = MocketSocketCore((self._host, self._port))
return self._io

def fileno(self):
address = (self._host, self._port)
r_fd, _ = Mocket.get_pair(address)
if not r_fd:
r_fd, w_fd = os.pipe()
Mocket.set_pair(address, (r_fd, w_fd))
return r_fd

def gettimeout(self):
return self.timeout
Expand Down Expand Up @@ -264,19 +274,14 @@ def unwrap(self):
def write(self, data):
return self.send(encode_to_bytes(data))

def fileno(self):
if self.true_socket:
return self.true_socket.fileno()
return self.fd.fileno()

def connect(self, address):
self._address = self._host, self._port = address
Mocket._address = address

def makefile(self, mode="r", bufsize=-1):
self._mode = mode
self._bufsize = bufsize
return self.fd
return self.io

def get_entry(self, data):
return Mocket.get_entry(self._host, self._port, data)
Expand All @@ -292,13 +297,13 @@ def sendall(self, data, entry=None, *args, **kwargs):
response = self.true_sendall(data, *args, **kwargs)

if response is not None:
self.fd.seek(0)
self.fd.write(response)
self.fd.truncate()
self.fd.seek(0)
self.io.seek(0)
self.io.write(response)
self.io.truncate()
self.io.seek(0)

def read(self, buffersize):
rv = self.fd.read(buffersize)
rv = self.io.read(buffersize)
if rv:
self._sent_non_empty_bytes = True
if self._did_handshake and not self._sent_non_empty_bytes:
Expand All @@ -315,6 +320,9 @@ def recv_into(self, buffer, buffersize=None, flags=None):
return len(data)

def recv(self, buffersize, flags=None):
r_fd, _ = Mocket.get_pair((self._host, self._port))
if r_fd:
return os.read(r_fd, buffersize)
data = self.read(buffersize)
if data:
return data
Expand Down Expand Up @@ -416,8 +424,8 @@ def true_sendall(self, data, *args, **kwargs):

def send(self, data, *args, **kwargs): # pragma: no cover
entry = self.get_entry(data)
kwargs["entry"] = entry
if not entry or (entry and self._entry != entry):
kwargs["entry"] = entry
self.sendall(data, *args, **kwargs)
else:
req = Mocket.last_request()
Expand All @@ -441,12 +449,29 @@ def do_nothing(*args, **kwargs):


class Mocket:
_socket_pairs = {}
_address = (None, None)
_entries = collections.defaultdict(list)
_requests = []
_namespace = text_type(id(_entries))
_truesocket_recording_dir = None

@classmethod
def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]:
"""
Given the id() of the caller, return a pair of file descriptors
as a tuple of two integers: (<read_fd>, <write_fd>)
"""
return cls._socket_pairs.get(address, (None, None))

@classmethod
def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None:
"""
Store a pair of file descriptors under the key `id_`
as a tuple of two integers: (<read_fd>, <write_fd>)
"""
cls._socket_pairs[address] = pair

@classmethod
def register(cls, *entries):
for entry in entries:
Expand All @@ -467,6 +492,10 @@ def collect(cls, data):

@classmethod
def reset(cls):
for r_fd, w_fd in cls._socket_pairs.values():
os.close(r_fd)
os.close(w_fd)
cls._socket_pairs = {}
cls._entries = collections.defaultdict(list)
cls._requests = []

Expand Down
2 changes: 1 addition & 1 deletion mocket/mockhttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def can_handle(self, data):
"""
try:
requestline, _ = decode_from_bytes(data).split(CRLF, 1)
method, path, version = self._parse_requestline(requestline)
method, path, _ = self._parse_requestline(requestline)
except ValueError:
return self is getattr(Mocket, "_last_entry", None)

Expand Down
17 changes: 17 additions & 0 deletions mocket/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import binascii
import io
import os
import ssl
from typing import TYPE_CHECKING, Any, Callable, ClassVar

Expand All @@ -14,6 +16,21 @@
SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2


class MocketSocketCore(io.BytesIO):
def __init__(self, address) -> None:
self._address = address
super().__init__()

def write(self, content):
from mocket import Mocket

super().write(content)

_, w_fd = Mocket.get_pair(self._address)
if w_fd:
os.write(w_fd, content)


def hexdump(binary_string: bytes) -> str:
r"""
>>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))
Expand Down
5 changes: 0 additions & 5 deletions tests/main/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import glob
import json
import socket
import sys
import tempfile

import aiohttp
Expand Down Expand Up @@ -45,10 +44,6 @@ async def test_asyncio_connection():


@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Looks like https://github.com/aio-libs/aiohttp/issues/5582",
)
@async_mocketize
async def test_aiohttp():
url = "https://bar.foo/"
Expand Down

0 comments on commit 0c5c07a

Please sign in to comment.