Skip to content

Commit

Permalink
Table.match_type: Don't flatten subarrays
Browse files Browse the repository at this point in the history
  • Loading branch information
nikicc committed Jun 1, 2017
1 parent 80475e7 commit 125a142
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions Orange/data/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,27 @@ def from_table(cls, domain, source, row_indices=...):

def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
is_sparse=False):
def match_type(x):
""" Assure that matrix and column are both dense or sparse. """
def match_type(x, force_1d=False):
""" Assure that matrix and column are both dense or sparse.
Args:
x (np.ndarray, scipy.sparse): data
force_1d (bool): If set, flatten resulting array to 1d.
Returns:
array of correct density.
"""
if is_sparse == sp.issparse(x):
return x
elif is_sparse:
if is_sparse:
x = np.asarray(x)
return sp.csc_matrix(x.reshape(-1, 1).astype(np.float))
else:
return np.ravel(x.toarray())
x = x.toarray()
if force_1d:
x = np.ravel(x)
return x

match_type_1d = lambda x: match_type(x, force_1d=True)

if not len(src_cols):
if is_sparse:
Expand Down Expand Up @@ -314,22 +326,22 @@ def match_type(x):
shared_cache[id(col.compute_shared), id(source)] = col.compute_shared(source)
shared = shared_cache[id(col.compute_shared), id(source)]
if row_indices is not ...:
a[:, i] = match_type(
a[:, i] = match_type_1d(
col(source, shared_data=shared)[row_indices])
else:
a[:, i] = match_type(
a[:, i] = match_type_1d(
col(source, shared_data=shared))
else:
if row_indices is not ...:
a[:, i] = match_type(col(source)[row_indices])
a[:, i] = match_type_1d(col(source)[row_indices])
else:
a[:, i] = match_type(col(source))
a[:, i] = match_type_1d(col(source))
elif col < 0:
a[:, i] = match_type(source.metas[row_indices, -1 - col])
a[:, i] = match_type_1d(source.metas[row_indices, -1 - col])
elif col < n_src_attrs:
a[:, i] = match_type(source.X[row_indices, col])
a[:, i] = match_type_1d(source.X[row_indices, col])
else:
a[:, i] = match_type(
a[:, i] = match_type_1d(
source._Y[row_indices, col - n_src_attrs])

if is_sparse:
Expand Down

0 comments on commit 125a142

Please sign in to comment.