make @llnl.util.lang.memoized support kwargs (#21722)
* make memoized() support kwargs * add testing for @memoized
This commit is contained in:
		| @@ -13,9 +13,10 @@ | ||||
| import sys | ||||
| from datetime import datetime, timedelta | ||||
| 
 | ||||
| import six | ||||
| from six import string_types | ||||
| 
 | ||||
| from llnl.util.compat import Hashable, MutableMapping, zip_longest | ||||
| from llnl.util.compat import MutableMapping, zip_longest | ||||
| 
 | ||||
| # Ignore emacs backups when listing modules | ||||
| ignore_modules = [r'^\.#', '~$'] | ||||
| @@ -165,6 +166,19 @@ def union_dicts(*dicts): | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| # Used as a sentinel that disambiguates tuples passed in *args from coincidentally | ||||
| # matching tuples formed from kwargs item pairs. | ||||
| _kwargs_separator = (object(),) | ||||
| 
 | ||||
| 
 | ||||
| def stable_args(*args, **kwargs): | ||||
|     """A key factory that performs a stable sort of the parameters.""" | ||||
|     key = args | ||||
|     if kwargs: | ||||
|         key += _kwargs_separator + tuple(sorted(kwargs.items())) | ||||
|     return key | ||||
| 
 | ||||
| 
 | ||||
| def memoized(func): | ||||
|     """Decorator that caches the results of a function, storing them in | ||||
|     an attribute of that function. | ||||
| @@ -172,15 +186,23 @@ def memoized(func): | ||||
|     func.cache = {} | ||||
| 
 | ||||
|     @functools.wraps(func) | ||||
|     def _memoized_function(*args): | ||||
|         if not isinstance(args, Hashable): | ||||
|             # Not hashable, so just call the function. | ||||
|             return func(*args) | ||||
|     def _memoized_function(*args, **kwargs): | ||||
|         key = stable_args(*args, **kwargs) | ||||
| 
 | ||||
|         if args not in func.cache: | ||||
|             func.cache[args] = func(*args) | ||||
| 
 | ||||
|         return func.cache[args] | ||||
|         try: | ||||
|             return func.cache[key] | ||||
|         except KeyError: | ||||
|             ret = func(*args, **kwargs) | ||||
|             func.cache[key] = ret | ||||
|             return ret | ||||
|         except TypeError as e: | ||||
|             # TypeError is raised when indexing into a dict if the key is unhashable. | ||||
|             raise six.raise_from( | ||||
|                 UnhashableArguments( | ||||
|                     "args + kwargs '{}' was not hashable for function '{}'" | ||||
|                     .format(key, func.__name__), | ||||
|                 ), | ||||
|                 e) | ||||
| 
 | ||||
|     return _memoized_function | ||||
| 
 | ||||
| @@ -930,3 +952,7 @@ def nullcontext(*args, **kwargs): | ||||
|     TODO: replace with contextlib.nullcontext() if we ever require python 3.7. | ||||
|     """ | ||||
|     yield | ||||
| 
 | ||||
| 
 | ||||
| class UnhashableArguments(TypeError): | ||||
|     """Raise when an @memoized function receives unhashable arg or kwarg values.""" | ||||
|   | ||||
| @@ -10,7 +10,7 @@ | ||||
| import pytest | ||||
| 
 | ||||
| import llnl.util.lang | ||||
| from llnl.util.lang import match_predicate, pretty_date | ||||
| from llnl.util.lang import match_predicate, memoized, pretty_date, stable_args | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture() | ||||
| @@ -205,3 +205,63 @@ def _cmp_key(self): | ||||
|     assert hash(a) == hash(a2) | ||||
|     assert hash(b) == hash(b) | ||||
|     assert hash(b) == hash(b2) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "args1,kwargs1,args2,kwargs2", | ||||
|     [ | ||||
|         # Ensure tuples passed in args are disambiguated from equivalent kwarg items. | ||||
|         (('a', 3), {}, (), {'a': 3}) | ||||
|     ], | ||||
| ) | ||||
| def test_unequal_args(args1, kwargs1, args2, kwargs2): | ||||
|     assert stable_args(*args1, **kwargs1) != stable_args(*args2, **kwargs2) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "args1,kwargs1,args2,kwargs2", | ||||
|     [ | ||||
|         # Ensure that kwargs are stably sorted. | ||||
|         ((), {'a': 3, 'b': 4}, (), {'b': 4, 'a': 3}), | ||||
|     ], | ||||
| ) | ||||
| def test_equal_args(args1, kwargs1, args2, kwargs2): | ||||
|     assert stable_args(*args1, **kwargs1) == stable_args(*args2, **kwargs2) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "args, kwargs", | ||||
|     [ | ||||
|         ((1,), {}), | ||||
|         ((), {'a': 3}), | ||||
|         ((1,), {'a': 3}), | ||||
|     ], | ||||
| ) | ||||
| def test_memoized(args, kwargs): | ||||
|     @memoized | ||||
|     def f(*args, **kwargs): | ||||
|         return 'return-value' | ||||
|     assert f(*args, **kwargs) == 'return-value' | ||||
|     key = stable_args(*args, **kwargs) | ||||
|     assert list(f.cache.keys()) == [key] | ||||
|     assert f.cache[key] == 'return-value' | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "args, kwargs", | ||||
|     [ | ||||
|         (([1],), {}), | ||||
|         ((), {'a': [1]}) | ||||
|     ], | ||||
| ) | ||||
| def test_memoized_unhashable(args, kwargs): | ||||
|     """Check that an exception is raised clearly""" | ||||
|     @memoized | ||||
|     def f(*args, **kwargs): | ||||
|         return None | ||||
|     with pytest.raises(llnl.util.lang.UnhashableArguments) as exc_info: | ||||
|         f(*args, **kwargs) | ||||
|     exc_msg = str(exc_info.value) | ||||
|     key = stable_args(*args, **kwargs) | ||||
|     assert str(key) in exc_msg | ||||
|     assert "function 'f'" in exc_msg | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Danny McClanahan
					Danny McClanahan