From f7e656e3ff00f6714a7b406ff7c2043bb645306a Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Fri, 1 Sep 2023 23:07:14 +0530 Subject: [PATCH] Make ivy.sparse_array a subclass of ivy.array (#22281) --- ivy/data_classes/array/array.py | 4 ++ .../ivy/experimental/sparse_array.py | 36 ++++++++++++++++- .../test_core/test_sparse_array.py | 40 +++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/ivy/data_classes/array/array.py b/ivy/data_classes/array/array.py index f57bec181ac9c..34bcb9d4b9049 100644 --- a/ivy/data_classes/array/array.py +++ b/ivy/data_classes/array/array.py @@ -144,6 +144,10 @@ def _init(self, data, dynamic_backend=None): self._data = data elif isinstance(data, np.ndarray): self._data = ivy.asarray(data)._data + elif ivy.is_ivy_sparse_array(data): + self._data = data._data + elif ivy.is_native_sparse_array(data): + self._data = data._data else: raise ivy.utils.exceptions.IvyException( "data must be ivy array, native array or ndarray" diff --git a/ivy/functional/ivy/experimental/sparse_array.py b/ivy/functional/ivy/experimental/sparse_array.py index 46e6eb64fa428..5c67cc823e7c1 100644 --- a/ivy/functional/ivy/experimental/sparse_array.py +++ b/ivy/functional/ivy/experimental/sparse_array.py @@ -354,7 +354,7 @@ def _is_valid_format( ) -class SparseArray: +class SparseArray(ivy.Array): def __init__( self, data=None, @@ -425,6 +425,9 @@ def __init__( "col_indices, values and dense_shape)." ) + # initialize parent class + super(SparseArray, self).__init__(self) + def _init_data(self, data): if ivy.is_ivy_sparse_array(data): self._data = data.data @@ -535,6 +538,37 @@ def _init_compressed_column_components( self._crow_indices = None self._col_indices = None + def __repr__(self): + if self._dev_str is None: + self._dev_str = ivy.as_ivy_dev(self.device) + self._pre_repr = "ivy.sparse_array" + if "gpu" in self._dev_str: + self._post_repr = ", dev={})".format(self._dev_str) + else: + self._post_repr = ")" + if self._format == "coo": + repr = ( + f"indices={self._coo_indices}, values={self._values}," + f" dense_shape={self._dense_shape}" + ) + elif self._format == "csr" or self._format == "bsr": + repr = ( + f"crow_indices={self._crow_indices}, col_indices={self._col_indices}," + f" values={self._values}, dense_shape={self._dense_shape}" + ) + else: + repr = ( + f"ccol_indices={self._ccol_indices}, row_indices={self._row_indices}," + f" values={self._values}, dense_shape={self._dense_shape}" + ) + return ( + self._pre_repr + + "(" + + repr + + f", format={self._format}" + + self._post_repr.format(ivy.current_backend_str()) + ) + # Properties # # -----------# diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sparse_array.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sparse_array.py index 8082dc60733ce..de62bbd2be3df 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sparse_array.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sparse_array.py @@ -2,6 +2,8 @@ from hypothesis import strategies as st # local +import ivy +import numpy as np import ivy_tests.test_ivy.helpers as helpers from ivy_tests.test_ivy.helpers import handle_method @@ -178,6 +180,44 @@ def _sparse_csr_indices_values_shape(draw): # ------------ # +# adding sparse array to dense array +@handle_method( + init_tree="ivy.array", + method_tree="Array.__add__", + sparse_data=_sparse_coo_indices_values_shape(), +) +def test_array_add_sparse( + sparse_data, + method_name, + class_name, + on_device, +): + coo_ind, val_dtype, val, shp = sparse_data + + # set backed to 'torch' as this is the only backend which supports sparse arrays + ivy.set_backend("torch") + + # initiate a sparse array + sparse_inst = ivy.sparse_array.SparseArray( + coo_indices=coo_ind, + values=val, + dense_shape=shp, + format="coo", + ) + + # create an Array instance + array_class = getattr(ivy, class_name) + x = np.random.random_sample(shp) + x = ivy.array(x, dtype=val_dtype, device=on_device) + + # call add method + add_method = getattr(array_class, method_name) + res = add_method(x, sparse_inst) + + # make sure the result is an Array instance + assert isinstance(res, array_class) + + # bsc - to_dense_array @handle_method( method_tree="SparseArray.to_dense_array",