From 729e3eddc91aa83b9a78adf0b209cecd65abece2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20K=C3=B6lling?= Date: Tue, 15 Oct 2024 00:52:41 +0200 Subject: [PATCH] refactor: allow specification of gateway address via API parameter --- ipfsspec/async_ipfs.py | 11 ++++++++--- test/test_async.py | 5 +++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ipfsspec/async_ipfs.py b/ipfsspec/async_ipfs.py index a6dc270..726866d 100644 --- a/ipfsspec/async_ipfs.py +++ b/ipfsspec/async_ipfs.py @@ -180,13 +180,17 @@ def gateway_from_file(gateway_path, protocol="ipfs"): @lru_cache -def get_gateway(protocol="ipfs"): +def get_gateway(protocol="ipfs", gateway_addr=None): """ Get IPFS gateway according to IPIP-280 see: https://github.com/ipfs/specs/pull/280 """ + if gateway_addr: + logger.debug("using IPFS gateway as specified via function argument: %s", gateway_addr) + return AsyncIPFSGateway(gateway_addr, protocol) + # IPFS_GATEWAY environment variable should override everything ipfs_gateway = os.environ.get("IPFS_GATEWAY", "") if ipfs_gateway: @@ -263,19 +267,20 @@ class AsyncIPFSFileSystem(AsyncFileSystem): sep = "/" protocol = "ipfs" - def __init__(self, asynchronous=False, loop=None, client_kwargs=None, **storage_options): + def __init__(self, asynchronous=False, loop=None, client_kwargs=None, gateway_addr=None, **storage_options): super().__init__(self, asynchronous=asynchronous, loop=loop, **storage_options) self._session = None self.client_kwargs = client_kwargs or {} self.get_client = get_client + self.gateway_addr = gateway_addr if not asynchronous: sync(self.loop, self.set_session) @property def gateway(self): - return get_gateway(self.protocol) + return get_gateway(self.protocol, gateway_addr=self.gateway_addr) @staticmethod def close_session(loop, session): diff --git a/test/test_async.py b/test/test_async.py index ebaf6f3..afec6eb 100644 --- a/test/test_async.py +++ b/test/test_async.py @@ -22,9 +22,10 @@ async def get_client(**kwargs): @pytest_asyncio.fixture -async def fs(get_client): +async def fs(request, get_client): AsyncIPFSFileSystem.clear_instance_cache() # avoid reusing old event loop - return AsyncIPFSFileSystem(asynchronous=True, loop=asyncio.get_running_loop(), get_client=get_client) + gateway_addr = getattr(request, "param", None) + return AsyncIPFSFileSystem(asynchronous=True, loop=asyncio.get_running_loop(), get_client=get_client, gateway_addr=gateway_addr) @pytest.mark.parametrize("gw_host", ["http://127.0.0.1:8080"])