diff --git a/Orange/data/domain.py b/Orange/data/domain.py index 88e3245d819..3191e147646 100644 --- a/Orange/data/domain.py +++ b/Orange/data/domain.py @@ -11,6 +11,7 @@ from Orange.data import ( Unknown, Variable, ContinuousVariable, DiscreteVariable, StringVariable ) +from Orange.misc.cache import IDWeakrefCache from Orange.util import deprecated, OrangeDeprecationWarning __all__ = ["DomainConversion", "Domain"] @@ -169,6 +170,7 @@ def __init__(self, attributes, class_vars=None, metas=None, source=None): self.anonymous = False self._hash = None # cache for __hash__() + self._eq_cache = IDWeakrefCache(_LRS10Dict()) # cache for __eq__() def _ensure_indices(self): if self._indices is None: @@ -185,6 +187,7 @@ def __setstate__(self, state): self._variables = self.attributes + self.class_vars self._indices = None self._hash = None + self._eq_cache = {} def __getstate__(self): # Do not pickle dictionaries because unpickling dictionaries that @@ -195,6 +198,7 @@ def __getstate__(self): del state["_variables"] del state["_indices"] del state["_hash"] + del state["_eq_cache"] return state # noinspection PyPep8Naming @@ -518,11 +522,26 @@ def __eq__(self, other): if not isinstance(other, Domain): return False - return (self.attributes == other.attributes and - self.class_vars == other.class_vars and - self.metas == other.metas) + try: + eq = self._eq_cache[(other,)] + except KeyError: + eq = (self.attributes == other.attributes and + self.class_vars == other.class_vars and + self.metas == other.metas) + self._eq_cache[(other,)] = eq + + return eq def __hash__(self): if self._hash is None: self._hash = hash(self.attributes) ^ hash(self.class_vars) ^ hash(self.metas) return self._hash + + +class _LRS10Dict(dict): + """ A small "least recently stored" (not LRU) dict """ + + def __setitem__(self, key, value): + if len(self) >= 10: + del self[next(iter(self))] + super().__setitem__(key, value) diff --git a/Orange/data/table.py b/Orange/data/table.py index c82f3e10271..c7f470a005a 100644 --- a/Orange/data/table.py +++ b/Orange/data/table.py @@ -29,6 +29,7 @@ from Orange.data.util import SharedComputeValue, \ assure_array_dense, assure_array_sparse, \ assure_column_dense, assure_column_sparse, get_unique_names_duplicates +from Orange.misc.cache import IDWeakrefCache from Orange.misc.collections import frozendict from Orange.statistics.util import bincount, countnans, contingency, \ stats as fast_stats, sparse_has_implicit_zeros, sparse_count_implicit_zeros, \ @@ -305,10 +306,11 @@ def get_columns(self, source, row_indices, out=None, target_indices=None): ) elif not isinstance(col, Integral): if isinstance(col, SharedComputeValue): - shared = _idcache_restore(shared_cache, (col.compute_shared, source)) - if shared is None: + try: + shared = shared_cache[(col.compute_shared, source)] + except KeyError: shared = col.compute_shared(sourceri) - _idcache_save(shared_cache, (col.compute_shared, source), shared) + shared_cache[col.compute_shared, source] = shared col_array = match_density( _compute_column(col, sourceri, shared_data=shared)) else: @@ -439,7 +441,7 @@ def convert(self, source, row_indices, clear_cache_after_part): # clear cache after a part is done if clear_cache_after_part: - _thread_local.conversion_cache = {} + _thread_local.conversion_cache.clear() for array_conv in self.columnwise: res[array_conv.target] = \ @@ -805,12 +807,13 @@ def from_table(cls, domain, source, row_indices=...): new_cache = _thread_local.conversion_cache is None try: if new_cache: - _thread_local.conversion_cache = {} - _thread_local.domain_cache = {} + _thread_local.conversion_cache = IDWeakrefCache({}) + _thread_local.domain_cache = IDWeakrefCache({}) else: - cached = _idcache_restore(_thread_local.conversion_cache, (domain, source)) - if cached is not None: - return cached + try: + return _thread_local.conversion_cache[(domain, source)] + except KeyError: + pass # avoid boolean indices; also convert to slices if possible row_indices = _optimize_indices(row_indices, len(source)) @@ -818,12 +821,12 @@ def from_table(cls, domain, source, row_indices=...): self = cls() self.domain = domain - table_conversion = \ - _idcache_restore(_thread_local.domain_cache, (domain, source.domain)) - if table_conversion is None: + try: + table_conversion = \ + _thread_local.domain_cache[(domain, source.domain)] + except KeyError: table_conversion = _FromTableConversion(source.domain, domain) - _idcache_save(_thread_local.domain_cache, (domain, source.domain), - table_conversion) + _thread_local.domain_cache[(domain, source.domain)] = table_conversion # if an array can be a subarray of the input table, this needs to be done # on the whole table, because this avoids needless copies of contents @@ -838,7 +841,7 @@ def from_table(cls, domain, source, row_indices=...): self.attributes = getattr(source, 'attributes', {}) if new_cache: # only deepcopy attributes for the outermost transformation self.attributes = deepcopy(self.attributes) - _idcache_save(_thread_local.conversion_cache, (domain, source), self) + _thread_local.conversion_cache[(domain, source)] = self return self finally: if new_cache: diff --git a/Orange/misc/cache.py b/Orange/misc/cache.py index ee165934a3f..d813fba5d59 100644 --- a/Orange/misc/cache.py +++ b/Orange/misc/cache.py @@ -1,5 +1,6 @@ """Common caching methods, using `lru_cache` sometimes has its downsides.""" from functools import wraps, lru_cache +from typing import MutableMapping import weakref @@ -54,3 +55,31 @@ def _wrapped_func(self, *args, **kwargs): return _wrapped_func return _decorator + + +class IDWeakrefCache: + """ + Cache that caches keys according to their id() for speed. It also stores + weak references to the keys to ensure that the same keys are being accessed. + """ + + def __init__(self, cache: MutableMapping): + self._cache = cache + + def __setitem__(self, keys, value): + self._cache[tuple(map(id, keys))] = \ + value, [weakref.ref(k) for k in keys] + + def __getitem__(self, keys): + key = tuple(map(id, keys)) + if key not in self._cache: + raise KeyError() + shared, weakrefs = self._cache[key] + for r in weakrefs: + if r() is None: + del self._cache[key] + raise KeyError() + return shared + + def clear(self): + self._cache.clear() diff --git a/Orange/regression/tests/test_pls.py b/Orange/regression/tests/test_pls.py index f289e6907c6..3b72e36e27a 100644 --- a/Orange/regression/tests/test_pls.py +++ b/Orange/regression/tests/test_pls.py @@ -124,6 +124,14 @@ def test_eq_hash(self): self.assertNotEqual(hash(proj1), hash(proj2)) self.assertNotEqual(hash(proj1.domain), hash(proj2.domain)) + def test_eq_hash_fake_same_model(self): + data = Table("housing") + pls1 = PLSRegressionLearner()(data) + pls2 = PLSRegressionLearner()(data) + + proj1 = pls1.project(data) + proj2 = pls2.project(data) + proj2.domain[0].compute_value.compute_shared.pls_model = \ proj1.domain[0].compute_value.compute_shared.pls_model # reset hash caches because object were hacked diff --git a/Orange/tests/test_domain.py b/Orange/tests/test_domain.py index 368a50e6f2d..a5e32660a06 100644 --- a/Orange/tests/test_domain.py +++ b/Orange/tests/test_domain.py @@ -441,24 +441,74 @@ def test_different_domains_with_same_attributes_are_equal(self): self.assertEqual(domain1, domain2) var1 = ContinuousVariable('var1') - domain1.attributes = (var1,) + domain1 = Domain([var1]) self.assertNotEqual(domain1, domain2) - domain2.attributes = (var1,) + domain2 = Domain([var1]) self.assertEqual(domain1, domain2) - domain1.class_vars = (var1,) + var2 = ContinuousVariable('var2') + domain1 = Domain([var1], [var2]) self.assertNotEqual(domain1, domain2) - domain2.class_vars = (var1,) + domain2 = Domain([var1], [var2]) self.assertEqual(domain1, domain2) - domain1._metas = (var1,) + var3 = ContinuousVariable('var3') + domain1 = Domain([var1], [var2], [var3]) self.assertNotEqual(domain1, domain2) - domain2._metas = (var1,) + domain2 = Domain([var1], [var2], [var3]) self.assertEqual(domain1, domain2) + def test_eq_cached(self): + + class ComputeValueEqOnce: + calls = 0 + + def __eq__(self, other): + if self.calls > 0: + raise RuntimeError() + self.calls += 1 + return type(self) is type(other) + + def __hash__(self): + return hash(type(self)) + + var1 = ContinuousVariable('var1', compute_value=ComputeValueEqOnce()) + var1a = ContinuousVariable('var1', compute_value=ComputeValueEqOnce()) + domain1 = Domain([var1]) + domain2 = Domain([var1a]) + self.assertTrue(domain1 == domain2) + + # the second call would crash if __eq__ was not cached + self.assertTrue(domain1 == domain2) + + # modify the cache, see if that has an effect + domain1._eq_cache[(domain2,)] = False # pylint: disable=protected-access + self.assertFalse(domain1 == domain2) + + def test_eq_cache_not_grow(self): + var = ContinuousVariable('var') + domain = Domain([var]) + domains = [Domain([var]) for _ in range(10)] + for d in domains: + self.assertTrue(domain == d) + + # pylint: disable=protected-access,pointless-statement + + # __eq__ results to all ten domains should be cached + for d in domains: + domain._eq_cache[(d,)] + + dn = Domain([var]) + self.assertTrue(domain == dn) + # the last compared domain should be cached + domain._eq_cache[(dn,)] + # but the first compared should be lost in cache + with self.assertRaises(KeyError): + domain._eq_cache[(domains[0],)] + def test_domain_conversion_is_fast_enough(self): attrs = [ContinuousVariable("f%i" % i) for i in range(10000)] class_vars = [ContinuousVariable("c%i" % i) for i in range(10)] diff --git a/Orange/tests/test_pca.py b/Orange/tests/test_pca.py index e5f1c1d7283..ca59483a461 100644 --- a/Orange/tests/test_pca.py +++ b/Orange/tests/test_pca.py @@ -178,6 +178,12 @@ def test_eq_hash(self): self.assertNotEqual(hash(p1), hash(p2)) self.assertNotEqual(hash(p1.domain), hash(p2.domain)) + def test_eq_hash_fake_same_projection(self): + d = np.random.RandomState(0).rand(20, 20) + data = Table.from_numpy(None, d) + p1 = PCA()(data) + p2 = PCA()(data) + # copy projection p2.domain[0].compute_value.compute_shared.projection = \ p1.domain[0].compute_value.compute_shared.projection