Skip to content

Commit

Permalink
refactor: allow specification of gateway address via API parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
d70-t committed Oct 14, 2024
1 parent 350c565 commit 729e3ed
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
11 changes: 8 additions & 3 deletions ipfsspec/async_ipfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,17 @@ def gateway_from_file(gateway_path, protocol="ipfs"):


@lru_cache
def get_gateway(protocol="ipfs"):
def get_gateway(protocol="ipfs", gateway_addr=None):
"""
Get IPFS gateway according to IPIP-280
see: https://github.com/ipfs/specs/pull/280
"""

if gateway_addr:
logger.debug("using IPFS gateway as specified via function argument: %s", gateway_addr)
return AsyncIPFSGateway(gateway_addr, protocol)

# IPFS_GATEWAY environment variable should override everything
ipfs_gateway = os.environ.get("IPFS_GATEWAY", "")
if ipfs_gateway:
Expand Down Expand Up @@ -263,19 +267,20 @@ class AsyncIPFSFileSystem(AsyncFileSystem):
sep = "/"
protocol = "ipfs"

def __init__(self, asynchronous=False, loop=None, client_kwargs=None, **storage_options):
def __init__(self, asynchronous=False, loop=None, client_kwargs=None, gateway_addr=None, **storage_options):
super().__init__(self, asynchronous=asynchronous, loop=loop, **storage_options)
self._session = None

self.client_kwargs = client_kwargs or {}
self.get_client = get_client
self.gateway_addr = gateway_addr

if not asynchronous:
sync(self.loop, self.set_session)

@property
def gateway(self):
return get_gateway(self.protocol)
return get_gateway(self.protocol, gateway_addr=self.gateway_addr)

@staticmethod
def close_session(loop, session):
Expand Down
5 changes: 3 additions & 2 deletions test/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ async def get_client(**kwargs):


@pytest_asyncio.fixture
async def fs(get_client):
async def fs(request, get_client):
AsyncIPFSFileSystem.clear_instance_cache() # avoid reusing old event loop
return AsyncIPFSFileSystem(asynchronous=True, loop=asyncio.get_running_loop(), get_client=get_client)
gateway_addr = getattr(request, "param", None)
return AsyncIPFSFileSystem(asynchronous=True, loop=asyncio.get_running_loop(), get_client=get_client, gateway_addr=gateway_addr)


@pytest.mark.parametrize("gw_host", ["http://127.0.0.1:8080"])
Expand Down

0 comments on commit 729e3ed

Please sign in to comment.