Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add server handler for PROXY (v1) protocol #24

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aiorpcx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .curio import *
from .framing import *
from .jsonrpc import *
from .proxy_rawsocket import *
from .rawsocket import *
from .socks import *
from .session import *
Expand All @@ -15,6 +16,7 @@
__all__ = (curio.__all__ +
framing.__all__ +
jsonrpc.__all__ +
proxy_rawsocket.__all__ +
rawsocket.__all__ +
socks.__all__ +
session.__all__ +
Expand Down
161 changes: 161 additions & 0 deletions aiorpcx/proxy_rawsocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
'''Transport implementation for PROXY protocol'''

__all__ = ('serve_pv1rs',)


import asyncio
from functools import partial

from aiorpcx.curio import Queue
from aiorpcx.rawsocket import RSTransport, ConnectionLostError
from aiorpcx.session import SessionKind
from aiorpcx.util import NetAddress


class ProxyProtocolError(Exception):
pass


class ProxyHeaderData:
'''Holds data supplied using the PROXY protocol header.

source and dst may be None or a NetAddress instance'''
def __init__(self):
self.source = None
self.dst = None


class ProxyHeaderV1Data(ProxyHeaderData):
PROTOCOL_MAGIC = b'PROXY'
NETWORK_PROTOCOLS = {'TCP4', 'TCP6', 'UNKNOWN'}

def __init__(self, version, net_proto, source, dst):
self.verify_version(version)
self.verify_net_proto(net_proto)
assert source is None or isinstance(source, NetAddress)
assert dst is None or isinstance(dst, NetAddress)
self.version = version
self.net_proto = net_proto
self.source = source
self.dst = dst

@classmethod
def verify_version(cls, version):
if version != 1:
raise ProxyProtocolError(f"unknown PROXY protocol version {version}")

@classmethod
def verify_net_proto(cls, net_proto):
if net_proto not in cls.NETWORK_PROTOCOLS:
raise ProxyProtocolError(f"unknown PROXY network protocol {net_proto}")

@classmethod
def from_bytes(cls, data):
header, remainder = data.split(b' ', maxsplit=1)
assert header == cls.PROTOCOL_MAGIC

if b' ' in remainder:
proto, src_ip, dst_ip, src_pt, dst_pt = remainder.split(b' ')
src = NetAddress(src_ip.decode('ascii'), int(src_pt))
dst = NetAddress(dst_ip.decode('ascii'), int(dst_pt))
else:
proto = remainder
src = None
dst = None

return cls(1, proto.decode('ascii'), src, dst)


class ProxyProtocolV1Processor:
'''Receives incoming data and separates the PROXY protocol header.'''
max_size = 107 # maximum frame size for PROXY v1 protocol

def __init__(self):
self.queue = Queue()
self.received_bytes = self.queue.put_nowait
self.residual = b''

async def receive_message(self):
'''Collects bytes until complete PROXY header has been received.'''
parts = []
buffer_size = 0
while True:
part = self.residual
self.residual = b''
new_part = b''
if not part:
new_part = await self.queue.get()

joined = b''.join(parts)
parts = [joined]
part = joined + new_part
npos = part.find(b'\r\n')
if npos == -1:
parts.append(new_part)
buffer_size += len(new_part)
# Ignore over-sized messages
if buffer_size <= self.max_size or self.max_size == 0:
continue
raise ProxyProtocolError(f"Expected PROXY v1 protocol header")

tail, self.residual = new_part[:npos], new_part[npos + 2:]
parts.append(tail)
return ProxyHeaderV1Data.from_bytes(b''.join(parts))


class ProxyProtocolMixinBase:
'''Base class for handling PROXY-wrapped connections.'''
PROXY_PROCESSOR = None

def __init__(self, *args, **kwargs):
self.process_messages = self._process_messages_proxy_init
self._proxy_processor = self.PROXY_PROCESSOR()
super().__init__(*args, **kwargs)

async def _process_messages_proxy_init(self):
'''Process the inital PROXY protocol header'''
try:
excess_bytes = await self._receive_message_proxy_header()
except (ConnectionLostError, ProxyProtocolError):
self._closed_event.set()
await self.session.connection_lost()
except Exception as e:
self._closed_event.set()
await self.session.connection_lost()
raise e
else:
self._proxy_init_done(excess_bytes)

async def _receive_message_proxy_header(self):
proxy_data = await self._proxy_processor.receive_message()
if proxy_data.source is not None:
self._remote_address = proxy_data.source
return self._proxy_processor.residual

def data_received(self, data):
self._proxy_processor.received_bytes(data)

def _proxy_init_done(self, excess_bytes):
'''Enable the underlying protocol handler and re-send extra data
received by the PROXY protocol handler.'''
self.data_received = super().data_received
self.data_received(excess_bytes)
while not self._proxy_processor.queue.empty():
self.data_received(self._proxy_processor.queue.get_nowait())
self.process_messages = super().process_messages
self._process_messages_task = self.loop.create_task(self.process_messages())
self._proxy_processor = None


class ProxyProtocolV1Mixin(ProxyProtocolMixinBase):
PROXY_PROCESSOR = ProxyProtocolV1Processor


class ProxyV1RSTransport(ProxyProtocolV1Mixin, RSTransport):
pass


async def serve_pv1rs(session_factory, host=None, port=None, *, framer=None, loop=None, **kwargs):
loop = loop or asyncio.get_event_loop()
protocol_factory = partial(ProxyV1RSTransport, session_factory, framer, SessionKind.SERVER)
return await loop.create_server(protocol_factory, host, port, **kwargs)
160 changes: 160 additions & 0 deletions tests/test_proxy_rawsocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import pytest

import asyncio

from aiorpcx.proxy_rawsocket import ProxyProtocolError, ProxyHeaderV1Data,\
ProxyProtocolV1Processor, serve_pv1rs
from aiorpcx.rawsocket import RSClient
from aiorpcx.util import NetAddress
from aiorpcx import timeout_after

from test_session import MyServerSession


def test_ProxyHeaderV1Data_parse_ip4():
src_ip = '1.2.3.4'
src_pt = 123
dst_ip = '11.22.33.44'
dst_pt = 12345
byte_data = f'PROXY TCP4 {src_ip} {dst_ip} {src_pt} {dst_pt}'
proxy_data = ProxyHeaderV1Data.from_bytes(byte_data.encode('ascii'))

assert proxy_data.version == 1
assert proxy_data.net_proto == 'TCP4'
assert proxy_data.source == NetAddress(src_ip, src_pt)
assert proxy_data.dst == NetAddress(dst_ip, dst_pt)


def test_ProxyHeaderV1Data_parse_ip6():
src_ip = 'ff:abcd::3:2:1'
src_pt = 123
dst_ip = '::1'
dst_pt = 12345
byte_data = f'PROXY TCP6 {src_ip} {dst_ip} {src_pt} {dst_pt}'
proxy_data = ProxyHeaderV1Data.from_bytes(byte_data.encode('ascii'))

assert proxy_data.version == 1
assert proxy_data.net_proto == 'TCP6'
assert proxy_data.source == NetAddress(src_ip, src_pt)
assert proxy_data.dst == NetAddress(dst_ip, dst_pt)


def test_ProxyHeaderV1Data_parse_unknown():
byte_data = 'PROXY UNKNOWN'
proxy_data = ProxyHeaderV1Data.from_bytes(byte_data.encode('ascii'))

assert proxy_data.version == 1
assert proxy_data.net_proto == 'UNKNOWN'
assert proxy_data.source is None
assert proxy_data.dst is None


@pytest.mark.asyncio
async def test_ProxyProtocolV1Processor_simple():
processor = ProxyProtocolV1Processor()
byte_data = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345\r\n'
processor.received_bytes(byte_data)
proxy_data = await processor.receive_message()
assert proxy_data.source is not None
assert processor.residual == b''


@pytest.mark.asyncio
async def test_ProxyProtocolV1Processor_remaining():
processor = ProxyProtocolV1Processor()
byte_data = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345\r\ntest'
processor.received_bytes(byte_data)
proxy_data = await processor.receive_message()
assert proxy_data.source is not None
assert processor.residual == b'test'


@pytest.mark.asyncio
async def test_ProxyProtocolV1Processor_incomplete():
processor = ProxyProtocolV1Processor()
byte_data = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345'
processor.received_bytes(byte_data)
async with timeout_after(0.5):
with pytest.raises(asyncio.CancelledError):
await processor.receive_message()


@pytest.mark.asyncio
async def test_ProxyProtocolV1Processor_garbage():
processor = ProxyProtocolV1Processor()
byte_data = b'PROXY ' * 100
processor.received_bytes(byte_data)
with pytest.raises(ProxyProtocolError):
await processor.receive_message()


@pytest.mark.asyncio
async def test_ProxyProtocolV1Processor_chunked():
processor = ProxyProtocolV1Processor()
byte_data = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345\r\n'
for byte in byte_data:
processor.received_bytes(bytes([byte]))
proxy_data = await processor.receive_message()
assert proxy_data.source is not None
assert processor.residual == b''


@pytest.mark.asyncio
async def test_ProxyProtocolV1Processor_remaining_chunked():
processor = ProxyProtocolV1Processor()
byte_data = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345\r\ntest'
processor.received_bytes(byte_data)
processor.received_bytes(b'moretest')
proxy_data = await processor.receive_message()
assert proxy_data.source is not None
leftover = processor.residual
while not processor.queue.empty():
leftover += processor.queue.get_nowait()
assert leftover == b'testmoretest'


class ProxyRSClient(RSClient):
async def __aenter__(self):
_transport, protocol = await self.create_connection()
self.session = protocol.session
msg = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345\r\n'
self.session.transport._asyncio_transport.write(msg)
return self.session


class ProxyServerSession(MyServerSession):
async def on_remote_addr(self):
return str(self.transport._remote_address)


@pytest.fixture
def server_port(unused_tcp_port, event_loop):
coro = serve_pv1rs(ProxyServerSession, 'localhost', unused_tcp_port, loop=event_loop)
server = event_loop.run_until_complete(coro)
yield unused_tcp_port
if hasattr(asyncio, 'all_tasks'):
tasks = asyncio.all_tasks(event_loop)
else:
tasks = asyncio.Task.all_tasks(loop=event_loop)
async def close_all():
server.close()
await server.wait_closed()
if tasks:
await asyncio.wait(tasks)
event_loop.run_until_complete(close_all())


@pytest.mark.asyncio
async def test_send_request(server_port):
async with ProxyRSClient('localhost', server_port) as session:
assert await session.send_request('echo', [23]) == 23
assert session.transport._closed_event.is_set()
assert session.transport._process_messages_task.done()


@pytest.mark.asyncio
async def test_remote_address(server_port):
async with ProxyRSClient('localhost', server_port) as session:
assert await session.send_request('remote_addr') == '1.2.3.4:123'
assert session.transport._closed_event.is_set()
assert session.transport._process_messages_task.done()