Skip to content

Commit

Permalink
Merge pull request #166 from icecube/fix165
Browse files Browse the repository at this point in the history
Fixes #165
  • Loading branch information
martwo authored Aug 8, 2023
2 parents eb3baf4 + afb90f4 commit 074c508
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 6 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ This file contains a log-book for major changes between releases.

v23.2.1
=======
- Add access operator support for core.dataset.DatasetCollection.

- Individual datasets of a dataset collection (``dsc``) can now be accessed
via ``dsc[name]`` or ``dsc[name1, name2, ...]``.

- Allow the definition of an origin of a dataset via the
core.dataset.DatasetOrigin class and download the dataset automatically from
the origin to the local host. The following transfer methods are provided:
Expand Down
4 changes: 2 additions & 2 deletions doc/sphinx/tutorials/publicdata_ps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The individual data sets ``IC86_II``, ``IC86_III``, ``IC86_IV``, ``IC86_V``, ``IC86_VI``, and ``IC86_VII`` are also available as a single combined data set ``IC86_II-VII``, because these data sets share the same detector simulation and event selection. Hence, we can get a list of data sets via the ``get_datasets`` method of the ``dsc`` instance:"
"The individual data sets ``IC86_II``, ``IC86_III``, ``IC86_IV``, ``IC86_V``, ``IC86_VI``, and ``IC86_VII`` are also available as a single combined data set ``IC86_II-VII``, because these data sets share the same detector simulation and event selection. Hence, we can get a list of data sets via the access operator ``[dataset1, dataset2, ...]`` of the ``dsc`` instance:"
]
},
{
Expand All @@ -167,7 +167,7 @@
"metadata": {},
"outputs": [],
"source": [
"datasets = dsc.get_datasets(['IC40', 'IC59', 'IC79', 'IC86_I', 'IC86_II-VII'])"
"datasets = dsc['IC40', 'IC59', 'IC79', 'IC86_I', 'IC86_II-VII']"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion doc/sphinx/tutorials/publicdata_ps_timedep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"dsc = create_dataset_collection(\n",
" cfg=cfg,\n",
" base_path=\"/home/mwolf/projects/publicdata_ps/\")\n",
"datasets = dsc.get_datasets([\"IC86_II-VII\"])"
"datasets = dsc[\"IC86_II-VII\", ]"
]
},
{
Expand Down
33 changes: 33 additions & 0 deletions skyllh/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,6 +2382,39 @@ def verqualifiers(self):
ds_name = list(self._datasets.keys())[0]
return self._datasets[ds_name].verqualifiers

def __getitem__(
self,
key,
):
"""Implementation of the access operator ``[key]``.
Parameters
----------
key : str | sequence of str
The name or names of the dataset(s) that should get retrieved from
this dataset collection.
Returns
-------
datasets : instance of Dataset | list of instance of Dataset
The dataset instance or the list of dataset instances corresponding
to the given key.
"""
if not issequence(key):
return self.get_dataset(key)

if not issequenceof(key, str):
raise TypeError(
'The key for the access operator must be an instance of str or '
'a sequence of str instances!')

datasets = [
self.get_dataset(name)
for name in key
]

return datasets

def __iadd__(self, ds):
"""Implementation of the ``self += dataset`` operation to add a
Dataset object to this dataset collection.
Expand Down
4 changes: 2 additions & 2 deletions skyllh/datasets/i3/PublicData_10y_ps_wMC.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def create_dataset_collection(
IC86_VI,
IC86_VII,
IC86_II_VII,
) = dsc.get_datasets((
) = dsc[
'IC40',
'IC59',
'IC79',
Expand All @@ -72,7 +72,7 @@ def create_dataset_collection(
'IC86_VI',
'IC86_VII',
'IC86_II-VII',
))
]
IC40.mc_pathfilename_list = 'sim/IC40_MC.npy'
IC59.mc_pathfilename_list = 'sim/IC59_MC.npy'
IC79.mc_pathfilename_list = 'sim/IC79_MC.npy'
Expand Down
27 changes: 27 additions & 0 deletions tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from skyllh.core.dataset import (
get_data_subset,
Dataset,
DatasetData,
DatasetOrigin,
DatasetTransferError,
Expand All @@ -21,9 +22,13 @@
from skyllh.core.storage import (
DataFieldRecordArray,
)

from skyllh.datasets.i3 import (
TestData,
)
from skyllh.datasets.i3.PublicData_10y_ps import (
create_dataset_collection,
)


class TestRSYNCDatasetTransfer(
Expand Down Expand Up @@ -191,5 +196,27 @@ def test_get_data_subset(self):
self.assertAlmostEqual(livetime_subset.livetime, 0.75)


class TestDatasetCollection(
unittest.TestCase,
):
def setUp(self) -> None:
self.cfg = Config()
self.dsc = create_dataset_collection(cfg=self.cfg)

def test__getitem__single(self):
ds = self.dsc['IC40']
self.assertIsInstance(ds, Dataset)
self.assertEqual(ds.name, 'IC40')

def test__getitem__multi(self):
ds_list = self.dsc['IC59', 'IC40']
self.assertIsInstance(ds_list, list)
self.assertEqual(len(ds_list), 2)
self.assertIsInstance(ds_list[0], Dataset)
self.assertIsInstance(ds_list[1], Dataset)
self.assertEqual(ds_list[0].name, 'IC59')
self.assertEqual(ds_list[1].name, 'IC40')


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion tests/core/test_signal_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def create_signal_generator(

dataset_name = 'PointSourceTracks_v004p00'
dsc = datasets.data_samples[dataset_name].create_dataset_collection(cfg=cfg)
ds_list = dsc.get_datasets(['IC86, 2018', 'IC86, 2019'])
ds_list = dsc['IC86, 2018', 'IC86, 2019']

data_list = [ds.load_and_prepare_data() for ds in ds_list]

Expand Down

0 comments on commit 074c508

Please sign in to comment.