package_hash: fix handling of multimethods and add tests

Package hashing was not properly handling multimethods. In particular, it was removing
any functions that had decorators from the output, so we'd miss things like
`@run_after("install")`, etc.

There were also problems with handling multiple `@when`'s in a single file, and with
handling `@when` functions that *had* to be evaluated dynamically.

- [x] Rework static `@when` resolution for package hash
- [x] Ensure that functions with decorators are not removed from output
- [x] Add tests for many different @when scenarios (multiple @when's,
      combining with other decorators, default/no default, etc.)

Co-authored-by: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com>
This commit is contained in:
Todd Gamblin 2021-12-23 00:53:44 -08:00 committed by Greg Becker
parent 93a6c51d88
commit 800229a448
2 changed files with 260 additions and 34 deletions

View File

@ -150,3 +150,133 @@ def test_remove_directives():
for name in spack.directives.directive_names: for name in spack.directives.directive_names:
assert name not in unparsed assert name not in unparsed
many_multimethods = """\
class Pkg:
def foo(self):
print("ONE")
@when("@1.0")
def foo(self):
print("TWO")
@when("@2.0")
@when(sys.platform == "darwin")
def foo(self):
print("THREE")
@when("@3.0")
def foo(self):
print("FOUR")
# this one should always stay
@run_after("install")
def some_function(self):
print("FIVE")
"""
def test_multimethod_resolution(tmpdir):
when_pkg = tmpdir.join("pkg.py")
with when_pkg.open("w") as f:
f.write(many_multimethods)
# all are false but the default
filtered = ph.canonical_source("pkg@4.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" not in filtered
assert "THREE" not in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered
# we know first @when overrides default and others are false
filtered = ph.canonical_source("pkg@1.0", str(when_pkg))
assert "ONE" not in filtered
assert "TWO" in filtered
assert "THREE" not in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered
# we know last @when overrides default and others are false
filtered = ph.canonical_source("pkg@3.0", str(when_pkg))
assert "ONE" not in filtered
assert "TWO" not in filtered
assert "THREE" not in filtered
assert "FOUR" in filtered
assert "FIVE" in filtered
# we don't know if default or THREE will win, include both
filtered = ph.canonical_source("pkg@2.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" not in filtered
assert "THREE" in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered
more_dynamic_multimethods = """\
class Pkg:
@when(sys.platform == "darwin")
def foo(self):
print("ONE")
@when("@1.0")
def foo(self):
print("TWO")
# this one isn't dynamic, but an int fails the Spec parse,
# so it's kept because it has to be evaluated at runtime.
@when("@2.0")
@when(1)
def foo(self):
print("THREE")
@when("@3.0")
def foo(self):
print("FOUR")
# this one should always stay
@run_after("install")
def some_function(self):
print("FIVE")
"""
def test_more_dynamic_multimethod_resolution(tmpdir):
when_pkg = tmpdir.join("pkg.py")
with when_pkg.open("w") as f:
f.write(more_dynamic_multimethods)
# we know the first one is the only one that can win.
filtered = ph.canonical_source("pkg@4.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" not in filtered
assert "THREE" not in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered
# now we have to include ONE and TWO because ONE may win dynamically.
filtered = ph.canonical_source("pkg@1.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" in filtered
assert "THREE" not in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered
# we know FOUR is true and TWO and THREE are false, but ONE may
# still win dynamically.
filtered = ph.canonical_source("pkg@3.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" not in filtered
assert "THREE" not in filtered
assert "FOUR" in filtered
assert "FIVE" in filtered
# TWO and FOUR can't be satisfied, but ONE or THREE could win
filtered = ph.canonical_source("pkg@2.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" not in filtered
assert "THREE" in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered

View File

@ -11,7 +11,9 @@
import spack.package import spack.package
import spack.repo import spack.repo
import spack.spec import spack.spec
import spack.util.hash
import spack.util.naming import spack.util.naming
from spack.util.unparse import unparse
class RemoveDocstrings(ast.NodeTransformer): class RemoveDocstrings(ast.NodeTransformer):
@ -82,70 +84,164 @@ def visit_ClassDef(self, node): # noqa
class TagMultiMethods(ast.NodeVisitor): class TagMultiMethods(ast.NodeVisitor):
"""Tag @when-decorated methods in a spec.""" """Tag @when-decorated methods in a package AST."""
def __init__(self, spec): def __init__(self, spec):
self.spec = spec self.spec = spec
# map from function name to (implementation, condition_list) tuples
self.methods = {} self.methods = {}
def visit_FunctionDef(self, node): # noqa def visit_FunctionDef(self, func): # noqa
nodes = self.methods.setdefault(node.name, []) conditions = []
if node.decorator_list: for dec in func.decorator_list:
dec = node.decorator_list[0]
if isinstance(dec, ast.Call) and dec.func.id == 'when': if isinstance(dec, ast.Call) and dec.func.id == 'when':
try: try:
# evaluate spec condition for any when's
cond = dec.args[0].s cond = dec.args[0].s
nodes.append( conditions.append(self.spec.satisfies(cond, strict=True))
(node, self.spec.satisfies(cond, strict=True)))
except AttributeError: except AttributeError:
# In this case the condition for the 'when' decorator is # In this case the condition for the 'when' decorator is
# not a string literal (for example it may be a Python # not a string literal (for example it may be a Python
# variable name). Therefore the function is added # variable name). We append None because we don't know
# unconditionally since we don't know whether the # whether the constraint applies or not, and it should be included
# constraint applies or not. # unless some other constraint is False.
nodes.append((node, None)) conditions.append(None)
else:
nodes.append((node, None)) # anything defined without conditions will overwrite prior definitions
if not conditions:
self.methods[func.name] = []
# add all discovered conditions on this node to the node list
impl_conditions = self.methods.setdefault(func.name, [])
impl_conditions.append((func, conditions))
# don't modify the AST -- return the untouched function node
return func
class ResolveMultiMethods(ast.NodeTransformer): class ResolveMultiMethods(ast.NodeTransformer):
"""Remove methods which do not exist if their @when is not satisfied.""" """Remove multi-methods when we know statically that they won't be used.
Say we have multi-methods like this::
class SomePackage:
def foo(self): print("implementation 1")
@when("@1.0")
def foo(self): print("implementation 2")
@when("@2.0")
@when(sys.platform == "darwin")
def foo(self): print("implementation 3")
@when("@3.0")
def foo(self): print("implementation 4")
The multimethod that will be chosen at runtime depends on the package spec and on
whether we're on the darwin platform *at build time* (the darwin condition for
implementation 3 is dynamic). We know the package spec statically; we don't know
statically what the runtime environment will be. We need to include things that can
possibly affect package behavior in the package hash, and we want to exclude things
when we know that they will not affect package behavior.
If we're at version 4.0, we know that implementation 1 will win, because some @when
for 2, 3, and 4 will be `False`. We should only include implementation 1.
If we're at version 1.0, we know that implementation 2 will win, because it
overrides implementation 1. We should only include implementation 2.
If we're at version 3.0, we know that implementation 4 will win, because it
overrides implementation 1 (the default), and some @when on all others will be
False.
If we're at version 2.0, it's a bit more complicated. We know we can remove
implementations 2 and 4, because their @when's will never be satisfied. But, the
choice between implementations 1 and 3 will happen at runtime (this is a bad example
because the spec itself has platform information, and we should prefer to use that,
but we allow arbitrary boolean expressions in @when's, so this example suffices).
For this case, we end up needing to include *both* implementation 1 and 3 in the
package hash, because either could be chosen.
"""
def __init__(self, methods): def __init__(self, methods):
self.methods = methods self.methods = methods
def resolve(self, node): def resolve(self, impl_conditions):
if node.name not in self.methods: """Given list of nodes and conditions, figure out which node will be chosen."""
raise PackageHashError( result = []
"Future traversal visited new node: %s" % node.name) default = None
for impl, conditions in impl_conditions:
# if there's a default implementation with no conditions, remember that.
if not conditions:
default = impl
result.append(default)
continue
result = None # any known-false @when means the method won't be used
for n, cond in self.methods[node.name]: if any(c is False for c in conditions):
if cond: continue
return n
if cond is None: # anything with all known-true conditions will be picked if it's first
result = n if all(c is True for c in conditions):
if result and result[0] is default:
return [impl] # we know the first MM will always win
# if anything dynamic comes before it we don't know if it'll win,
# so just let this result get appended
# anything else has to be determined dynamically, so add it to a list
result.append(impl)
# if nothing was picked, the last definition wins.
return result return result
def visit_FunctionDef(self, node): # noqa def visit_FunctionDef(self, func): # noqa
if self.resolve(node) is node: # if the function def wasn't visited on the first traversal there is a problem
node.decorator_list = [] assert func.name in self.methods, "Inconsistent package traversal!"
return node
return None # if the function is a multimethod, need to resolve it statically
impl_conditions = self.methods[func.name]
resolutions = self.resolve(impl_conditions)
if not any(r is func for r in resolutions):
# multimethod did not resolve to this function; remove it
return None
# if we get here, this function is a possible resolution for a multi-method.
# it might be the only one, or there might be several that have to be evaluated
# dynamcially. Either way, we include the function.
# strip the when decorators (preserve the rest)
func.decorator_list = [
dec for dec in func.decorator_list
if not (isinstance(dec, ast.Call) and dec.func.id == 'when')
]
return func
def package_content(spec): def package_content(spec):
return ast.dump(package_ast(spec)) return ast.dump(package_ast(spec))
def canonical_source(spec, filename=None):
return unparse(package_ast(spec, filename=filename), py_ver_consistent=True)
def canonical_source_hash(spec, filename=None):
source = canonical_source(spec, filename)
return spack.util.hash.b32_hash(source)
def package_hash(spec, content=None): def package_hash(spec, content=None):
if content is None: if content is None:
content = package_content(spec) content = package_content(spec)
return hashlib.sha256(content.encode('utf-8')).digest().lower() return hashlib.sha256(content.encode('utf-8')).digest().lower()
def package_ast(spec): def package_ast(spec, filename=None):
spec = spack.spec.Spec(spec) spec = spack.spec.Spec(spec)
filename = spack.repo.path.filename_for_package_name(spec.name) if not filename:
filename = spack.repo.path.filename_for_package_name(spec.name)
with open(filename) as f: with open(filename) as f:
text = f.read() text = f.read()
root = ast.parse(text) root = ast.parse(text)
@ -154,10 +250,10 @@ def package_ast(spec):
RemoveDirectives(spec).visit(root) RemoveDirectives(spec).visit(root)
fmm = TagMultiMethods(spec) tagger = TagMultiMethods(spec)
fmm.visit(root) tagger.visit(root)
root = ResolveMultiMethods(fmm.methods).visit(root) root = ResolveMultiMethods(tagger.methods).visit(root)
return root return root