Refactor to organize filter object

This commit is contained in:
Philip Sakievich 2025-04-16 22:10:09 -06:00 committed by Harmen Stoppels
parent 0240120d4f
commit 6eda1b4d04
5 changed files with 77 additions and 71 deletions

View File

@ -63,6 +63,7 @@
import spack.util.web as web_util import spack.util.web as web_util
from spack import traverse from spack import traverse
from spack.caches import misc_cache_location from spack.caches import misc_cache_location
from spack.mirrors.utils import MirrorSpecFilter
from spack.oci.image import ( from spack.oci.image import (
Digest, Digest,
ImageReference, ImageReference,
@ -1230,8 +1231,7 @@ def push(
force=self.force, force=self.force,
tmpdir=self.tmpdir, tmpdir=self.tmpdir,
executor=self.executor, executor=self.executor,
exclusions=self.mirror.exclusions, filter=MirrorSpecFilter(self.mirror),
inclusions=self.mirror.inclusions,
) )
self._base_images = base_images self._base_images = base_images
@ -1282,8 +1282,7 @@ def push(
signing_key=self.signing_key, signing_key=self.signing_key,
tmpdir=self.tmpdir, tmpdir=self.tmpdir,
executor=self.executor, executor=self.executor,
exclusions=self.mirror.exclusions, filter=MirrorSpecFilter(self.mirror),
inclusions=self.mirror.inclusions,
) )
@ -1346,34 +1345,6 @@ def fail(self) -> None:
tty.info(f"{self.pre}Failed to push {self.pretty_spec}") tty.info(f"{self.pre}Failed to push {self.pretty_spec}")
def filter_specs(specs: List[spack.spec.Spec], exclude: List[str], include: List[str]):
"""
Determine the intersection of include/exclude filters
Tie goes to keeping
skip | keep | outcome
------------------------
False | False | Keep
True | True | Keep
False | True | Keep
True | False | Skip
"""
filter = []
filtrate = []
ex_specs = [spack.spec.Spec(spec) for spec in exclude]
ic_specs = [spack.spec.Spec(spec) for spec in include]
for spec in specs:
skip = any([spec.satisfies(test) for test in ex_specs])
keep = any([spec.satisfies(test) for test in ic_specs])
if skip and not keep:
filtrate.append(spec)
else:
filter.append(spec)
return filter, filtrate
def _url_push( def _url_push(
specs: List[spack.spec.Spec], specs: List[spack.spec.Spec],
out_url: str, out_url: str,
@ -1382,8 +1353,7 @@ def _url_push(
update_index: bool, update_index: bool,
tmpdir: str, tmpdir: str,
executor: concurrent.futures.Executor, executor: concurrent.futures.Executor,
exclusions: List[str] = [], filter: Optional[MirrorSpecFilter] = None,
inclusions: List[str] = [],
) -> Tuple[List[spack.spec.Spec], List[Tuple[spack.spec.Spec, BaseException]]]: ) -> Tuple[List[spack.spec.Spec], List[Tuple[spack.spec.Spec, BaseException]]]:
"""Pushes to the provided build cache, and returns a list of skipped specs that were already """Pushes to the provided build cache, and returns a list of skipped specs that were already
present (when force=False), and a list of errors. Does not raise on error.""" present (when force=False), and a list of errors. Does not raise on error."""
@ -1414,10 +1384,10 @@ def _url_push(
if not specs_to_upload: if not specs_to_upload:
return skipped, errors return skipped, errors
filter, filtrate = filter_specs(specs_to_upload, exclusions, inclusions) if filter:
filtered, filtrate = filter(specs_to_upload)
skipped.extend(filtrate) skipped.extend(filtrate)
specs_to_upload = filter specs_to_upload = filtered
total = len(specs_to_upload) total = len(specs_to_upload)
@ -1686,8 +1656,7 @@ def _oci_push(
tmpdir: str, tmpdir: str,
executor: concurrent.futures.Executor, executor: concurrent.futures.Executor,
force: bool = False, force: bool = False,
exclusions: List[str] = [], filter: Optional[MirrorSpecFilter] = None,
inclusions: List[str] = [],
) -> Tuple[ ) -> Tuple[
List[spack.spec.Spec], List[spack.spec.Spec],
Dict[str, Tuple[dict, dict]], Dict[str, Tuple[dict, dict]],
@ -1724,10 +1693,10 @@ def _oci_push(
if not blobs_to_upload: if not blobs_to_upload:
return skipped, base_images, checksums, [] return skipped, base_images, checksums, []
filter, filtrate = filter_specs(blobs_to_upload, exclusions, inclusions) if filter:
filtered, filtrate = filter(blobs_to_upload)
skipped.extend(filtrate) skipped.extend(filtrate)
blobs_to_upload = filter blobs_to_upload = filtered
if len(blobs_to_upload) != len(installed_specs_with_deps): if len(blobs_to_upload) != len(installed_specs_with_deps):
tty.info( tty.info(

View File

@ -378,7 +378,7 @@ def mirror_add(args):
else: else:
mirror = spack.mirrors.mirror.Mirror(args.url, name=args.name) mirror = spack.mirrors.mirror.Mirror(args.url, name=args.name)
exclude_specs = mirror.to_dict().get("exclude", []) exclude_specs = []
if args.exclude_file: if args.exclude_file:
exclude_specs.extend(specs_from_text_file(args.exclude_file, concretize=False)) exclude_specs.extend(specs_from_text_file(args.exclude_file, concretize=False))
if args.exclude_specs: if args.exclude_specs:

View File

@ -3,6 +3,7 @@
# SPDX-License-Identifier: (Apache-2.0 OR MIT) # SPDX-License-Identifier: (Apache-2.0 OR MIT)
import os import os
import traceback import traceback
from typing import List
import llnl.util.tty as tty import llnl.util.tty as tty
from llnl.util.filesystem import mkdirp from llnl.util.filesystem import mkdirp
@ -254,3 +255,34 @@ def require_mirror_name(mirror_name):
if not mirror: if not mirror:
raise ValueError(f'no mirror named "{mirror_name}"') raise ValueError(f'no mirror named "{mirror_name}"')
return mirror return mirror
class MirrorSpecFilter:
def __init__(self, mirror: Mirror):
self.exclude = [spack.spec.Spec(spec) for spec in mirror.exclusions]
self.include = [spack.spec.Spec(spec) for spec in mirror.inclusions]
def __call__(self, specs: List[spack.spec.Spec]):
"""
Determine the intersection of include/exclude filters
Tie goes to keeping
skip | keep | outcome
------------------------
False | False | Keep
True | True | Keep
False | True | Keep
True | False | Skip
"""
filter = []
filtrate = []
for spec in specs:
skip = any([spec.satisfies(test) for test in self.exclude])
keep = any([spec.satisfies(test) for test in self.include])
if skip and not keep:
filtrate.append(spec)
else:
filter.append(spec)
return filter, filtrate

View File

@ -59,32 +59,6 @@
legacy_mirror_dir = os.path.join(test_path, "data", "mirrors", "legacy_yaml") legacy_mirror_dir = os.path.join(test_path, "data", "mirrors", "legacy_yaml")
INPUT_SPEC_STRS = ["foo@main", "foo@main dev_path=/tmp", "foo@2.1.3"]
@pytest.mark.parametrize(
"include,exclude,gold",
[
([], [], [0, 1, 2]),
(["dev_path=*", "@main"], [], [0, 1, 2]),
([], ["dev_path=*", "@main"], [2]),
(["dev_path=*"], ["@main"], [1, 2]),
],
)
def test_filter_specs(include, exclude, gold):
input_specs = [spack.spec.Spec(s) for s in INPUT_SPEC_STRS]
filter, filtrate = bindist.filter_specs(input_specs, exclude, include)
assert filter is not None
assert filtrate is not None
# lossless
assert (set(filter) | set(filtrate)) == set(input_specs)
for i in gold:
assert input_specs[i] in filter
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def cache_directory(tmpdir): def cache_directory(tmpdir):
fetch_cache_dir = tmpdir.ensure("fetch_cache", dir=True) fetch_cache_dir = tmpdir.ensure("fetch_cache", dir=True)

View File

@ -18,6 +18,7 @@
import spack.mirrors.mirror import spack.mirrors.mirror
import spack.mirrors.utils import spack.mirrors.utils
import spack.patch import spack.patch
import spack.spec
import spack.stage import spack.stage
import spack.util.executable import spack.util.executable
import spack.util.spack_json as sjson import spack.util.spack_json as sjson
@ -454,3 +455,33 @@ def test_mirror_parse_exclude_include():
m = spack.mirrors.mirror.Mirror(mirror_raw) m = spack.mirrors.mirror.Mirror(mirror_raw)
assert "dev_path=*" in m.exclusions assert "dev_path=*" in m.exclusions
assert "+foo" in m.inclusions assert "+foo" in m.inclusions
INPUT_SPEC_STRS = ["foo@main", "foo@main dev_path=/tmp", "foo@2.1.3"]
@pytest.mark.parametrize(
"include,exclude,gold",
[
([], [], [0, 1, 2]),
(["dev_path=*", "@main"], [], [0, 1, 2]),
([], ["dev_path=*", "@main"], [2]),
(["dev_path=*"], ["@main"], [1, 2]),
],
)
def test_filter_specs(include, exclude, gold):
input_specs = [spack.spec.Spec(s) for s in INPUT_SPEC_STRS]
data = {"include": include, "exclude": exclude}
m = spack.mirrors.mirror.Mirror(data)
filter = spack.mirrors.utils.MirrorSpecFilter(m)
filtered, filtrate = filter(input_specs)
assert filtered is not None
assert filtrate is not None
# lossless
assert (set(filtered) | set(filtrate)) == set(input_specs)
for i in gold:
assert input_specs[i] in filtered