diff --git a/lib/spack/spack/repo_migrate.py b/lib/spack/spack/repo_migrate.py index 4fdce933cc3..3827254a609 100644 --- a/lib/spack/spack/repo_migrate.py +++ b/lib/spack/spack/repo_migrate.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: (Apache-2.0 OR MIT) import ast +import difflib import os import re import shutil @@ -82,7 +83,8 @@ def migrate_v1_to_v2( errors = False - stack: List[Tuple[str, int]] = [(repo.root, 0)] + stack: List[Tuple[str, int]] = [(repo.packages_path, 0)] + while stack: path, depth = stack.pop() @@ -112,11 +114,7 @@ def migrate_v1_to_v2( continue # check if this is a package - if ( - depth == 1 - and rel_path.startswith(f"{subdirectory}{os.sep}") - and os.path.exists(os.path.join(entry.path, "package.py")) - ): + if depth == 0 and os.path.exists(os.path.join(entry.path, "package.py")): if "_" in entry.name: print( f"Invalid package name '{entry.name}': underscores are not allowed in " @@ -144,7 +142,7 @@ def migrate_v1_to_v2( rename_regex = re.compile("^(" + "|".join(re.escape(k) for k in rename.keys()) + ")") if fix: - os.makedirs(new_root, exist_ok=True) + os.makedirs(os.path.join(new_root, repo.subdirectory), exist_ok=True) def _relocate(rel_path: str) -> Tuple[str, str]: old = os.path.join(repo.root, rel_path) @@ -223,6 +221,16 @@ def _relocate(rel_path: str) -> Tuple[str, str]: return result, (updated_repo if fix else None) +def _spack_pkg_to_spack_repo(modulename: str) -> str: + # rewrite spack.pkg.builtin.foo -> spack_repo.builtin.packages.foo.package + parts = modulename.split(".") + assert parts[:2] == ["spack", "pkg"] + parts[0:2] = ["spack_repo"] + parts.insert(2, "packages") + parts.append("package") + return ".".join(parts) + + def migrate_v2_imports( packages_dir: str, root: str, fix: bool, out: IO[str] = sys.stdout, err: IO[str] = sys.stderr ) -> bool: @@ -299,12 +307,41 @@ def migrate_v2_imports( #: Set of symbols of interest that are already defined through imports, assignments, or #: function definitions. defined_symbols: Set[str] = set() - best_line: Optional[int] = None - seen_import = False + module_replacements: Dict[str, str] = {} + parent: Dict[int, ast.AST] = {} + + #: List of (line, col start, old, new) tuples of strings to be replaced inline. + inline_updates: List[Tuple[int, int, str, str]] = [] + + #: List of (line from, line to, new lines) tuples of line replacements + multiline_updates: List[Tuple[int, int, List[str]]] = [] + + with open(pkg_path, "r", encoding="utf-8", newline="") as file: + original_lines = file.readlines() + + if len(original_lines) < 2: # assume packagepy files have at least 2 lines... + continue + + if original_lines[0].endswith("\r\n"): + newline = "\r\n" + elif original_lines[0].endswith("\n"): + newline = "\n" + elif original_lines[0].endswith("\r"): + newline = "\r" + else: + success = False + print(f"{pkg_path}: unknown line ending, cannot fix", file=err) + continue + + updated_lines = original_lines.copy() for node in ast.walk(tree): + for child in ast.iter_child_nodes(node): + if isinstance(child, ast.Attribute): + parent[id(child)] = node + # Get the last import statement from the first block of top-level imports if isinstance(node, ast.Module): for child in ast.iter_child_nodes(node): @@ -353,12 +390,89 @@ def migrate_v2_imports( elif isinstance(node, ast.Name) and node.id in symbol_to_module: referenced_symbols.add(node.id) - # Register imported symbols to make this operation idempotent + # Find lines where spack.pkg is used. + elif ( + isinstance(node, ast.Attribute) + and isinstance(node.value, ast.Name) + and node.value.id == "spack" + and node.attr == "pkg" + ): + # go as many attrs up until we reach a known module name to be replaced + known_module = "spack.pkg" + ancestor = node + while True: + next_parent = parent.get(id(ancestor)) + if next_parent is None or not isinstance(next_parent, ast.Attribute): + break + ancestor = next_parent + known_module = f"{known_module}.{ancestor.attr}" + if known_module in module_replacements: + break + + inline_updates.append( + ( + ancestor.lineno, + ancestor.col_offset, + known_module, + module_replacements[known_module], + ) + ) + elif isinstance(node, ast.ImportFrom): + # Keep track of old style spack.pkg imports, to be replaced. + if node.module and node.module.startswith("spack.pkg.") and node.level == 0: + + depth = node.module.count(".") + + # simple case of find and replace + # from spack.pkg.builtin.my_pkg import MyPkg + # -> from spack_repo.builtin.packages.my_pkg.package import MyPkg + if depth == 3: + module_replacements[node.module] = _spack_pkg_to_spack_repo(node.module) + inline_updates.append( + ( + node.lineno, + node.col_offset, + node.module, + module_replacements[node.module], + ) + ) + + # non-trivial possible multiline case + # from spack.pkg.builtin import (boost, cmake as foo) + # -> import spack_repo.builtin.packages.boost.package as boost + # -> import spack_repo.builtin.packages.cmake.package as foo + elif depth == 2 and node.end_lineno is not None: + _, _, namespace = node.module.rpartition(".") + indent = original_lines[node.lineno - 1][: node.col_offset] + multiline_updates.append( + ( + node.lineno, + node.end_lineno + 1, + [ + f"{indent}import spack_repo.{namespace}.packages." + f"{alias.name}.package as {alias.asname or alias.name}" + f"{newline}" + for alias in node.names + ], + ) + ) + + else: + success = False + print( + f"{pkg_path}:{node.lineno}: don't know how to rewrite `{node.module}`", + file=err, + ) + + # Subtract the symbols that are imported so we don't repeatedly add imports. for alias in node.names: if alias.name in symbol_to_module: - defined_symbols.add(alias.name) - if node.module == "spack.package": + if alias.asname is None: + defined_symbols.add(alias.name) + + # error when symbols are explicitly imported that are no longer available + if node.module == "spack.package" and node.level == 0: success = False print( f"{pkg_path}:{node.lineno}: `{alias.name}` is imported from " @@ -369,59 +483,84 @@ def migrate_v2_imports( if alias.asname and alias.asname in symbol_to_module: defined_symbols.add(alias.asname) + elif isinstance(node, ast.Import): + # normal imports are easy find and replace since they are single lines. + for alias in node.names: + if alias.asname and alias.asname in symbol_to_module: + defined_symbols.add(alias.name) + elif alias.asname is None and alias.name.startswith("spack.pkg."): + module_replacements[alias.name] = _spack_pkg_to_spack_repo(alias.name) + inline_updates.append( + ( + alias.lineno, + alias.col_offset, + alias.name, + module_replacements[alias.name], + ) + ) + # Remove imported symbols from the referenced symbols referenced_symbols.difference_update(defined_symbols) - if not referenced_symbols: + # Sort from last to first so we can modify without messing up the line / col offsets + inline_updates.sort(reverse=True) + + # Nothing to change here. + if not inline_updates and not referenced_symbols: continue - if best_line is None: - print(f"{pkg_path}: failed to update imports", file=err) - success = False - continue + # First do module replacements of spack.pkg imports + for line, col, old, new in inline_updates: + updated_lines[line - 1] = updated_lines[line - 1][:col] + updated_lines[line - 1][ + col: + ].replace(old, new, 1) - # Add the missing imports right after the last import statement - with open(pkg_path, "r", encoding="utf-8", newline="") as file: - lines = file.readlines() + # Then insert new imports for symbols referenced in the package + if referenced_symbols: + if best_line is None: + print(f"{pkg_path}: failed to update imports", file=err) + success = False + continue - # Group missing symbols by their module - missing_imports_by_module: Dict[str, list] = {} - for symbol in referenced_symbols: - module = symbol_to_module[symbol] - if module not in missing_imports_by_module: - missing_imports_by_module[module] = [] - missing_imports_by_module[module].append(symbol) + # Group missing symbols by their module + missing_imports_by_module: Dict[str, list] = {} + for symbol in referenced_symbols: + module = symbol_to_module[symbol] + if module not in missing_imports_by_module: + missing_imports_by_module[module] = [] + missing_imports_by_module[module].append(symbol) - new_lines = [ - f"from {module} import {', '.join(sorted(symbols))}\n" - for module, symbols in sorted(missing_imports_by_module.items()) - ] + new_lines = [ + f"from {module} import {', '.join(sorted(symbols))}{newline}" + for module, symbols in sorted(missing_imports_by_module.items()) + ] - if not seen_import: - new_lines.extend(("\n", "\n")) + if not seen_import: + new_lines.extend((newline, newline)) - if not fix: # only print the diff - success = False # packages need to be fixed, but we didn't do it - diff_start, diff_end = max(1, best_line - 3), min(best_line + 2, len(lines)) - num_changed = diff_end - diff_start + 1 - num_added = num_changed + len(new_lines) + multiline_updates.append((best_line, best_line, new_lines)) + + multiline_updates.sort(reverse=True) + for start, end, new_lines in multiline_updates: + updated_lines[start - 1 : end - 1] = new_lines + + if not fix: rel_pkg_path = os.path.relpath(pkg_path, start=root) - out.write(f"--- a/{rel_pkg_path}\n+++ b/{rel_pkg_path}\n") - out.write(f"@@ -{diff_start},{num_changed} +{diff_start},{num_added} @@\n") - for line in lines[diff_start - 1 : best_line - 1]: - out.write(f" {line}") - for line in new_lines: - out.write(f"+{line}") - for line in lines[best_line - 1 : diff_end]: - out.write(f" {line}") + diff = difflib.unified_diff( + original_lines, + updated_lines, + n=3, + fromfile=f"a/{rel_pkg_path}", + tofile=f"b/{rel_pkg_path}", + ) + out.write("".join(diff)) continue - lines[best_line - 1 : best_line - 1] = new_lines - tmp_file = pkg_path + ".tmp" - with open(tmp_file, "w", encoding="utf-8", newline="") as file: - file.writelines(lines) + # binary mode to avoid newline conversion issues; utf-8 was already required upon read. + with open(tmp_file, "wb") as file: + file.write("".join(updated_lines).encode("utf-8")) os.replace(tmp_file, pkg_path) diff --git a/lib/spack/spack/test/cmd/repo.py b/lib/spack/spack/test/cmd/repo.py index c7019f35682..b5f9a43dc4f 100644 --- a/lib/spack/spack/test/cmd/repo.py +++ b/lib/spack/spack/test/cmd/repo.py @@ -95,24 +95,47 @@ class _7zip(Package): pass """ -OLD_NUMPY = b"""\ -# some comment +# this is written like this to be explicit about line endings and indentation +OLD_NUMPY = ( + b"# some comment\r\n" + b"\r\n" + b"import spack.pkg.builtin.foo, spack.pkg.builtin.bar\r\n" + b"from spack.package import *\r\n" + b"from something.unrelated import AutotoolsPackage\r\n" + b"\r\n" + b"if True:\r\n" + b"\tfrom spack.pkg.builtin import (\r\n" + b"\t\tfoo,\r\n" + b"\t\tbar as baz,\r\n" + b"\t)\r\n" + b"\r\n" + b"class PyNumpy(CMakePackage, AutotoolsPackage):\r\n" + b"\tgenerator('ninja')\r\n" + b"\r\n" + b"\tdef example(self):\r\n" + b"\t\t# unchanged comment: spack.pkg.builtin.foo.something\r\n" + b"\t\treturn spack.pkg.builtin.foo.example(), foo, baz\r\n" +) -from spack.package import * - -class PyNumpy(CMakePackage): - generator("ninja") -""" - -NEW_NUMPY = b"""\ -# some comment - -from spack_repo.builtin.build_systems.cmake import CMakePackage, generator -from spack.package import * - -class PyNumpy(CMakePackage): - generator("ninja") -""" +NEW_NUMPY = ( + b"# some comment\r\n" + b"\r\n" + b"import spack_repo.builtin.packages.foo.package, spack_repo.builtin.packages.bar.package\r\n" + b"from spack_repo.builtin.build_systems.cmake import CMakePackage, generator\r\n" + b"from spack.package import *\r\n" + b"from something.unrelated import AutotoolsPackage\r\n" + b"\r\n" + b"if True:\r\n" + b"\timport spack_repo.builtin.packages.foo.package as foo\r\n" + b"\timport spack_repo.builtin.packages.bar.package as baz\r\n" + b"\r\n" + b"class PyNumpy(CMakePackage, AutotoolsPackage):\r\n" + b"\tgenerator('ninja')\r\n" + b"\r\n" + b"\tdef example(self):\r\n" + b"\t\t# unchanged comment: spack.pkg.builtin.foo.something\r\n" + b"\t\treturn spack_repo.builtin.packages.foo.package.example(), foo, baz\r\n" +) def test_repo_migrate(tmp_path: pathlib.Path, config): @@ -142,7 +165,6 @@ def test_repo_migrate(tmp_path: pathlib.Path, config): assert pkg_py_numpy_new.read_bytes() == NEW_NUMPY -@pytest.mark.not_on_windows("Known failure on windows") def test_migrate_diff(git: Executable, tmp_path: pathlib.Path): root, _ = spack.repo.create_repo(str(tmp_path), "foo", package_api=(2, 0)) r = pathlib.Path(root)