diff --git a/aiorpcx/__init__.py b/aiorpcx/__init__.py index ef2e0f1..338a7dd 100644 --- a/aiorpcx/__init__.py +++ b/aiorpcx/__init__.py @@ -1,6 +1,7 @@ from .curio import * from .framing import * from .jsonrpc import * +from .proxy_rawsocket import * from .rawsocket import * from .socks import * from .session import * @@ -15,6 +16,7 @@ __all__ = (curio.__all__ + framing.__all__ + jsonrpc.__all__ + + proxy_rawsocket.__all__ + rawsocket.__all__ + socks.__all__ + session.__all__ + diff --git a/aiorpcx/proxy_rawsocket.py b/aiorpcx/proxy_rawsocket.py new file mode 100644 index 0000000..15cfae7 --- /dev/null +++ b/aiorpcx/proxy_rawsocket.py @@ -0,0 +1,161 @@ +'''Transport implementation for PROXY protocol''' + +__all__ = ('serve_pv1rs',) + + +import asyncio +from functools import partial + +from aiorpcx.curio import Queue +from aiorpcx.rawsocket import RSTransport, ConnectionLostError +from aiorpcx.session import SessionKind +from aiorpcx.util import NetAddress + + +class ProxyProtocolError(Exception): + pass + + +class ProxyHeaderData: + '''Holds data supplied using the PROXY protocol header. + + source and dst may be None or a NetAddress instance''' + def __init__(self): + self.source = None + self.dst = None + + +class ProxyHeaderV1Data(ProxyHeaderData): + PROTOCOL_MAGIC = b'PROXY' + NETWORK_PROTOCOLS = {'TCP4', 'TCP6', 'UNKNOWN'} + + def __init__(self, version, net_proto, source, dst): + self.verify_version(version) + self.verify_net_proto(net_proto) + assert source is None or isinstance(source, NetAddress) + assert dst is None or isinstance(dst, NetAddress) + self.version = version + self.net_proto = net_proto + self.source = source + self.dst = dst + + @classmethod + def verify_version(cls, version): + if version != 1: + raise ProxyProtocolError(f"unknown PROXY protocol version {version}") + + @classmethod + def verify_net_proto(cls, net_proto): + if net_proto not in cls.NETWORK_PROTOCOLS: + raise ProxyProtocolError(f"unknown PROXY network protocol {net_proto}") + + @classmethod + def from_bytes(cls, data): + header, remainder = data.split(b' ', maxsplit=1) + assert header == cls.PROTOCOL_MAGIC + + if b' ' in remainder: + proto, src_ip, dst_ip, src_pt, dst_pt = remainder.split(b' ') + src = NetAddress(src_ip.decode('ascii'), int(src_pt)) + dst = NetAddress(dst_ip.decode('ascii'), int(dst_pt)) + else: + proto = remainder + src = None + dst = None + + return cls(1, proto.decode('ascii'), src, dst) + + +class ProxyProtocolV1Processor: + '''Receives incoming data and separates the PROXY protocol header.''' + max_size = 107 # maximum frame size for PROXY v1 protocol + + def __init__(self): + self.queue = Queue() + self.received_bytes = self.queue.put_nowait + self.residual = b'' + + async def receive_message(self): + '''Collects bytes until complete PROXY header has been received.''' + parts = [] + buffer_size = 0 + while True: + part = self.residual + self.residual = b'' + new_part = b'' + if not part: + new_part = await self.queue.get() + + joined = b''.join(parts) + parts = [joined] + part = joined + new_part + npos = part.find(b'\r\n') + if npos == -1: + parts.append(new_part) + buffer_size += len(new_part) + # Ignore over-sized messages + if buffer_size <= self.max_size or self.max_size == 0: + continue + raise ProxyProtocolError(f"Expected PROXY v1 protocol header") + + tail, self.residual = new_part[:npos], new_part[npos + 2:] + parts.append(tail) + return ProxyHeaderV1Data.from_bytes(b''.join(parts)) + + +class ProxyProtocolMixinBase: + '''Base class for handling PROXY-wrapped connections.''' + PROXY_PROCESSOR = None + + def __init__(self, *args, **kwargs): + self.process_messages = self._process_messages_proxy_init + self._proxy_processor = self.PROXY_PROCESSOR() + super().__init__(*args, **kwargs) + + async def _process_messages_proxy_init(self): + '''Process the inital PROXY protocol header''' + try: + excess_bytes = await self._receive_message_proxy_header() + except (ConnectionLostError, ProxyProtocolError): + self._closed_event.set() + await self.session.connection_lost() + except Exception as e: + self._closed_event.set() + await self.session.connection_lost() + raise e + else: + self._proxy_init_done(excess_bytes) + + async def _receive_message_proxy_header(self): + proxy_data = await self._proxy_processor.receive_message() + if proxy_data.source is not None: + self._remote_address = proxy_data.source + return self._proxy_processor.residual + + def data_received(self, data): + self._proxy_processor.received_bytes(data) + + def _proxy_init_done(self, excess_bytes): + '''Enable the underlying protocol handler and re-send extra data + received by the PROXY protocol handler.''' + self.data_received = super().data_received + self.data_received(excess_bytes) + while not self._proxy_processor.queue.empty(): + self.data_received(self._proxy_processor.queue.get_nowait()) + self.process_messages = super().process_messages + self._process_messages_task = self.loop.create_task(self.process_messages()) + self._proxy_processor = None + + +class ProxyProtocolV1Mixin(ProxyProtocolMixinBase): + PROXY_PROCESSOR = ProxyProtocolV1Processor + + +class ProxyV1RSTransport(ProxyProtocolV1Mixin, RSTransport): + pass + + +async def serve_pv1rs(session_factory, host=None, port=None, *, framer=None, loop=None, **kwargs): + loop = loop or asyncio.get_event_loop() + protocol_factory = partial(ProxyV1RSTransport, session_factory, framer, SessionKind.SERVER) + return await loop.create_server(protocol_factory, host, port, **kwargs) diff --git a/tests/test_proxy_rawsocket.py b/tests/test_proxy_rawsocket.py new file mode 100644 index 0000000..0144657 --- /dev/null +++ b/tests/test_proxy_rawsocket.py @@ -0,0 +1,160 @@ +import pytest + +import asyncio + +from aiorpcx.proxy_rawsocket import ProxyProtocolError, ProxyHeaderV1Data,\ + ProxyProtocolV1Processor, serve_pv1rs +from aiorpcx.rawsocket import RSClient +from aiorpcx.util import NetAddress +from aiorpcx import timeout_after + +from test_session import MyServerSession + + +def test_ProxyHeaderV1Data_parse_ip4(): + src_ip = '1.2.3.4' + src_pt = 123 + dst_ip = '11.22.33.44' + dst_pt = 12345 + byte_data = f'PROXY TCP4 {src_ip} {dst_ip} {src_pt} {dst_pt}' + proxy_data = ProxyHeaderV1Data.from_bytes(byte_data.encode('ascii')) + + assert proxy_data.version == 1 + assert proxy_data.net_proto == 'TCP4' + assert proxy_data.source == NetAddress(src_ip, src_pt) + assert proxy_data.dst == NetAddress(dst_ip, dst_pt) + + +def test_ProxyHeaderV1Data_parse_ip6(): + src_ip = 'ff:abcd::3:2:1' + src_pt = 123 + dst_ip = '::1' + dst_pt = 12345 + byte_data = f'PROXY TCP6 {src_ip} {dst_ip} {src_pt} {dst_pt}' + proxy_data = ProxyHeaderV1Data.from_bytes(byte_data.encode('ascii')) + + assert proxy_data.version == 1 + assert proxy_data.net_proto == 'TCP6' + assert proxy_data.source == NetAddress(src_ip, src_pt) + assert proxy_data.dst == NetAddress(dst_ip, dst_pt) + + +def test_ProxyHeaderV1Data_parse_unknown(): + byte_data = 'PROXY UNKNOWN' + proxy_data = ProxyHeaderV1Data.from_bytes(byte_data.encode('ascii')) + + assert proxy_data.version == 1 + assert proxy_data.net_proto == 'UNKNOWN' + assert proxy_data.source is None + assert proxy_data.dst is None + + +@pytest.mark.asyncio +async def test_ProxyProtocolV1Processor_simple(): + processor = ProxyProtocolV1Processor() + byte_data = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345\r\n' + processor.received_bytes(byte_data) + proxy_data = await processor.receive_message() + assert proxy_data.source is not None + assert processor.residual == b'' + + +@pytest.mark.asyncio +async def test_ProxyProtocolV1Processor_remaining(): + processor = ProxyProtocolV1Processor() + byte_data = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345\r\ntest' + processor.received_bytes(byte_data) + proxy_data = await processor.receive_message() + assert proxy_data.source is not None + assert processor.residual == b'test' + + +@pytest.mark.asyncio +async def test_ProxyProtocolV1Processor_incomplete(): + processor = ProxyProtocolV1Processor() + byte_data = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345' + processor.received_bytes(byte_data) + async with timeout_after(0.5): + with pytest.raises(asyncio.CancelledError): + await processor.receive_message() + + +@pytest.mark.asyncio +async def test_ProxyProtocolV1Processor_garbage(): + processor = ProxyProtocolV1Processor() + byte_data = b'PROXY ' * 100 + processor.received_bytes(byte_data) + with pytest.raises(ProxyProtocolError): + await processor.receive_message() + + +@pytest.mark.asyncio +async def test_ProxyProtocolV1Processor_chunked(): + processor = ProxyProtocolV1Processor() + byte_data = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345\r\n' + for byte in byte_data: + processor.received_bytes(bytes([byte])) + proxy_data = await processor.receive_message() + assert proxy_data.source is not None + assert processor.residual == b'' + + +@pytest.mark.asyncio +async def test_ProxyProtocolV1Processor_remaining_chunked(): + processor = ProxyProtocolV1Processor() + byte_data = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345\r\ntest' + processor.received_bytes(byte_data) + processor.received_bytes(b'moretest') + proxy_data = await processor.receive_message() + assert proxy_data.source is not None + leftover = processor.residual + while not processor.queue.empty(): + leftover += processor.queue.get_nowait() + assert leftover == b'testmoretest' + + +class ProxyRSClient(RSClient): + async def __aenter__(self): + _transport, protocol = await self.create_connection() + self.session = protocol.session + msg = b'PROXY TCP4 1.2.3.4 11.22.33.44 123 12345\r\n' + self.session.transport._asyncio_transport.write(msg) + return self.session + + +class ProxyServerSession(MyServerSession): + async def on_remote_addr(self): + return str(self.transport._remote_address) + + +@pytest.fixture +def server_port(unused_tcp_port, event_loop): + coro = serve_pv1rs(ProxyServerSession, 'localhost', unused_tcp_port, loop=event_loop) + server = event_loop.run_until_complete(coro) + yield unused_tcp_port + if hasattr(asyncio, 'all_tasks'): + tasks = asyncio.all_tasks(event_loop) + else: + tasks = asyncio.Task.all_tasks(loop=event_loop) + async def close_all(): + server.close() + await server.wait_closed() + if tasks: + await asyncio.wait(tasks) + event_loop.run_until_complete(close_all()) + + +@pytest.mark.asyncio +async def test_send_request(server_port): + async with ProxyRSClient('localhost', server_port) as session: + assert await session.send_request('echo', [23]) == 23 + assert session.transport._closed_event.is_set() + assert session.transport._process_messages_task.done() + + +@pytest.mark.asyncio +async def test_remote_address(server_port): + async with ProxyRSClient('localhost', server_port) as session: + assert await session.send_request('remote_addr') == '1.2.3.4:123' + assert session.transport._closed_event.is_set() + assert session.transport._process_messages_task.done()