Skip to content

Commit

Permalink
Merge pull request #6764 from markotoplak/domain-cache-eq
Browse files Browse the repository at this point in the history
[FIX] Avoid slowdowns by caching Domain.__eq__
  • Loading branch information
janezd authored Aug 23, 2024
2 parents 02a0652 + eee8618 commit b981d76
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 24 deletions.
25 changes: 22 additions & 3 deletions Orange/data/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from Orange.data import (
Unknown, Variable, ContinuousVariable, DiscreteVariable, StringVariable
)
from Orange.misc.cache import IDWeakrefCache
from Orange.util import deprecated, OrangeDeprecationWarning

__all__ = ["DomainConversion", "Domain"]
Expand Down Expand Up @@ -169,6 +170,7 @@ def __init__(self, attributes, class_vars=None, metas=None, source=None):
self.anonymous = False

self._hash = None # cache for __hash__()
self._eq_cache = IDWeakrefCache(_LRS10Dict()) # cache for __eq__()

def _ensure_indices(self):
if self._indices is None:
Expand All @@ -185,6 +187,7 @@ def __setstate__(self, state):
self._variables = self.attributes + self.class_vars
self._indices = None
self._hash = None
self._eq_cache = {}

def __getstate__(self):
# Do not pickle dictionaries because unpickling dictionaries that
Expand All @@ -195,6 +198,7 @@ def __getstate__(self):
del state["_variables"]
del state["_indices"]
del state["_hash"]
del state["_eq_cache"]
return state

# noinspection PyPep8Naming
Expand Down Expand Up @@ -518,11 +522,26 @@ def __eq__(self, other):
if not isinstance(other, Domain):
return False

return (self.attributes == other.attributes and
self.class_vars == other.class_vars and
self.metas == other.metas)
try:
eq = self._eq_cache[(other,)]
except KeyError:
eq = (self.attributes == other.attributes and
self.class_vars == other.class_vars and
self.metas == other.metas)
self._eq_cache[(other,)] = eq

return eq

def __hash__(self):
if self._hash is None:
self._hash = hash(self.attributes) ^ hash(self.class_vars) ^ hash(self.metas)
return self._hash


class _LRS10Dict(dict):
""" A small "least recently stored" (not LRU) dict """

def __setitem__(self, key, value):
if len(self) >= 10:
del self[next(iter(self))]
super().__setitem__(key, value)
33 changes: 18 additions & 15 deletions Orange/data/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from Orange.data.util import SharedComputeValue, \
assure_array_dense, assure_array_sparse, \
assure_column_dense, assure_column_sparse, get_unique_names_duplicates
from Orange.misc.cache import IDWeakrefCache
from Orange.misc.collections import frozendict
from Orange.statistics.util import bincount, countnans, contingency, \
stats as fast_stats, sparse_has_implicit_zeros, sparse_count_implicit_zeros, \
Expand Down Expand Up @@ -305,10 +306,11 @@ def get_columns(self, source, row_indices, out=None, target_indices=None):
)
elif not isinstance(col, Integral):
if isinstance(col, SharedComputeValue):
shared = _idcache_restore(shared_cache, (col.compute_shared, source))
if shared is None:
try:
shared = shared_cache[(col.compute_shared, source)]
except KeyError:
shared = col.compute_shared(sourceri)
_idcache_save(shared_cache, (col.compute_shared, source), shared)
shared_cache[col.compute_shared, source] = shared
col_array = match_density(
_compute_column(col, sourceri, shared_data=shared))
else:
Expand Down Expand Up @@ -439,7 +441,7 @@ def convert(self, source, row_indices, clear_cache_after_part):

# clear cache after a part is done
if clear_cache_after_part:
_thread_local.conversion_cache = {}
_thread_local.conversion_cache.clear()

for array_conv in self.columnwise:
res[array_conv.target] = \
Expand Down Expand Up @@ -805,25 +807,26 @@ def from_table(cls, domain, source, row_indices=...):
new_cache = _thread_local.conversion_cache is None
try:
if new_cache:
_thread_local.conversion_cache = {}
_thread_local.domain_cache = {}
_thread_local.conversion_cache = IDWeakrefCache({})
_thread_local.domain_cache = IDWeakrefCache({})
else:
cached = _idcache_restore(_thread_local.conversion_cache, (domain, source))
if cached is not None:
return cached
try:
return _thread_local.conversion_cache[(domain, source)]
except KeyError:
pass

# avoid boolean indices; also convert to slices if possible
row_indices = _optimize_indices(row_indices, len(source))

self = cls()
self.domain = domain

table_conversion = \
_idcache_restore(_thread_local.domain_cache, (domain, source.domain))
if table_conversion is None:
try:
table_conversion = \
_thread_local.domain_cache[(domain, source.domain)]
except KeyError:
table_conversion = _FromTableConversion(source.domain, domain)
_idcache_save(_thread_local.domain_cache, (domain, source.domain),
table_conversion)
_thread_local.domain_cache[(domain, source.domain)] = table_conversion

# if an array can be a subarray of the input table, this needs to be done
# on the whole table, because this avoids needless copies of contents
Expand All @@ -838,7 +841,7 @@ def from_table(cls, domain, source, row_indices=...):
self.attributes = getattr(source, 'attributes', {})
if new_cache: # only deepcopy attributes for the outermost transformation
self.attributes = deepcopy(self.attributes)
_idcache_save(_thread_local.conversion_cache, (domain, source), self)
_thread_local.conversion_cache[(domain, source)] = self
return self
finally:
if new_cache:
Expand Down
29 changes: 29 additions & 0 deletions Orange/misc/cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Common caching methods, using `lru_cache` sometimes has its downsides."""
from functools import wraps, lru_cache
from typing import MutableMapping
import weakref


Expand Down Expand Up @@ -54,3 +55,31 @@ def _wrapped_func(self, *args, **kwargs):
return _wrapped_func

return _decorator


class IDWeakrefCache:
"""
Cache that caches keys according to their id() for speed. It also stores
weak references to the keys to ensure that the same keys are being accessed.
"""

def __init__(self, cache: MutableMapping):
self._cache = cache

def __setitem__(self, keys, value):
self._cache[tuple(map(id, keys))] = \
value, [weakref.ref(k) for k in keys]

def __getitem__(self, keys):
key = tuple(map(id, keys))
if key not in self._cache:
raise KeyError()
shared, weakrefs = self._cache[key]
for r in weakrefs:
if r() is None:
del self._cache[key]
raise KeyError()
return shared

def clear(self):
self._cache.clear()
8 changes: 8 additions & 0 deletions Orange/regression/tests/test_pls.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ def test_eq_hash(self):
self.assertNotEqual(hash(proj1), hash(proj2))
self.assertNotEqual(hash(proj1.domain), hash(proj2.domain))

def test_eq_hash_fake_same_model(self):
data = Table("housing")
pls1 = PLSRegressionLearner()(data)
pls2 = PLSRegressionLearner()(data)

proj1 = pls1.project(data)
proj2 = pls2.project(data)

proj2.domain[0].compute_value.compute_shared.pls_model = \
proj1.domain[0].compute_value.compute_shared.pls_model
# reset hash caches because object were hacked
Expand Down
62 changes: 56 additions & 6 deletions Orange/tests/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,24 +441,74 @@ def test_different_domains_with_same_attributes_are_equal(self):
self.assertEqual(domain1, domain2)

var1 = ContinuousVariable('var1')
domain1.attributes = (var1,)
domain1 = Domain([var1])
self.assertNotEqual(domain1, domain2)

domain2.attributes = (var1,)
domain2 = Domain([var1])
self.assertEqual(domain1, domain2)

domain1.class_vars = (var1,)
var2 = ContinuousVariable('var2')
domain1 = Domain([var1], [var2])
self.assertNotEqual(domain1, domain2)

domain2.class_vars = (var1,)
domain2 = Domain([var1], [var2])
self.assertEqual(domain1, domain2)

domain1._metas = (var1,)
var3 = ContinuousVariable('var3')
domain1 = Domain([var1], [var2], [var3])
self.assertNotEqual(domain1, domain2)

domain2._metas = (var1,)
domain2 = Domain([var1], [var2], [var3])
self.assertEqual(domain1, domain2)

def test_eq_cached(self):

class ComputeValueEqOnce:
calls = 0

def __eq__(self, other):
if self.calls > 0:
raise RuntimeError()
self.calls += 1
return type(self) is type(other)

def __hash__(self):
return hash(type(self))

var1 = ContinuousVariable('var1', compute_value=ComputeValueEqOnce())
var1a = ContinuousVariable('var1', compute_value=ComputeValueEqOnce())
domain1 = Domain([var1])
domain2 = Domain([var1a])
self.assertTrue(domain1 == domain2)

# the second call would crash if __eq__ was not cached
self.assertTrue(domain1 == domain2)

# modify the cache, see if that has an effect
domain1._eq_cache[(domain2,)] = False # pylint: disable=protected-access
self.assertFalse(domain1 == domain2)

def test_eq_cache_not_grow(self):
var = ContinuousVariable('var')
domain = Domain([var])
domains = [Domain([var]) for _ in range(10)]
for d in domains:
self.assertTrue(domain == d)

# pylint: disable=protected-access,pointless-statement

# __eq__ results to all ten domains should be cached
for d in domains:
domain._eq_cache[(d,)]

dn = Domain([var])
self.assertTrue(domain == dn)
# the last compared domain should be cached
domain._eq_cache[(dn,)]
# but the first compared should be lost in cache
with self.assertRaises(KeyError):
domain._eq_cache[(domains[0],)]

def test_domain_conversion_is_fast_enough(self):
attrs = [ContinuousVariable("f%i" % i) for i in range(10000)]
class_vars = [ContinuousVariable("c%i" % i) for i in range(10)]
Expand Down
6 changes: 6 additions & 0 deletions Orange/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ def test_eq_hash(self):
self.assertNotEqual(hash(p1), hash(p2))
self.assertNotEqual(hash(p1.domain), hash(p2.domain))

def test_eq_hash_fake_same_projection(self):
d = np.random.RandomState(0).rand(20, 20)
data = Table.from_numpy(None, d)
p1 = PCA()(data)
p2 = PCA()(data)

# copy projection
p2.domain[0].compute_value.compute_shared.projection = \
p1.domain[0].compute_value.compute_shared.projection
Expand Down

0 comments on commit b981d76

Please sign in to comment.