Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIOHTTPTransport default ssl cert validation add warning #530

2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include tox.ini

include gql/py.typed

recursive-include tests *.py *.graphql *.cnf *.yaml *.pem
recursive-include tests *.py *.graphql *.cnf *.yaml *.pem *.crt
recursive-include docs *.txt *.rst conf.py Makefile make.bat
recursive-include docs/code_examples *.py

Expand Down
30 changes: 27 additions & 3 deletions gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,19 @@
import io
import json
import logging
import warnings
from ssl import SSLContext
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Optional,
Tuple,
Type,
Union,
cast,
)

import aiohttp
from aiohttp.client_exceptions import ClientResponseError
Expand Down Expand Up @@ -46,7 +57,7 @@ def __init__(
headers: Optional[LooseHeaders] = None,
cookies: Optional[LooseCookies] = None,
auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None,
ssl: Union[SSLContext, bool, Fingerprint] = False,
ssl: Union[SSLContext, bool, Fingerprint, str] = "ssl_warning",
timeout: Optional[int] = None,
ssl_close_timeout: Optional[Union[int, float]] = 10,
json_serialize: Callable = json.dumps,
Expand Down Expand Up @@ -77,7 +88,20 @@ def __init__(
self.headers: Optional[LooseHeaders] = headers
self.cookies: Optional[LooseCookies] = cookies
self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth
self.ssl: Union[SSLContext, bool, Fingerprint] = ssl

if ssl == "ssl_warning":
ssl = False
if str(url).startswith("https"):
warnings.warn(
"WARNING: By default, AIOHTTPTransport does not verify"
" ssl certificates. This will be fixed in the next major version."
" You can set ssl=True to force the ssl certificate verification"
" or ssl=False to disable this warning"
)

self.ssl: Union[SSLContext, bool, Fingerprint] = cast(
Union[SSLContext, bool, Fingerprint], ssl
)
self.timeout: Optional[int] = timeout
self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout
self.client_session_args = client_session_args
Expand Down
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,29 @@ def get_localhost_ssl_context():
return (testcert, ssl_context)


def get_localhost_ssl_context_client():
"""
Create a client-side SSL context that verifies the specific self-signed certificate
used for our test.
"""
# Get the certificate from the server setup
cert_path = bytes(pathlib.Path(__file__).with_name("test_localhost_client.crt"))

# Create client SSL context
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

# Load just the certificate part as a trusted CA
ssl_context.load_verify_locations(cafile=cert_path)

# Require certificate verification
ssl_context.verify_mode = ssl.CERT_REQUIRED

# Enable hostname checking for localhost
ssl_context.check_hostname = True

return cert_path, ssl_context


class WebSocketServer:
"""Websocket server on localhost on a free port.

Expand Down
84 changes: 81 additions & 3 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
TransportServerError,
)

from .conftest import TemporaryFile, strip_braces_spaces
from .conftest import (
TemporaryFile,
get_localhost_ssl_context_client,
strip_braces_spaces,
)

query1_str = """
query getContinents {
Expand Down Expand Up @@ -1285,7 +1289,10 @@ async def handler(request):

@pytest.mark.asyncio
@pytest.mark.parametrize("ssl_close_timeout", [0, 10])
async def test_aiohttp_query_https(event_loop, ssl_aiohttp_server, ssl_close_timeout):
@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"])
async def test_aiohttp_query_https(
event_loop, ssl_aiohttp_server, ssl_close_timeout, verify_https
):
from aiohttp import web
from gql.transport.aiohttp import AIOHTTPTransport

Expand All @@ -1300,8 +1307,20 @@ async def handler(request):

assert str(url).startswith("https://")

extra_args = {}

if verify_https == "cert_provided":
_, ssl_context = get_localhost_ssl_context_client()

extra_args["ssl"] = ssl_context
elif verify_https == "disabled":
extra_args["ssl"] = False

transport = AIOHTTPTransport(
url=url, timeout=10, ssl_close_timeout=ssl_close_timeout
url=url,
timeout=10,
ssl_close_timeout=ssl_close_timeout,
**extra_args,
)

async with Client(transport=transport) as session:
Expand All @@ -1318,6 +1337,65 @@ async def handler(request):
assert africa["code"] == "AF"


@pytest.mark.skip(reason="We will change the default to fix this in a future version")
@pytest.mark.asyncio
async def test_aiohttp_query_https_self_cert_fail(event_loop, ssl_aiohttp_server):
"""By default, we should verify the ssl certificate"""
from aiohttp.client_exceptions import ClientConnectorCertificateError
from aiohttp import web
from gql.transport.aiohttp import AIOHTTPTransport

async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await ssl_aiohttp_server(app)

url = server.make_url("/")

assert str(url).startswith("https://")

transport = AIOHTTPTransport(url=url, timeout=10)

with pytest.raises(ClientConnectorCertificateError) as exc_info:
async with Client(transport=transport) as session:
query = gql(query1_str)

# Execute query asynchronously
await session.execute(query)

expected_error = "certificate verify failed: self-signed certificate"

assert expected_error in str(exc_info.value)
assert transport.session is None


@pytest.mark.asyncio
async def test_aiohttp_query_https_self_cert_warn(event_loop, ssl_aiohttp_server):
from aiohttp import web
from gql.transport.aiohttp import AIOHTTPTransport

async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await ssl_aiohttp_server(app)

url = server.make_url("/")

assert str(url).startswith("https://")

expected_warning = (
"WARNING: By default, AIOHTTPTransport does not verify ssl certificates."
" This will be fixed in the next major version."
)

with pytest.warns(Warning, match=expected_warning):
AIOHTTPTransport(url=url, timeout=10)


@pytest.mark.asyncio
async def test_aiohttp_error_fetching_schema(event_loop, aiohttp_server):
from aiohttp import web
Expand Down
63 changes: 57 additions & 6 deletions tests/test_aiohttp_websocket_query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import json
import ssl
import sys
from typing import Dict, Mapping

Expand All @@ -14,7 +13,7 @@
TransportServerError,
)

from .conftest import MS, WebSocketServerHelper
from .conftest import MS, WebSocketServerHelper, get_localhost_ssl_context_client

# Marking all tests in this file with the aiohttp AND websockets marker
pytestmark = pytest.mark.aiohttp
Expand Down Expand Up @@ -92,8 +91,9 @@ async def test_aiohttp_websocket_starting_client_in_context_manager(
@pytest.mark.websockets
@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True)
@pytest.mark.parametrize("ssl_close_timeout", [0, 10])
@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"])
async def test_aiohttp_websocket_using_ssl_connection(
event_loop, ws_ssl_server, ssl_close_timeout
event_loop, ws_ssl_server, ssl_close_timeout, verify_https
):

from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport
Expand All @@ -103,11 +103,19 @@ async def test_aiohttp_websocket_using_ssl_connection(
url = f"wss://{server.hostname}:{server.port}/graphql"
print(f"url = {url}")

ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_verify_locations(ws_ssl_server.testcert)
extra_args = {}

if verify_https == "cert_provided":
_, ssl_context = get_localhost_ssl_context_client()

extra_args["ssl"] = ssl_context
elif verify_https == "disabled":
extra_args["ssl"] = False

transport = AIOHTTPWebsocketsTransport(
url=url, ssl=ssl_context, ssl_close_timeout=ssl_close_timeout
url=url,
ssl_close_timeout=ssl_close_timeout,
**extra_args,
)

async with Client(transport=transport) as session:
Expand All @@ -130,6 +138,49 @@ async def test_aiohttp_websocket_using_ssl_connection(
assert transport.websocket is None


@pytest.mark.asyncio
@pytest.mark.websockets
@pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True)
@pytest.mark.parametrize("ssl_close_timeout", [10])
@pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"])
async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail(
event_loop, ws_ssl_server, ssl_close_timeout, verify_https
):

from aiohttp.client_exceptions import ClientConnectorCertificateError
from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport

server = ws_ssl_server

url = f"wss://{server.hostname}:{server.port}/graphql"
print(f"url = {url}")

extra_args = {}

if verify_https == "explicitely_enabled":
extra_args["ssl"] = True

transport = AIOHTTPWebsocketsTransport(
url=url,
ssl_close_timeout=ssl_close_timeout,
**extra_args,
)

with pytest.raises(ClientConnectorCertificateError) as exc_info:
async with Client(transport=transport) as session:

query1 = gql(query1_str)

await session.execute(query1)

expected_error = "certificate verify failed: self-signed certificate"

assert expected_error in str(exc_info.value)

# Check client is disconnect here
assert transport.websocket is None


@pytest.mark.asyncio
@pytest.mark.websockets
@pytest.mark.parametrize("server", [server1_answers], indirect=True)
Expand Down
Loading
Loading