multimethod: slight refactoring, documentation for code review

This commit is contained in:
Gregory Becker 2018-12-05 20:12:47 -08:00 committed by Greg Becker
parent 2621af41d1
commit b072c9b457
5 changed files with 64 additions and 40 deletions

View File

@ -95,32 +95,38 @@ def __get__(self, obj, objtype):
) )
return func return func
def _get_method_by_spec(self, spec):
"""Find the method of this SpecMultiMethod object that satisfies the
given spec, if one exists
"""
for condition, method in self.method_list:
if spec.satisfies(condition):
return method
return self.default or None
def __call__(self, package_self, *args, **kwargs): def __call__(self, package_self, *args, **kwargs):
"""Find the first method with a spec that matches the """Find the first method with a spec that matches the
package's spec. If none is found, call the default package's spec. If none is found, call the default
or if there is none, then raise a NoSuchMethodError. or if there is none, then raise a NoSuchMethodError.
""" """
for spec, method in self.method_list: spec_method = self._get_method_by_spec(package_self.spec)
if package_self.spec.satisfies(spec): if spec_method:
return method(package_self, *args, **kwargs) return spec_method(package_self, *args, **kwargs)
# Unwrap the MRO of `package_self by hand. Note that we can't
if self.default: # use `super()` here, because using `super()` recursively
return self.default(package_self, *args, **kwargs) # requires us to know the class of `package_self`, as well as
# its superclasses for successive calls. We don't have that
else: # information within `SpecMultiMethod`, because it is not
# Unwrap MRO by hand because super binds to the subclass # associated with the package class.
# and causes infinite recursion for inherited methods
for cls in inspect.getmro(package_self.__class__)[1:]: for cls in inspect.getmro(package_self.__class__)[1:]:
superself = cls.__dict__.get(self.__name__, None) superself = cls.__dict__.get(self.__name__, None)
if isinstance(superself, self.__class__): if isinstance(superself, SpecMultiMethod):
# Parent class method is a multimethod # Check parent multimethod for method for spec.
# check it locally for methods, conditional or default superself_method = superself._get_method_by_spec(
# Do not recurse, that will mess up MRO package_self.spec
for spec, method in superself.method_list: )
if package_self.spec.satisfies(spec): if superself_method:
return method(package_self, *args, **kwargs) return superself_method(package_self, *args, **kwargs)
if superself.default:
return superself.default(package_self, *args, **kwargs)
elif superself: elif superself:
return superself(package_self, *args, **kwargs) return superself(package_self, *args, **kwargs)
@ -129,10 +135,6 @@ def __call__(self, package_self, *args, **kwargs):
[m[0] for m in self.method_list] [m[0] for m in self.method_list]
) )
def __str__(self):
return "SpecMultiMethod {\n\tdefault: %s,\n\tspecs: %s\n}" % (
self.default, self.method_list)
class when(object): class when(object):
"""This annotation lets packages declare multiple versions of """This annotation lets packages declare multiple versions of
@ -193,13 +195,11 @@ def install(self, prefix):
around this because of the way decorators work. around this because of the way decorators work.
""" """
def __init__(self, spec): def __init__(self, condition):
if spec is True: if isinstance(condition, bool):
self.spec = Spec() self.spec = Spec() if condition else None
elif spec is not False:
self.spec = Spec(spec)
else: else:
self.spec = None self.spec = Spec(condition)
def __call__(self, method): def __call__(self, method):
# Get the first definition of the method in the calling scope # Get the first definition of the method in the calling scope

View File

@ -3528,7 +3528,6 @@ def spec(self, name):
self.check_identifier(spec_name) self.check_identifier(spec_name)
if self._initial is None: if self._initial is None:
# This will init the spec without calling Spec.__init__
spec = Spec() spec = Spec()
else: else:
# this is used by Spec.__init__ # this is used by Spec.__init__

View File

@ -158,3 +158,9 @@ def test_multimethod_diamond_inheritance():
pkg = spack.repo.get('multimethod-diamond@4.0') pkg = spack.repo.get('multimethod-diamond@4.0')
assert pkg.diamond_inheritance() == 'subclass' assert pkg.diamond_inheritance() == 'subclass'
def test_multimethod_boolean(pkg_name):
pkg = spack.repo.get(pkg_name)
assert pkg.boolean_true_first() == 'True'
assert pkg.boolean_false_first() == 'True'

View File

@ -18,4 +18,4 @@ def diamond_inheritance(self):
@when('@4.0, 2.0') @when('@4.0, 2.0')
def diamond_inheritance(self): def diamond_inheritance(self):
return "should never be reached" return "should never be reached by diamond inheritance test"

View File

@ -148,4 +148,23 @@ def diamond_inheritance(self):
@when('@4.0') @when('@4.0')
def diamond_inheritance(self): def diamond_inheritance(self):
return "should_not_be_reached" return "should_not_be_reached by diamond inheritance test"
#
# Check that multimethods work with boolean values
#
@when(True)
def boolean_true_first(self):
return 'True'
@when(False)
def boolean_true_first(self):
return 'False'
@when(False)
def boolean_false_first(self):
return 'False'
@when(True)
def boolean_false_first(self):
return 'True'