Skip to content

Commit

Permalink
Support SOCKS over Unix domain sockets
Browse files Browse the repository at this point in the history
Fixes #39
  • Loading branch information
JeremyRand committed Apr 30, 2021
1 parent d3f5ac5 commit 7910789
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
17 changes: 12 additions & 5 deletions aiorpcx/socks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import struct
from functools import partial

from aiorpcx.util import NetAddress
from aiorpcx.util import NetAddress, UnixAddress


__all__ = ('SOCKSUserAuth', 'SOCKSRandomAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy',
Expand Down Expand Up @@ -272,11 +272,11 @@ def _connect_response_rest(self, addr_len):
class SOCKSProxy:

def __init__(self, address, protocol, auth):
'''A SOCKS proxy at a NetAddress following a SOCKS protocol.
'''A SOCKS proxy at a NetAddress or UnixAddress following a SOCKS protocol.
auth is an authentication method to use when connecting, or None.
'''
if not isinstance(address, NetAddress):
if not isinstance(address, (NetAddress, UnixAddress)):
address = NetAddress.from_string(address)
self.address = address
self.protocol = protocol
Expand Down Expand Up @@ -314,8 +314,15 @@ async def _connect_one(self, remote_address):
'''
loop = asyncio.get_event_loop()

for info in await loop.getaddrinfo(str(self.address.host), self.address.port,
type=socket.SOCK_STREAM):
if isinstance(self.address, UnixAddress):
# Unix socket
infos = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, '', self.address.path)]
else:
# IP socket
infos = await loop.getaddrinfo(str(self.address.host), self.address.port,
type=socket.SOCK_STREAM)

for info in infos:
# This object has state so is only good for one connection
client = self.protocol(remote_address, self.auth)
sock = socket.socket(family=info[0])
Expand Down
27 changes: 26 additions & 1 deletion aiorpcx/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

__all__ = ('instantiate_coroutine', 'is_valid_hostname', 'classify_host',
'validate_port', 'validate_protocol', 'Service', 'ServicePart', 'NetAddress')
'validate_port', 'validate_protocol', 'Service', 'ServicePart',
'NetAddress', 'UnixAddress')


import asyncio
Expand Down Expand Up @@ -181,6 +182,30 @@ def default_port(cls, port):
return cls.default_host_and_port(None, port)


class UnixAddress:

def __init__(self, path: str):
'''Construct a UnixAddress from a path.'''
self._path = path

def __eq__(self, other):
# pylint: disable=protected-access
return isinstance(other, UnixAddress) and self._path == other._path

def __hash__(self):
return hash((self._path))

@property
def path(self):
return self._path

def __str__(self):
return f'{self.path}'

def __repr__(self):
return f'UnixAddress({self.path!r})'


class Service:
'''A validated protocol, address pair.'''

Expand Down

0 comments on commit 7910789

Please sign in to comment.