Skip to content

Commit

Permalink
Enh/storages (#100)
Browse files Browse the repository at this point in the history
* ADD: ArrayStorage

* ADD: according tests

* ENH: finished modifying tests

* ENH: documentatiopn

* ADD: adapted notebook for Storages

* FIX: test
  • Loading branch information
VincentAuriau authored Jun 7, 2024
1 parent e78136c commit 731d0db
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 28 deletions.
147 changes: 145 additions & 2 deletions choice_learn/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,67 @@ def __str__(self):
return f"FeatureStorage with name {self.name}"


class FeaturesStorage(Storage):
class FeaturesStorage(object):
"""Base FeaturesStorage class that redirects toward the right one."""

def __new__(
cls,
ids=None,
values=None,
values_names=None,
name=None,
as_one_hot=False,
):
"""Redirects toward the right object.
Parameters
----------
ids : Iterable, optional
IDs to references features with, by default None
values : Iterable, optional
Features to be stored, by default None
values_names : list, optional
List of names for the features to be stored, by default None
name : str, optional
Name of the FeaturesStorage, to be matched in ChoiceDataset, by default None
as_one_hot: bool
Whether features are OneHot representations or not.
Returns
-------
FeaturesStorage
One of ArrayStorage, DictStorage or OneHotStorage
"""
if as_one_hot:
return OneHotStorage(ids=ids, values=values, name=name)

if ids is None and (isinstance(values, np.ndarray) or isinstance(values, list)):
return ArrayStorage(values=values, values_names=values_names, name=name)

if ids is not None:
check_ids = np.unique(ids) == np.arange(len(ids))
if isinstance(check_ids, np.ndarray):
check_ids = check_ids.all()
if check_ids:
values = [values[np.where(np.array(ids) == i)[0][0]] for i in np.arange(len(ids))]
return ArrayStorage(values=values, values_names=values_names, name=name)

return DictStorage(
ids=ids, values=values, values_names=values_names, name=name, indexer=StorageIndexer
)


class DictStorage(Storage):
"""Function to store features with ids."""

def __init__(self, ids=None, values=None, values_names=None, name=None, indexer=StorageIndexer):
def __init__(
self,
ids=None,
values=None,
values_names=None,
name=None,
indexer=StorageIndexer,
):
"""Build the store.
Parameters
Expand All @@ -70,6 +127,7 @@ def __init__(self, ids=None, values=None, values_names=None, name=None, indexer=
name: string, optional
name of the features store
"""
print("DictStorage")
if isinstance(values, dict):
storage = values
lengths = []
Expand Down Expand Up @@ -169,6 +227,91 @@ def batch(self):
return self.indexer


class ArrayStorage(Storage):
"""Function to store features with ids as NumPy Array."""

def __init__(self, values=None, values_names=None, name=None):
"""Build the store.
Parameters
----------
values : array_like
list of values of features to store
values_names : array_like
Iterable of str indicating the name of the features. Must be same length as values.
name: string, optional
name of the features store
"""
if isinstance(values, list):
storage = np.array(values)
elif not isinstance(values, np.ndarray):
raise ValueError("ArrayStorage Values must be a list or a numpy array")

# self.storage = storage
self.values_names = values_names
self.name = name

self.shape = storage.shape
self.indexer = storage

def get_element_from_index(self, index):
"""Getter method over self.sequence.
Returns the features stored at index index. Compared to __getitem__, it does take
the index-th element of sequence but the index-th element of the store.
Parameters
----------
index : (int, list, slice)
index argument of the feature
Returns
-------
array_like
features corresponding to the index index in self.store
"""
return self.batch[index]

def __len__(self):
"""Return the length of the sequence of apparition of the features."""
return self.shape[0]

def __getitem__(self, id_keys):
"""Subset FeaturesStorage, keeping only features which id is in keys.
Parameters
----------
id_keys : Iterable
List of ids to keep.
Returns
-------
FeaturesStorage
Subset of the FeaturesStorage, with only the features whose id is in id_keys
"""
if not isinstance(id_keys, list):
id_keys = [id_keys]
return ArrayStorage(
values=self.batch[id_keys], values_names=self.values_names, name=self.name
)

def get_storage_type(self):
"""Functions to access stored elements dtypes.
Returns
-------
tuple
tuple of dtypes of the stored elements, as returned by np.dtype
"""
element = self.get_element_from_index(0)
return element.dtype

@property
def batch(self):
"""Indexing attribute."""
return self.indexer


class OneHotStorage(Storage):
"""Specific Storage for one hot features storage.
Expand Down
64 changes: 49 additions & 15 deletions notebooks/data/features_byID_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"metadata": {},
"outputs": [],
"source": [
"from choice_learn.data.storage import FeaturesStorage, OneHotStorage\n",
"from choice_learn.data.storage import FeaturesStorage\n",
"from choice_learn.data import ChoiceDataset"
]
},
Expand All @@ -71,7 +71,15 @@
"metadata": {
"keep_output": true
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DictStorage\n"
]
}
],
"source": [
"features = {\"customerA\": [1, 2, 3], \"customerB\": [4, 5, 6], \"customerC\": [7, 8, 9]}\n",
"# dict must be {id: features}\n",
Expand All @@ -88,14 +96,11 @@
},
"outputs": [
{
"data": {
"text/plain": [
"<choice_learn.data.storage.FeaturesStorage at 0x7f14dc11c610>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
"DictStorage\n"
]
}
],
"source": [
Expand Down Expand Up @@ -143,6 +148,13 @@
"keep_output": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DictStorage\n"
]
},
{
"data": {
"text/plain": [
Expand Down Expand Up @@ -220,6 +232,13 @@
"keep_output": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DictStorage\n"
]
},
{
"data": {
"text/plain": [
Expand Down Expand Up @@ -249,6 +268,13 @@
"keep_output": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DictStorage\n"
]
},
{
"data": {
"text/plain": [
Expand Down Expand Up @@ -293,7 +319,7 @@
"\n",
"# Here the Storage will map the ids to the values\n",
"# value = 4 means that the fifth value is a one, the rest are zeros\n",
"oh_storage = OneHotStorage(ids=ids, values=values, name=\"OneHotTest\")"
"oh_storage = FeaturesStorage(ids=ids, values=values, as_one_hot=True, name=\"OneHotTest\")"
]
},
{
Expand Down Expand Up @@ -374,7 +400,7 @@
}
],
"source": [
"oh_storage = OneHotStorage(values=values, name=\"OneHotTest\")\n",
"oh_storage = FeaturesStorage(values=values, as_one_hot=True, name=\"OneHotTest\")\n",
"oh_storage.batch[[0, 2, 4]]"
]
},
Expand Down Expand Up @@ -406,7 +432,7 @@
}
],
"source": [
"oh_storage = OneHotStorage(ids=ids, name=\"OneHotTest\")\n",
"oh_storage = FeaturesStorage(ids=ids, as_one_hot=True, name=\"OneHotTest\")\n",
"oh_storage.batch[[0, 2, 4]]\n",
"# Note that here it changes the order !"
]
Expand Down Expand Up @@ -440,7 +466,7 @@
],
"source": [
"values_dict = {k:v for k, v in zip(ids, values)}\n",
"oh_storage = OneHotStorage(values=values_dict, name=\"OneHotTest\")\n",
"oh_storage = FeaturesStorage(values=values_dict, as_one_hot=True, name=\"OneHotTest\")\n",
"oh_storage.batch[[0, 2, 4]]"
]
},
Expand All @@ -463,7 +489,15 @@
"metadata": {
"keep_output": true
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DictStorage\n"
]
}
],
"source": [
"features = {\"customerA\": [1, 2, 3], \"customerB\": [4, 5, 6], \"customerC\": [7, 8, 9]}\n",
"customer_storage = FeaturesStorage(values=features,\n",
Expand Down
Loading

0 comments on commit 731d0db

Please sign in to comment.