make @llnl.util.lang.memoized support kwargs (#21722)

* make memoized() support kwargs

* add testing for @memoized
This commit is contained in:
Danny McClanahan
2022-03-02 19:12:15 +00:00
committed by GitHub
parent 916c94fd65
commit 2c331a1d7f
2 changed files with 96 additions and 10 deletions

View File

@@ -13,9 +13,10 @@
import sys
from datetime import datetime, timedelta
import six
from six import string_types
from llnl.util.compat import Hashable, MutableMapping, zip_longest
from llnl.util.compat import MutableMapping, zip_longest
# Ignore emacs backups when listing modules
ignore_modules = [r'^\.#', '~$']
@@ -165,6 +166,19 @@ def union_dicts(*dicts):
return result
# Used as a sentinel that disambiguates tuples passed in *args from coincidentally
# matching tuples formed from kwargs item pairs.
_kwargs_separator = (object(),)
def stable_args(*args, **kwargs):
"""A key factory that performs a stable sort of the parameters."""
key = args
if kwargs:
key += _kwargs_separator + tuple(sorted(kwargs.items()))
return key
def memoized(func):
"""Decorator that caches the results of a function, storing them in
an attribute of that function.
@@ -172,15 +186,23 @@ def memoized(func):
func.cache = {}
@functools.wraps(func)
def _memoized_function(*args):
if not isinstance(args, Hashable):
# Not hashable, so just call the function.
return func(*args)
def _memoized_function(*args, **kwargs):
key = stable_args(*args, **kwargs)
if args not in func.cache:
func.cache[args] = func(*args)
return func.cache[args]
try:
return func.cache[key]
except KeyError:
ret = func(*args, **kwargs)
func.cache[key] = ret
return ret
except TypeError as e:
# TypeError is raised when indexing into a dict if the key is unhashable.
raise six.raise_from(
UnhashableArguments(
"args + kwargs '{}' was not hashable for function '{}'"
.format(key, func.__name__),
),
e)
return _memoized_function
@@ -930,3 +952,7 @@ def nullcontext(*args, **kwargs):
TODO: replace with contextlib.nullcontext() if we ever require python 3.7.
"""
yield
class UnhashableArguments(TypeError):
"""Raise when an @memoized function receives unhashable arg or kwarg values."""