diff --git a/Orange/widgets/data/owfeatureconstructor.py b/Orange/widgets/data/owfeatureconstructor.py index 72704e9727c..674f1d0fe8d 100644 --- a/Orange/widgets/data/owfeatureconstructor.py +++ b/Orange/widgets/data/owfeatureconstructor.py @@ -21,7 +21,7 @@ from traceback import format_exception_only from collections import namedtuple, OrderedDict from itertools import chain, count, starmap -from typing import List, Dict, Any, Mapping +from typing import List, Dict, Any, Mapping, Optional import numpy as np @@ -1021,6 +1021,7 @@ def bind_variable(descriptor, env, data, use_values): values = {} cast = None + dtype = object if isinstance(descriptor, StringDescriptor) else float if isinstance(descriptor, DiscreteDescriptor): if not descriptor.values: @@ -1038,7 +1039,7 @@ def bind_variable(descriptor, env, data, use_values): cast = DateTimeCast() func = FeatureFunc(descriptor.expression, source_vars, values, cast, - use_values=use_values) + use_values=use_values, dtype=dtype) return descriptor, func @@ -1216,7 +1217,10 @@ class FeatureFunc: A function for casting the expressions result to the appropriate type (e.g. string representation of date/time variables to floats) """ - def __init__(self, expression, args, extra_env=None, cast=None, use_values=False): + dtype: Optional['DType'] = None + + def __init__(self, expression, args, extra_env=None, cast=None, use_values=False, + dtype=None): self.expression = expression self.args = args self.extra_env = dict(extra_env or {}) @@ -1225,6 +1229,7 @@ def __init__(self, expression, args, extra_env=None, cast=None, use_values=False self.cast = cast self.mask_exceptions = True self.use_values = use_values + self.dtype = dtype def __call__(self, table, *_): if isinstance(table, Table): @@ -1252,7 +1257,7 @@ def __call_table(self, table): y = list(starmap(f, args)) if self.cast is not None: y = self.cast(y) - return y + return np.asarray(y, dtype=self.dtype) def __call_instance(self, instance: Instance): table = Table.from_numpy( @@ -1281,7 +1286,8 @@ def extract_column(self, table: Table, var: Variable): def __reduce__(self): return type(self), (self.expression, self.args, - self.extra_env, self.cast, self.use_values) + self.extra_env, self.cast, self.use_values, + self.dtype) def __repr__(self): return "{0.__name__}{1!r}".format(*self.__reduce__()) diff --git a/Orange/widgets/data/tests/test_owfeatureconstructor.py b/Orange/widgets/data/tests/test_owfeatureconstructor.py index 3de43bae642..237b4908ac7 100644 --- a/Orange/widgets/data/tests/test_owfeatureconstructor.py +++ b/Orange/widgets/data/tests/test_owfeatureconstructor.py @@ -8,6 +8,7 @@ from unittest.mock import patch, Mock import numpy as np +from scipy import sparse as sp from orangewidget.settings import Context @@ -148,6 +149,20 @@ def test_unicode_normalization(): construct_variables(desc, data))) np.testing.assert_equal(data.X, data.metas) + def test_transform_sparse(self): + domain = Domain([ContinuousVariable("A")]) + desc = [ + ContinuousDescriptor(name="X", expression="A", number_of_decimals=2) + ] + X = sp.csc_matrix(np.arange(5).reshape(5, 1)) + data = Table.from_numpy(domain, X) + data_ = data.transform(Domain(data.domain.attributes, + [], + construct_variables(desc, data))) + np.testing.assert_equal( + data.get_column_view(0)[0], data_.get_column_view(0)[0] + ) + class TestTools(unittest.TestCase): def test_free_vars(self): @@ -276,7 +291,7 @@ def test_reconstruct(self): def test_repr(self): self.assertEqual(repr(FeatureFunc("a + 1", [("a", 2)])), - "FeatureFunc('a + 1', [('a', 2)], {}, None, False)") + "FeatureFunc('a + 1', [('a', 2)], {}, None, False, None)") def test_call(self): iris = Table("iris") @@ -291,7 +306,7 @@ def test_string_casting(self): f = FeatureFunc("name[0]", [("name", zoo.domain["name"])]) r = f(zoo) - self.assertEqual(r, [x[0] for x in zoo.metas[:, 0]]) + self.assertEqual(list(r), [x[0] for x in zoo.metas[:, 0]]) self.assertEqual(f(zoo[0]), str(zoo[0, "name"])[0]) def test_missing_variable(self): @@ -309,7 +324,7 @@ def test_time_str(self): data = Table.from_numpy(Domain([TimeVariable("T", have_date=True)]), [[0], [0]]) f = FeatureFunc("str(T)", [("T", data.domain[0])]) c = f(data) - self.assertEqual(c, ["1970-01-01", "1970-01-01"]) + self.assertEqual(list(c), ["1970-01-01", "1970-01-01"]) def test_invalid_expression_variable(self): iris = Table("iris")