make @llnl.util.lang.memoized support kwargs (#21722)

* make memoized() support kwargs

* add testing for @memoized
This commit is contained in:
Danny McClanahan 2022-03-02 19:12:15 +00:00 committed by GitHub
parent 916c94fd65
commit 2c331a1d7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 10 deletions

View File

@ -13,9 +13,10 @@
import sys import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
import six
from six import string_types from six import string_types
from llnl.util.compat import Hashable, MutableMapping, zip_longest from llnl.util.compat import MutableMapping, zip_longest
# Ignore emacs backups when listing modules # Ignore emacs backups when listing modules
ignore_modules = [r'^\.#', '~$'] ignore_modules = [r'^\.#', '~$']
@ -165,6 +166,19 @@ def union_dicts(*dicts):
return result return result
# Used as a sentinel that disambiguates tuples passed in *args from coincidentally
# matching tuples formed from kwargs item pairs.
_kwargs_separator = (object(),)
def stable_args(*args, **kwargs):
"""A key factory that performs a stable sort of the parameters."""
key = args
if kwargs:
key += _kwargs_separator + tuple(sorted(kwargs.items()))
return key
def memoized(func): def memoized(func):
"""Decorator that caches the results of a function, storing them in """Decorator that caches the results of a function, storing them in
an attribute of that function. an attribute of that function.
@ -172,15 +186,23 @@ def memoized(func):
func.cache = {} func.cache = {}
@functools.wraps(func) @functools.wraps(func)
def _memoized_function(*args): def _memoized_function(*args, **kwargs):
if not isinstance(args, Hashable): key = stable_args(*args, **kwargs)
# Not hashable, so just call the function.
return func(*args)
if args not in func.cache: try:
func.cache[args] = func(*args) return func.cache[key]
except KeyError:
return func.cache[args] ret = func(*args, **kwargs)
func.cache[key] = ret
return ret
except TypeError as e:
# TypeError is raised when indexing into a dict if the key is unhashable.
raise six.raise_from(
UnhashableArguments(
"args + kwargs '{}' was not hashable for function '{}'"
.format(key, func.__name__),
),
e)
return _memoized_function return _memoized_function
@ -930,3 +952,7 @@ def nullcontext(*args, **kwargs):
TODO: replace with contextlib.nullcontext() if we ever require python 3.7. TODO: replace with contextlib.nullcontext() if we ever require python 3.7.
""" """
yield yield
class UnhashableArguments(TypeError):
"""Raise when an @memoized function receives unhashable arg or kwarg values."""

View File

@ -10,7 +10,7 @@
import pytest import pytest
import llnl.util.lang import llnl.util.lang
from llnl.util.lang import match_predicate, pretty_date from llnl.util.lang import match_predicate, memoized, pretty_date, stable_args
@pytest.fixture() @pytest.fixture()
@ -205,3 +205,63 @@ def _cmp_key(self):
assert hash(a) == hash(a2) assert hash(a) == hash(a2)
assert hash(b) == hash(b) assert hash(b) == hash(b)
assert hash(b) == hash(b2) assert hash(b) == hash(b2)
@pytest.mark.parametrize(
"args1,kwargs1,args2,kwargs2",
[
# Ensure tuples passed in args are disambiguated from equivalent kwarg items.
(('a', 3), {}, (), {'a': 3})
],
)
def test_unequal_args(args1, kwargs1, args2, kwargs2):
assert stable_args(*args1, **kwargs1) != stable_args(*args2, **kwargs2)
@pytest.mark.parametrize(
"args1,kwargs1,args2,kwargs2",
[
# Ensure that kwargs are stably sorted.
((), {'a': 3, 'b': 4}, (), {'b': 4, 'a': 3}),
],
)
def test_equal_args(args1, kwargs1, args2, kwargs2):
assert stable_args(*args1, **kwargs1) == stable_args(*args2, **kwargs2)
@pytest.mark.parametrize(
"args, kwargs",
[
((1,), {}),
((), {'a': 3}),
((1,), {'a': 3}),
],
)
def test_memoized(args, kwargs):
@memoized
def f(*args, **kwargs):
return 'return-value'
assert f(*args, **kwargs) == 'return-value'
key = stable_args(*args, **kwargs)
assert list(f.cache.keys()) == [key]
assert f.cache[key] == 'return-value'
@pytest.mark.parametrize(
"args, kwargs",
[
(([1],), {}),
((), {'a': [1]})
],
)
def test_memoized_unhashable(args, kwargs):
"""Check that an exception is raised clearly"""
@memoized
def f(*args, **kwargs):
return None
with pytest.raises(llnl.util.lang.UnhashableArguments) as exc_info:
f(*args, **kwargs)
exc_msg = str(exc_info.value)
key = stable_args(*args, **kwargs)
assert str(key) in exc_msg
assert "function 'f'" in exc_msg