multimethod: slight refactoring, documentation for code review
This commit is contained in:
parent
2621af41d1
commit
b072c9b457
@ -95,32 +95,38 @@ def __get__(self, obj, objtype):
|
|||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
def _get_method_by_spec(self, spec):
|
||||||
|
"""Find the method of this SpecMultiMethod object that satisfies the
|
||||||
|
given spec, if one exists
|
||||||
|
"""
|
||||||
|
for condition, method in self.method_list:
|
||||||
|
if spec.satisfies(condition):
|
||||||
|
return method
|
||||||
|
return self.default or None
|
||||||
|
|
||||||
def __call__(self, package_self, *args, **kwargs):
|
def __call__(self, package_self, *args, **kwargs):
|
||||||
"""Find the first method with a spec that matches the
|
"""Find the first method with a spec that matches the
|
||||||
package's spec. If none is found, call the default
|
package's spec. If none is found, call the default
|
||||||
or if there is none, then raise a NoSuchMethodError.
|
or if there is none, then raise a NoSuchMethodError.
|
||||||
"""
|
"""
|
||||||
for spec, method in self.method_list:
|
spec_method = self._get_method_by_spec(package_self.spec)
|
||||||
if package_self.spec.satisfies(spec):
|
if spec_method:
|
||||||
return method(package_self, *args, **kwargs)
|
return spec_method(package_self, *args, **kwargs)
|
||||||
|
# Unwrap the MRO of `package_self by hand. Note that we can't
|
||||||
if self.default:
|
# use `super()` here, because using `super()` recursively
|
||||||
return self.default(package_self, *args, **kwargs)
|
# requires us to know the class of `package_self`, as well as
|
||||||
|
# its superclasses for successive calls. We don't have that
|
||||||
else:
|
# information within `SpecMultiMethod`, because it is not
|
||||||
# Unwrap MRO by hand because super binds to the subclass
|
# associated with the package class.
|
||||||
# and causes infinite recursion for inherited methods
|
|
||||||
for cls in inspect.getmro(package_self.__class__)[1:]:
|
for cls in inspect.getmro(package_self.__class__)[1:]:
|
||||||
superself = cls.__dict__.get(self.__name__, None)
|
superself = cls.__dict__.get(self.__name__, None)
|
||||||
if isinstance(superself, self.__class__):
|
if isinstance(superself, SpecMultiMethod):
|
||||||
# Parent class method is a multimethod
|
# Check parent multimethod for method for spec.
|
||||||
# check it locally for methods, conditional or default
|
superself_method = superself._get_method_by_spec(
|
||||||
# Do not recurse, that will mess up MRO
|
package_self.spec
|
||||||
for spec, method in superself.method_list:
|
)
|
||||||
if package_self.spec.satisfies(spec):
|
if superself_method:
|
||||||
return method(package_self, *args, **kwargs)
|
return superself_method(package_self, *args, **kwargs)
|
||||||
if superself.default:
|
|
||||||
return superself.default(package_self, *args, **kwargs)
|
|
||||||
elif superself:
|
elif superself:
|
||||||
return superself(package_self, *args, **kwargs)
|
return superself(package_self, *args, **kwargs)
|
||||||
|
|
||||||
@ -129,10 +135,6 @@ def __call__(self, package_self, *args, **kwargs):
|
|||||||
[m[0] for m in self.method_list]
|
[m[0] for m in self.method_list]
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "SpecMultiMethod {\n\tdefault: %s,\n\tspecs: %s\n}" % (
|
|
||||||
self.default, self.method_list)
|
|
||||||
|
|
||||||
|
|
||||||
class when(object):
|
class when(object):
|
||||||
"""This annotation lets packages declare multiple versions of
|
"""This annotation lets packages declare multiple versions of
|
||||||
@ -193,13 +195,11 @@ def install(self, prefix):
|
|||||||
around this because of the way decorators work.
|
around this because of the way decorators work.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, spec):
|
def __init__(self, condition):
|
||||||
if spec is True:
|
if isinstance(condition, bool):
|
||||||
self.spec = Spec()
|
self.spec = Spec() if condition else None
|
||||||
elif spec is not False:
|
|
||||||
self.spec = Spec(spec)
|
|
||||||
else:
|
else:
|
||||||
self.spec = None
|
self.spec = Spec(condition)
|
||||||
|
|
||||||
def __call__(self, method):
|
def __call__(self, method):
|
||||||
# Get the first definition of the method in the calling scope
|
# Get the first definition of the method in the calling scope
|
||||||
|
@ -3528,7 +3528,6 @@ def spec(self, name):
|
|||||||
self.check_identifier(spec_name)
|
self.check_identifier(spec_name)
|
||||||
|
|
||||||
if self._initial is None:
|
if self._initial is None:
|
||||||
# This will init the spec without calling Spec.__init__
|
|
||||||
spec = Spec()
|
spec = Spec()
|
||||||
else:
|
else:
|
||||||
# this is used by Spec.__init__
|
# this is used by Spec.__init__
|
||||||
|
@ -158,3 +158,9 @@ def test_multimethod_diamond_inheritance():
|
|||||||
|
|
||||||
pkg = spack.repo.get('multimethod-diamond@4.0')
|
pkg = spack.repo.get('multimethod-diamond@4.0')
|
||||||
assert pkg.diamond_inheritance() == 'subclass'
|
assert pkg.diamond_inheritance() == 'subclass'
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimethod_boolean(pkg_name):
|
||||||
|
pkg = spack.repo.get(pkg_name)
|
||||||
|
assert pkg.boolean_true_first() == 'True'
|
||||||
|
assert pkg.boolean_false_first() == 'True'
|
||||||
|
@ -18,4 +18,4 @@ def diamond_inheritance(self):
|
|||||||
|
|
||||||
@when('@4.0, 2.0')
|
@when('@4.0, 2.0')
|
||||||
def diamond_inheritance(self):
|
def diamond_inheritance(self):
|
||||||
return "should never be reached"
|
return "should never be reached by diamond inheritance test"
|
||||||
|
@ -148,4 +148,23 @@ def diamond_inheritance(self):
|
|||||||
|
|
||||||
@when('@4.0')
|
@when('@4.0')
|
||||||
def diamond_inheritance(self):
|
def diamond_inheritance(self):
|
||||||
return "should_not_be_reached"
|
return "should_not_be_reached by diamond inheritance test"
|
||||||
|
|
||||||
|
#
|
||||||
|
# Check that multimethods work with boolean values
|
||||||
|
#
|
||||||
|
@when(True)
|
||||||
|
def boolean_true_first(self):
|
||||||
|
return 'True'
|
||||||
|
|
||||||
|
@when(False)
|
||||||
|
def boolean_true_first(self):
|
||||||
|
return 'False'
|
||||||
|
|
||||||
|
@when(False)
|
||||||
|
def boolean_false_first(self):
|
||||||
|
return 'False'
|
||||||
|
|
||||||
|
@when(True)
|
||||||
|
def boolean_false_first(self):
|
||||||
|
return 'True'
|
||||||
|
Loading…
Reference in New Issue
Block a user