From 7910789a072c555f1a62b8325a3fed9eb131daf9 Mon Sep 17 00:00:00 2001 From: Jeremy Rand Date: Thu, 8 Apr 2021 08:59:13 +0000 Subject: [PATCH] Support SOCKS over Unix domain sockets Fixes https://github.com/kyuupichan/aiorpcX/issues/39 --- aiorpcx/socks.py | 17 ++++++++++++----- aiorpcx/util.py | 27 ++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/aiorpcx/socks.py b/aiorpcx/socks.py index 4a6cefa..e6fe775 100755 --- a/aiorpcx/socks.py +++ b/aiorpcx/socks.py @@ -33,7 +33,7 @@ import struct from functools import partial -from aiorpcx.util import NetAddress +from aiorpcx.util import NetAddress, UnixAddress __all__ = ('SOCKSUserAuth', 'SOCKSRandomAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy', @@ -272,11 +272,11 @@ def _connect_response_rest(self, addr_len): class SOCKSProxy: def __init__(self, address, protocol, auth): - '''A SOCKS proxy at a NetAddress following a SOCKS protocol. + '''A SOCKS proxy at a NetAddress or UnixAddress following a SOCKS protocol. auth is an authentication method to use when connecting, or None. ''' - if not isinstance(address, NetAddress): + if not isinstance(address, (NetAddress, UnixAddress)): address = NetAddress.from_string(address) self.address = address self.protocol = protocol @@ -314,8 +314,15 @@ async def _connect_one(self, remote_address): ''' loop = asyncio.get_event_loop() - for info in await loop.getaddrinfo(str(self.address.host), self.address.port, - type=socket.SOCK_STREAM): + if isinstance(self.address, UnixAddress): + # Unix socket + infos = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, '', self.address.path)] + else: + # IP socket + infos = await loop.getaddrinfo(str(self.address.host), self.address.port, + type=socket.SOCK_STREAM) + + for info in infos: # This object has state so is only good for one connection client = self.protocol(remote_address, self.auth) sock = socket.socket(family=info[0]) diff --git a/aiorpcx/util.py b/aiorpcx/util.py index d8308c1..09493cb 100755 --- a/aiorpcx/util.py +++ b/aiorpcx/util.py @@ -24,7 +24,8 @@ # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. __all__ = ('instantiate_coroutine', 'is_valid_hostname', 'classify_host', - 'validate_port', 'validate_protocol', 'Service', 'ServicePart', 'NetAddress') + 'validate_port', 'validate_protocol', 'Service', 'ServicePart', + 'NetAddress', 'UnixAddress') import asyncio @@ -181,6 +182,30 @@ def default_port(cls, port): return cls.default_host_and_port(None, port) +class UnixAddress: + + def __init__(self, path: str): + '''Construct a UnixAddress from a path.''' + self._path = path + + def __eq__(self, other): + # pylint: disable=protected-access + return isinstance(other, UnixAddress) and self._path == other._path + + def __hash__(self): + return hash((self._path)) + + @property + def path(self): + return self._path + + def __str__(self): + return f'{self.path}' + + def __repr__(self): + return f'UnixAddress({self.path!r})' + + class Service: '''A validated protocol, address pair.'''