make @llnl.util.lang.memoized support kwargs (#21722)
* make memoized() support kwargs * add testing for @memoized
This commit is contained in:
parent
916c94fd65
commit
2c331a1d7f
@ -13,9 +13,10 @@
|
|||||||
import sys
|
import sys
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import six
|
||||||
from six import string_types
|
from six import string_types
|
||||||
|
|
||||||
from llnl.util.compat import Hashable, MutableMapping, zip_longest
|
from llnl.util.compat import MutableMapping, zip_longest
|
||||||
|
|
||||||
# Ignore emacs backups when listing modules
|
# Ignore emacs backups when listing modules
|
||||||
ignore_modules = [r'^\.#', '~$']
|
ignore_modules = [r'^\.#', '~$']
|
||||||
@ -165,6 +166,19 @@ def union_dicts(*dicts):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Used as a sentinel that disambiguates tuples passed in *args from coincidentally
|
||||||
|
# matching tuples formed from kwargs item pairs.
|
||||||
|
_kwargs_separator = (object(),)
|
||||||
|
|
||||||
|
|
||||||
|
def stable_args(*args, **kwargs):
|
||||||
|
"""A key factory that performs a stable sort of the parameters."""
|
||||||
|
key = args
|
||||||
|
if kwargs:
|
||||||
|
key += _kwargs_separator + tuple(sorted(kwargs.items()))
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
def memoized(func):
|
def memoized(func):
|
||||||
"""Decorator that caches the results of a function, storing them in
|
"""Decorator that caches the results of a function, storing them in
|
||||||
an attribute of that function.
|
an attribute of that function.
|
||||||
@ -172,15 +186,23 @@ def memoized(func):
|
|||||||
func.cache = {}
|
func.cache = {}
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def _memoized_function(*args):
|
def _memoized_function(*args, **kwargs):
|
||||||
if not isinstance(args, Hashable):
|
key = stable_args(*args, **kwargs)
|
||||||
# Not hashable, so just call the function.
|
|
||||||
return func(*args)
|
|
||||||
|
|
||||||
if args not in func.cache:
|
try:
|
||||||
func.cache[args] = func(*args)
|
return func.cache[key]
|
||||||
|
except KeyError:
|
||||||
return func.cache[args]
|
ret = func(*args, **kwargs)
|
||||||
|
func.cache[key] = ret
|
||||||
|
return ret
|
||||||
|
except TypeError as e:
|
||||||
|
# TypeError is raised when indexing into a dict if the key is unhashable.
|
||||||
|
raise six.raise_from(
|
||||||
|
UnhashableArguments(
|
||||||
|
"args + kwargs '{}' was not hashable for function '{}'"
|
||||||
|
.format(key, func.__name__),
|
||||||
|
),
|
||||||
|
e)
|
||||||
|
|
||||||
return _memoized_function
|
return _memoized_function
|
||||||
|
|
||||||
@ -930,3 +952,7 @@ def nullcontext(*args, **kwargs):
|
|||||||
TODO: replace with contextlib.nullcontext() if we ever require python 3.7.
|
TODO: replace with contextlib.nullcontext() if we ever require python 3.7.
|
||||||
"""
|
"""
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
class UnhashableArguments(TypeError):
|
||||||
|
"""Raise when an @memoized function receives unhashable arg or kwarg values."""
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import llnl.util.lang
|
import llnl.util.lang
|
||||||
from llnl.util.lang import match_predicate, pretty_date
|
from llnl.util.lang import match_predicate, memoized, pretty_date, stable_args
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@ -205,3 +205,63 @@ def _cmp_key(self):
|
|||||||
assert hash(a) == hash(a2)
|
assert hash(a) == hash(a2)
|
||||||
assert hash(b) == hash(b)
|
assert hash(b) == hash(b)
|
||||||
assert hash(b) == hash(b2)
|
assert hash(b) == hash(b2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"args1,kwargs1,args2,kwargs2",
|
||||||
|
[
|
||||||
|
# Ensure tuples passed in args are disambiguated from equivalent kwarg items.
|
||||||
|
(('a', 3), {}, (), {'a': 3})
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_unequal_args(args1, kwargs1, args2, kwargs2):
|
||||||
|
assert stable_args(*args1, **kwargs1) != stable_args(*args2, **kwargs2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"args1,kwargs1,args2,kwargs2",
|
||||||
|
[
|
||||||
|
# Ensure that kwargs are stably sorted.
|
||||||
|
((), {'a': 3, 'b': 4}, (), {'b': 4, 'a': 3}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_equal_args(args1, kwargs1, args2, kwargs2):
|
||||||
|
assert stable_args(*args1, **kwargs1) == stable_args(*args2, **kwargs2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"args, kwargs",
|
||||||
|
[
|
||||||
|
((1,), {}),
|
||||||
|
((), {'a': 3}),
|
||||||
|
((1,), {'a': 3}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_memoized(args, kwargs):
|
||||||
|
@memoized
|
||||||
|
def f(*args, **kwargs):
|
||||||
|
return 'return-value'
|
||||||
|
assert f(*args, **kwargs) == 'return-value'
|
||||||
|
key = stable_args(*args, **kwargs)
|
||||||
|
assert list(f.cache.keys()) == [key]
|
||||||
|
assert f.cache[key] == 'return-value'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"args, kwargs",
|
||||||
|
[
|
||||||
|
(([1],), {}),
|
||||||
|
((), {'a': [1]})
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_memoized_unhashable(args, kwargs):
|
||||||
|
"""Check that an exception is raised clearly"""
|
||||||
|
@memoized
|
||||||
|
def f(*args, **kwargs):
|
||||||
|
return None
|
||||||
|
with pytest.raises(llnl.util.lang.UnhashableArguments) as exc_info:
|
||||||
|
f(*args, **kwargs)
|
||||||
|
exc_msg = str(exc_info.value)
|
||||||
|
key = stable_args(*args, **kwargs)
|
||||||
|
assert str(key) in exc_msg
|
||||||
|
assert "function 'f'" in exc_msg
|
||||||
|
Loading…
Reference in New Issue
Block a user