Skip to content

Commit

Permalink
Support globbed reads
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Mar 19, 2024
1 parent c13f21e commit d4ca67a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 40 deletions.
39 changes: 37 additions & 2 deletions tests/test_xarrayfits.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,45 @@
from dask.distributed import Client, LocalCluster
import numpy as np
import pytest
import xarray

from xarrayfits import xds_from_fits


@pytest.fixture(scope="session")
def multiple_files(tmp_path_factory):
path = tmp_path_factory.mktemp("globbing")
shape = (10, 10)
data = np.arange(np.prod(shape), dtype=np.float64)
data = data.reshape(shape)

for i in range(3):
filename = str(path / f"data-{i}.fits")
primary_hdu = fits.PrimaryHDU(data)
primary_hdu.writeto(filename, overwrite=True)

return str(path / f"data*.fits")


def test_globbing(multiple_files):
datasets = xds_from_fits(multiple_files)
assert len(datasets) == 3

for xds in datasets:
expected = np.arange(np.prod(xds.hdu0.shape), dtype=np.float64)
expected = expected.reshape(xds.hdu0.shape)
np.testing.assert_array_equal(xds.hdu0.data, expected)

combined = xarray.concat(datasets, dim="hdu0-0")
np.testing.assert_array_equal(
combined.hdu0.data, np.concatenate([expected] * 3, axis=0)
)
combined = xarray.concat(datasets, dim="hdu0-1")
np.testing.assert_array_equal(
combined.hdu0.data, np.concatenate([expected] * 3, axis=1)
)


@pytest.fixture(scope="session")
def beam_cube(tmp_path_factory):
frequency = np.linspace(0.856e9, 0.856e9 * 2, 32, endpoint=True)
Expand Down Expand Up @@ -95,7 +130,7 @@ def beam_cube(tmp_path_factory):


def test_beam_creation(beam_cube):
xds = xds_from_fits(beam_cube)
(xds,) = xds_from_fits(beam_cube)
cmp_data = np.arange(np.prod(xds.hdu0.shape), dtype=np.float64)
cmp_data = cmp_data.reshape(xds.hdu0.shape)
np.testing.assert_array_equal(xds.hdu0.data, cmp_data)
Expand Down Expand Up @@ -136,7 +171,7 @@ def test_distributed(beam_cube):
cluster = stack.enter_context(LocalCluster(n_workers=8, processes=True))
stack.enter_context(Client(cluster))

xds = xds_from_fits(beam_cube, chunks={0: 100, 1: 100, 2: 15})
(xds,) = xds_from_fits(beam_cube, chunks={0: 100, 1: 100, 2: 15})
expected = np.arange(np.prod(xds.hdu0.shape)).reshape(xds.hdu0.shape)
np.testing.assert_array_equal(expected, xds.hdu0.data)
assert xds.hdu0.data.chunks == ((100, 100, 57), (100, 100, 57), (15, 15, 2))
81 changes: 43 additions & 38 deletions xarrayfits/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,44 +194,49 @@ def xds_from_fits(fits_filename, hdus=None, name_prefix="hdu", chunks=None):
Returns
-------
:class:`xarray.Dataset`
xarray Dataset containing DataArray's representing the
list of :class:`xarray.Dataset`
xarray Datasets containing DataArray's representing the
specified HDUs on the FITS file.
"""
fits_proxy = FitsProxy(fits_filename, use_fsspec=True)

# Take all hdus if None specified
if hdus is None:
hdus = list(range(len(fits_proxy.hdu_list)))
# promote to list in case of single integer
elif isinstance(hdus, int):
hdus = [hdus]

if chunks is None:
chunks = [{} for _ in hdus]
# Promote to list in case of single dict
elif isinstance(chunks, dict):
chunks = [chunks]

if not len(hdus) == len(chunks):
raise ValueError(
f"Number of requested hdus ({len(hdus)}) "
f"does not match the number of "
f"chunks ({len(chunks)})"
)

fits_proxy = FitsProxy(fits_filename)

# Generate xarray datavars for each hdu
xarrays = {
f"{name_prefix}{hdu_index}": array_from_fits_hdu(
fits_proxy,
name_prefix,
fits_proxy.hdu_list,
hdu_index,
hdu_chunks,
)
for hdu_index, hdu_chunks in zip(hdus, chunks)
}

return xr.Dataset(xarrays)
openfiles = fsspec.open_files(fits_filename)
datasets = []

for filename in (f.path for f in openfiles):
fits_proxy = FitsProxy(filename, use_fsspec=True)

# Take all hdus if None specified
if hdus is None:
hdus = list(range(len(fits_proxy.hdu_list)))
# promote to list in case of single integer
elif isinstance(hdus, int):
hdus = [hdus]

if chunks is None:
chunks = [{} for _ in hdus]
# Promote to list in case of single dict
elif isinstance(chunks, dict):
chunks = [chunks]

if not len(hdus) == len(chunks):
raise ValueError(
f"Number of requested hdus ({len(hdus)}) "
f"does not match the number of "
f"chunks ({len(chunks)})"
)

# Generate xarray datavars for each hdu
xarrays = {
f"{name_prefix}{hdu_index}": array_from_fits_hdu(
fits_proxy,
name_prefix,
fits_proxy.hdu_list,
hdu_index,
hdu_chunks,
)
for hdu_index, hdu_chunks in zip(hdus, chunks)
}

datasets.append(xr.Dataset(xarrays))

return datasets

0 comments on commit d4ca67a

Please sign in to comment.