Skip to content

Commit

Permalink
Merge pull request #45 from NabuCasa/dev
Browse files Browse the repository at this point in the history
Release 0.19
  • Loading branch information
pvizeli authored Mar 28, 2019
2 parents 031833a + 7c2fda5 commit e9e2a44
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 39 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup

VERSION = "0.18"
VERSION = "0.19"

setup(
name="snitun",
Expand Down
3 changes: 2 additions & 1 deletion snitun/client/client_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ async def stop(self) -> None:
"""Stop connection to SniTun server."""
if not self._multiplexer:
raise RuntimeError("No SniTun connection available")
await self._multiplexer.shutdown()
self._multiplexer.shutdown()
await self._multiplexer.wait()

async def _handler(self) -> None:
"""Wait until connection is closed."""
Expand Down
10 changes: 6 additions & 4 deletions snitun/multiplexer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,13 @@ def wait(self) -> asyncio.Task:
"""
return asyncio.shield(self._processing_task)

async def shutdown(self):
def shutdown(self):
"""Shutdown connection."""
if self._processing_task.done():
return

_LOGGER.debug("Cancel connection")
self._processing_task.cancel()

self._graceful_channel_shutdown()

def _graceful_channel_shutdown(self):
Expand All @@ -95,7 +94,7 @@ async def ping(self):

except (OSError, asyncio.TimeoutError):
_LOGGER.error("Ping fails, no response from peer")
self._loop.create_task(self.shutdown())
self._loop.call_soon(self.shutdown)
raise MultiplexerTransportError() from None

async def _runner(self):
Expand Down Expand Up @@ -242,7 +241,10 @@ async def _process_message(self, message: MultiplexerMessage) -> None:

ip_address = bytes_to_ip_address(message.extra[1:5])
channel = MultiplexerChannel(
self._queue, ip_address, channel_id=message.channel_id, throttling=self._throttling
self._queue,
ip_address,
channel_id=message.channel_id,
throttling=self._throttling,
)
self._channels[channel.uuid] = channel
self._loop.create_task(self._new_connections(self, channel))
Expand Down
4 changes: 2 additions & 2 deletions snitun/server/listener_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ async def handle_connection(
# Connection closed before data received
if not fernet_data:
return

peer = self._peer_manager.register_peer(fernet_data)
peer = self._peer_manager.create_peer(fernet_data)

# Start multiplexer
await peer.init_multiplexer_challenge(reader, writer)

self._peer_manager.add_peer(peer)
while peer.is_connected:
try:
async with async_timeout.timeout(CHECK_VALID_EXPIRE):
Expand Down
2 changes: 1 addition & 1 deletion snitun/server/listener_sni.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def _proxy_peer(
await multiplexer.delete_channel(channel)

except asyncio.TimeoutError:
_LOGGER.warning("Close TCP session after timeout for %s", channel.uuid)
_LOGGER.debug("Close TCP session after timeout for %s", channel.uuid)
with suppress(MultiplexerTransportError):
await multiplexer.delete_channel(channel)

Expand Down
6 changes: 3 additions & 3 deletions snitun/server/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import hashlib
import logging
import os
from typing import Optional
from typing import Optional, Coroutine

from ..exceptions import MultiplexerTransportDecrypt, SniTunChallengeError
from ..multiplexer.core import Multiplexer
Expand Down Expand Up @@ -91,10 +91,10 @@ async def init_multiplexer_challenge(
self._crypto, reader, writer, throttling=self._throttling
)

def wait_disconnect(self) -> asyncio.Task:
def wait_disconnect(self) -> Coroutine:
"""Wait until peer is disconnected.
Return awaitable object.
Return a coroutine.
"""
if not self._multiplexer:
raise RuntimeError("No Transport initialize for peer")
Expand Down
25 changes: 16 additions & 9 deletions snitun/server/peer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def __init__(self, fernet_tokens: List[str], throttling: Optional[int] = None):
self._peers = {}

@property
def connections(self):
def connections(self) -> int:
"""Return count of connected devices."""
return len(self._peers)

def register_peer(self, fernet_data: bytes) -> Peer:
def create_peer(self, fernet_data: bytes) -> Peer:
"""Create a new peer from crypt config."""
try:
data = self._fernet.decrypt(fernet_data).decode()
Expand All @@ -46,21 +46,28 @@ def register_peer(self, fernet_data: bytes) -> Peer:
aes_key = bytes.fromhex(config["aes_key"])
aes_iv = bytes.fromhex(config["aes_iv"])

peer = self._peers[hostname] = Peer(
hostname, valid, aes_key, aes_iv, throttling=self._throttling
)
return peer
return Peer(hostname, valid, aes_key, aes_iv, throttling=self._throttling)

def remove_peer(self, peer: Peer):
def add_peer(self, peer: Peer) -> None:
"""Register peer to internal hostname list."""
if self.peer_available(peer.hostname):
_LOGGER.warning("Found stale peer connection")
self._peers[peer.hostname].multiplexer.shutdown()

self._peers[peer.hostname] = peer

def remove_peer(self, peer: Peer) -> None:
"""Remove peer from list."""
self._peers.pop(peer.hostname, None)
if self._peers.get(peer.hostname) != peer:
return
self._peers.pop(peer.hostname)

def peer_available(self, hostname: str) -> bool:
"""Check if peer available and return True or False."""
if hostname in self._peers:
return self._peers[hostname].is_ready
return False

def get_peer(self, hostname: str) -> Peer:
def get_peer(self, hostname: str) -> Optional[Peer]:
"""Get peer."""
return self._peers.get(hostname)
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def mock_new_channel(multiplexer, channel):

yield multiplexer

await multiplexer.shutdown()
multiplexer.shutdown()
client.close.set()


Expand All @@ -118,7 +118,7 @@ async def mock_new_channel(multiplexer, channel):

yield multiplexer

await multiplexer.shutdown()
multiplexer.shutdown()


@pytest.fixture
Expand Down
14 changes: 7 additions & 7 deletions tests/multiplexer/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def test_init_multiplexer_server(test_server, test_client, crypto_transpor

assert multiplexer.is_connected
assert multiplexer._throttling is None
await multiplexer.shutdown()
multiplexer.shutdown()
client.close.set()


Expand All @@ -30,7 +30,7 @@ async def test_init_multiplexer_client(test_client, crypto_transport):

assert multiplexer.is_connected
assert multiplexer._throttling is None
await multiplexer.shutdown()
multiplexer.shutdown()


async def test_init_multiplexer_server_throttling(
Expand All @@ -45,7 +45,7 @@ async def test_init_multiplexer_server_throttling(

assert multiplexer.is_connected
assert multiplexer._throttling == 0.002
await multiplexer.shutdown()
multiplexer.shutdown()
client.close.set()


Expand All @@ -57,15 +57,15 @@ async def test_init_multiplexer_client_throttling(test_client, crypto_transport)

assert multiplexer.is_connected
assert multiplexer._throttling == 0.002
await multiplexer.shutdown()
multiplexer.shutdown()


async def test_multiplexer_server_close(multiplexer_server, multiplexer_client):
"""Test a close from server peers."""
assert multiplexer_server.is_connected
assert multiplexer_client.is_connected

await multiplexer_server.shutdown()
multiplexer_server.shutdown()
await asyncio.sleep(0.1)

assert not multiplexer_server.is_connected
Expand All @@ -77,7 +77,7 @@ async def test_multiplexer_client_close(multiplexer_server, multiplexer_client):
assert multiplexer_server.is_connected
assert multiplexer_client.is_connected

await multiplexer_client.shutdown()
multiplexer_client.shutdown()
await asyncio.sleep(0.1)

assert not multiplexer_server.is_connected
Expand Down Expand Up @@ -269,7 +269,7 @@ async def test_multiplexer_channel_shutdown(
assert not client_read.done()
assert not server_read.done()

await multiplexer_client.shutdown()
multiplexer_client.shutdown()
await asyncio.sleep(0.1)
assert not multiplexer_client._channels
assert client_read.done()
Expand Down
2 changes: 1 addition & 1 deletion tests/server/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def mock_new_channel(multiplexer, channel):
assert peer_address
assert peer_address[0] == IP_ADDR

await multiplexer.shutdown()
multiplexer.shutdown()
await multiplexer.wait()
await asyncio.sleep(0.1)

Expand Down
Loading

0 comments on commit e9e2a44

Please sign in to comment.