make @llnl.util.lang.memoized support kwargs (#21722)
* make memoized() support kwargs * add testing for @memoized
This commit is contained in:
parent
916c94fd65
commit
2c331a1d7f
@ -13,9 +13,10 @@
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import six
|
||||
from six import string_types
|
||||
|
||||
from llnl.util.compat import Hashable, MutableMapping, zip_longest
|
||||
from llnl.util.compat import MutableMapping, zip_longest
|
||||
|
||||
# Ignore emacs backups when listing modules
|
||||
ignore_modules = [r'^\.#', '~$']
|
||||
@ -165,6 +166,19 @@ def union_dicts(*dicts):
|
||||
return result
|
||||
|
||||
|
||||
# Used as a sentinel that disambiguates tuples passed in *args from coincidentally
|
||||
# matching tuples formed from kwargs item pairs.
|
||||
_kwargs_separator = (object(),)
|
||||
|
||||
|
||||
def stable_args(*args, **kwargs):
|
||||
"""A key factory that performs a stable sort of the parameters."""
|
||||
key = args
|
||||
if kwargs:
|
||||
key += _kwargs_separator + tuple(sorted(kwargs.items()))
|
||||
return key
|
||||
|
||||
|
||||
def memoized(func):
|
||||
"""Decorator that caches the results of a function, storing them in
|
||||
an attribute of that function.
|
||||
@ -172,15 +186,23 @@ def memoized(func):
|
||||
func.cache = {}
|
||||
|
||||
@functools.wraps(func)
|
||||
def _memoized_function(*args):
|
||||
if not isinstance(args, Hashable):
|
||||
# Not hashable, so just call the function.
|
||||
return func(*args)
|
||||
def _memoized_function(*args, **kwargs):
|
||||
key = stable_args(*args, **kwargs)
|
||||
|
||||
if args not in func.cache:
|
||||
func.cache[args] = func(*args)
|
||||
|
||||
return func.cache[args]
|
||||
try:
|
||||
return func.cache[key]
|
||||
except KeyError:
|
||||
ret = func(*args, **kwargs)
|
||||
func.cache[key] = ret
|
||||
return ret
|
||||
except TypeError as e:
|
||||
# TypeError is raised when indexing into a dict if the key is unhashable.
|
||||
raise six.raise_from(
|
||||
UnhashableArguments(
|
||||
"args + kwargs '{}' was not hashable for function '{}'"
|
||||
.format(key, func.__name__),
|
||||
),
|
||||
e)
|
||||
|
||||
return _memoized_function
|
||||
|
||||
@ -930,3 +952,7 @@ def nullcontext(*args, **kwargs):
|
||||
TODO: replace with contextlib.nullcontext() if we ever require python 3.7.
|
||||
"""
|
||||
yield
|
||||
|
||||
|
||||
class UnhashableArguments(TypeError):
|
||||
"""Raise when an @memoized function receives unhashable arg or kwarg values."""
|
||||
|
@ -10,7 +10,7 @@
|
||||
import pytest
|
||||
|
||||
import llnl.util.lang
|
||||
from llnl.util.lang import match_predicate, pretty_date
|
||||
from llnl.util.lang import match_predicate, memoized, pretty_date, stable_args
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -205,3 +205,63 @@ def _cmp_key(self):
|
||||
assert hash(a) == hash(a2)
|
||||
assert hash(b) == hash(b)
|
||||
assert hash(b) == hash(b2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args1,kwargs1,args2,kwargs2",
|
||||
[
|
||||
# Ensure tuples passed in args are disambiguated from equivalent kwarg items.
|
||||
(('a', 3), {}, (), {'a': 3})
|
||||
],
|
||||
)
|
||||
def test_unequal_args(args1, kwargs1, args2, kwargs2):
|
||||
assert stable_args(*args1, **kwargs1) != stable_args(*args2, **kwargs2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args1,kwargs1,args2,kwargs2",
|
||||
[
|
||||
# Ensure that kwargs are stably sorted.
|
||||
((), {'a': 3, 'b': 4}, (), {'b': 4, 'a': 3}),
|
||||
],
|
||||
)
|
||||
def test_equal_args(args1, kwargs1, args2, kwargs2):
|
||||
assert stable_args(*args1, **kwargs1) == stable_args(*args2, **kwargs2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args, kwargs",
|
||||
[
|
||||
((1,), {}),
|
||||
((), {'a': 3}),
|
||||
((1,), {'a': 3}),
|
||||
],
|
||||
)
|
||||
def test_memoized(args, kwargs):
|
||||
@memoized
|
||||
def f(*args, **kwargs):
|
||||
return 'return-value'
|
||||
assert f(*args, **kwargs) == 'return-value'
|
||||
key = stable_args(*args, **kwargs)
|
||||
assert list(f.cache.keys()) == [key]
|
||||
assert f.cache[key] == 'return-value'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args, kwargs",
|
||||
[
|
||||
(([1],), {}),
|
||||
((), {'a': [1]})
|
||||
],
|
||||
)
|
||||
def test_memoized_unhashable(args, kwargs):
|
||||
"""Check that an exception is raised clearly"""
|
||||
@memoized
|
||||
def f(*args, **kwargs):
|
||||
return None
|
||||
with pytest.raises(llnl.util.lang.UnhashableArguments) as exc_info:
|
||||
f(*args, **kwargs)
|
||||
exc_msg = str(exc_info.value)
|
||||
key = stable_args(*args, **kwargs)
|
||||
assert str(key) in exc_msg
|
||||
assert "function 'f'" in exc_msg
|
||||
|
Loading…
Reference in New Issue
Block a user