Skip to content

Commit

Permalink
Make ivy.sparse_array a subclass of ivy.array (ivy-llc#22281)
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored Sep 1, 2023
1 parent 85b4b52 commit f7e656e
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 1 deletion.
4 changes: 4 additions & 0 deletions ivy/data_classes/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def _init(self, data, dynamic_backend=None):
self._data = data
elif isinstance(data, np.ndarray):
self._data = ivy.asarray(data)._data
elif ivy.is_ivy_sparse_array(data):
self._data = data._data
elif ivy.is_native_sparse_array(data):
self._data = data._data
else:
raise ivy.utils.exceptions.IvyException(
"data must be ivy array, native array or ndarray"
Expand Down
36 changes: 35 additions & 1 deletion ivy/functional/ivy/experimental/sparse_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def _is_valid_format(
)


class SparseArray:
class SparseArray(ivy.Array):
def __init__(
self,
data=None,
Expand Down Expand Up @@ -425,6 +425,9 @@ def __init__(
"col_indices, values and dense_shape)."
)

# initialize parent class
super(SparseArray, self).__init__(self)

def _init_data(self, data):
if ivy.is_ivy_sparse_array(data):
self._data = data.data
Expand Down Expand Up @@ -535,6 +538,37 @@ def _init_compressed_column_components(
self._crow_indices = None
self._col_indices = None

def __repr__(self):
if self._dev_str is None:
self._dev_str = ivy.as_ivy_dev(self.device)
self._pre_repr = "ivy.sparse_array"
if "gpu" in self._dev_str:
self._post_repr = ", dev={})".format(self._dev_str)
else:
self._post_repr = ")"
if self._format == "coo":
repr = (
f"indices={self._coo_indices}, values={self._values},"
f" dense_shape={self._dense_shape}"
)
elif self._format == "csr" or self._format == "bsr":
repr = (
f"crow_indices={self._crow_indices}, col_indices={self._col_indices},"
f" values={self._values}, dense_shape={self._dense_shape}"
)
else:
repr = (
f"ccol_indices={self._ccol_indices}, row_indices={self._row_indices},"
f" values={self._values}, dense_shape={self._dense_shape}"
)
return (
self._pre_repr
+ "("
+ repr
+ f", format={self._format}"
+ self._post_repr.format(ivy.current_backend_str())
)

# Properties #
# -----------#

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from hypothesis import strategies as st

# local
import ivy
import numpy as np
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_method

Expand Down Expand Up @@ -178,6 +180,44 @@ def _sparse_csr_indices_values_shape(draw):
# ------------ #


# adding sparse array to dense array
@handle_method(
init_tree="ivy.array",
method_tree="Array.__add__",
sparse_data=_sparse_coo_indices_values_shape(),
)
def test_array_add_sparse(
sparse_data,
method_name,
class_name,
on_device,
):
coo_ind, val_dtype, val, shp = sparse_data

# set backed to 'torch' as this is the only backend which supports sparse arrays
ivy.set_backend("torch")

# initiate a sparse array
sparse_inst = ivy.sparse_array.SparseArray(
coo_indices=coo_ind,
values=val,
dense_shape=shp,
format="coo",
)

# create an Array instance
array_class = getattr(ivy, class_name)
x = np.random.random_sample(shp)
x = ivy.array(x, dtype=val_dtype, device=on_device)

# call add method
add_method = getattr(array_class, method_name)
res = add_method(x, sparse_inst)

# make sure the result is an Array instance
assert isinstance(res, array_class)


# bsc - to_dense_array
@handle_method(
method_tree="SparseArray.to_dense_array",
Expand Down

0 comments on commit f7e656e

Please sign in to comment.