make @llnl.util.lang.memoized support kwargs (#21722)
* make memoized() support kwargs * add testing for @memoized
This commit is contained in:
@@ -13,9 +13,10 @@
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import six
|
||||
from six import string_types
|
||||
|
||||
from llnl.util.compat import Hashable, MutableMapping, zip_longest
|
||||
from llnl.util.compat import MutableMapping, zip_longest
|
||||
|
||||
# Ignore emacs backups when listing modules
|
||||
ignore_modules = [r'^\.#', '~$']
|
||||
@@ -165,6 +166,19 @@ def union_dicts(*dicts):
|
||||
return result
|
||||
|
||||
|
||||
# Used as a sentinel that disambiguates tuples passed in *args from coincidentally
|
||||
# matching tuples formed from kwargs item pairs.
|
||||
_kwargs_separator = (object(),)
|
||||
|
||||
|
||||
def stable_args(*args, **kwargs):
|
||||
"""A key factory that performs a stable sort of the parameters."""
|
||||
key = args
|
||||
if kwargs:
|
||||
key += _kwargs_separator + tuple(sorted(kwargs.items()))
|
||||
return key
|
||||
|
||||
|
||||
def memoized(func):
|
||||
"""Decorator that caches the results of a function, storing them in
|
||||
an attribute of that function.
|
||||
@@ -172,15 +186,23 @@ def memoized(func):
|
||||
func.cache = {}
|
||||
|
||||
@functools.wraps(func)
|
||||
def _memoized_function(*args):
|
||||
if not isinstance(args, Hashable):
|
||||
# Not hashable, so just call the function.
|
||||
return func(*args)
|
||||
def _memoized_function(*args, **kwargs):
|
||||
key = stable_args(*args, **kwargs)
|
||||
|
||||
if args not in func.cache:
|
||||
func.cache[args] = func(*args)
|
||||
|
||||
return func.cache[args]
|
||||
try:
|
||||
return func.cache[key]
|
||||
except KeyError:
|
||||
ret = func(*args, **kwargs)
|
||||
func.cache[key] = ret
|
||||
return ret
|
||||
except TypeError as e:
|
||||
# TypeError is raised when indexing into a dict if the key is unhashable.
|
||||
raise six.raise_from(
|
||||
UnhashableArguments(
|
||||
"args + kwargs '{}' was not hashable for function '{}'"
|
||||
.format(key, func.__name__),
|
||||
),
|
||||
e)
|
||||
|
||||
return _memoized_function
|
||||
|
||||
@@ -930,3 +952,7 @@ def nullcontext(*args, **kwargs):
|
||||
TODO: replace with contextlib.nullcontext() if we ever require python 3.7.
|
||||
"""
|
||||
yield
|
||||
|
||||
|
||||
class UnhashableArguments(TypeError):
|
||||
"""Raise when an @memoized function receives unhashable arg or kwarg values."""
|
||||
|
Reference in New Issue
Block a user