This commit is contained in:
Todd Gamblin 2016-06-12 23:50:59 -07:00
parent 5e5024342f
commit d195576fba
3 changed files with 119 additions and 29 deletions

View File

@ -26,6 +26,8 @@
The ``virtual`` module contains utility classes for virtual dependencies.
"""
import itertools
from pprint import pformat
import yaml
from yaml.error import MarkedYAMLError
@ -48,15 +50,30 @@ class ProviderIndex(object):
Calling providers_for(spec) will find specs that provide a
matching implementation of MPI.
"""
def __init__(self, specs=None, **kwargs):
# TODO: come up with another name for this. This "restricts"
# values to the verbatim impu specs (i.e., it doesn't
# pre-apply package's constraints, and keeps things as broad
# as possible, so it's really the wrong name)
if specs is None: specs = []
self.restrict = kwargs.setdefault('restrict', False)
"""
def __init__(self, specs=None, restrict=False):
"""Create a new ProviderIndex.
Optional arguments:
specs
List (or sequence) of specs. If provided, will call
`update` on this ProviderIndex with each spec in the list.
restrict
"restricts" values to the verbatim input specs; do not
pre-apply package's constraints.
TODO: rename this. It is intended to keep things as broad
as possible without overly restricting results, so it is
not the best name.
"""
if specs is None: specs = []
self.restrict = restrict
self.providers = {}
for spec in specs:
@ -174,10 +191,9 @@ def satisfies(self, other):
def to_yaml(self, stream=None):
provider_list = dict(
(name, [[vpkg.to_node_dict(), [p.to_node_dict() for p in pset]]
for vpkg, pset in pdict.items()])
for name, pdict in self.providers.items())
provider_list = self._transform(
lambda vpkg, pset: [
vpkg.to_node_dict(), [p.to_node_dict() for p in pset]], list)
yaml.dump({'provider_index': {'providers': provider_list}},
stream=stream)
@ -201,12 +217,11 @@ def from_yaml(stream):
index = ProviderIndex()
providers = yfile['provider_index']['providers']
index.providers = dict(
(name, dict((spack.spec.Spec.from_node_dict(vpkg),
set(spack.spec.Spec.from_node_dict(p) for p in plist))
for vpkg, plist in pdict_list))
for name, pdict_list in providers.items())
index.providers = _transform(
providers,
lambda vpkg, plist: (
spack.spec.Spec.from_node_dict(vpkg),
set(spack.spec.Spec.from_node_dict(p) for p in plist)))
return index
@ -253,12 +268,39 @@ def remove_provider(self, pkg_name):
def copy(self):
"""Deep copy of this ProviderIndex."""
clone = ProviderIndex()
clone.providers = dict(
(name, dict((vpkg, set((p.copy() for p in pset)))
for vpkg, pset in pdict.items()))
for name, pdict in self.providers.items())
clone.providers = self._transform(
lambda vpkg, pset: (vpkg, set((p.copy() for p in pset))))
return clone
def __eq__(self, other):
return self.providers == other.providers
def _transform(self, transform_fun, out_mapping_type=dict):
return _transform(self.providers, transform_fun, out_mapping_type)
def __str__(self):
return pformat(
_transform(self.providers,
lambda k, v: (k, list(v))))
def _transform(providers, transform_fun, out_mapping_type=dict):
"""Syntactic sugar for transforming a providers dict.
transform_fun takes a (vpkg, pset) mapping and runs it on each
pair in nested dicts.
"""
def mapiter(mappings):
if isinstance(mappings, dict):
return mappings.iteritems()
else:
return iter(mappings)
return dict(
(name, out_mapping_type([
transform_fun(vpkg, pset) for vpkg, pset in mapiter(mappings)]))
for name, mappings in providers.items())

View File

@ -153,9 +153,6 @@ def test_concretize_with_provides_when(self):
self.assertTrue(not any(spec.satisfies('mpich2@:1.1')
for spec in spack.repo.providers_for('mpi@2.2')))
self.assertTrue(not any(spec.satisfies('mpich2@:1.1')
for spec in spack.repo.providers_for('mpi@2.2')))
self.assertTrue(not any(spec.satisfies('mpich@:1')
for spec in spack.repo.providers_for('mpi@2')))

View File

@ -26,12 +26,27 @@
import unittest
import spack
from spack.spec import Spec
from spack.provider_index import ProviderIndex
from spack.test.mock_packages_test import *
# Test assume that mock packages provide this:
#
# {'blas': {
# blas: set([netlib-blas, openblas, openblas-with-lapack])},
# 'lapack': {lapack: set([netlib-lapack, openblas-with-lapack])},
# 'mpi': {mpi@:1: set([mpich@:1]),
# mpi@:2.0: set([mpich2]),
# mpi@:2.1: set([mpich2@1.1:]),
# mpi@:2.2: set([mpich2@1.2:]),
# mpi@:3: set([mpich@3:]),
# mpi@:10.0: set([zmpi])},
# 'stuff': {stuff: set([externalvirtual])}}
#
class ProviderIndexTest(unittest.TestCase):
class ProviderIndexTest(MockPackagesTest):
def test_write_and_read(self):
def test_yaml_round_trip(self):
p = ProviderIndex(spack.repo.all_package_names())
ostream = StringIO()
@ -40,10 +55,46 @@ def test_write_and_read(self):
istream = StringIO(ostream.getvalue())
q = ProviderIndex.from_yaml(istream)
self.assertTrue(p == q)
self.assertEqual(p, q)
def test_providers_for_simple(self):
p = ProviderIndex(spack.repo.all_package_names())
blas_providers = p.providers_for('blas')
self.assertTrue(Spec('netlib-blas') in blas_providers)
self.assertTrue(Spec('openblas') in blas_providers)
self.assertTrue(Spec('openblas-with-lapack') in blas_providers)
lapack_providers = p.providers_for('lapack')
self.assertTrue(Spec('netlib-lapack') in lapack_providers)
self.assertTrue(Spec('openblas-with-lapack') in lapack_providers)
def test_mpi_providers(self):
p = ProviderIndex(spack.repo.all_package_names())
mpi_2_providers = p.providers_for('mpi@2')
self.assertTrue(Spec('mpich2') in mpi_2_providers)
self.assertTrue(Spec('mpich@3:') in mpi_2_providers)
mpi_3_providers = p.providers_for('mpi@3')
self.assertTrue(Spec('mpich2') not in mpi_3_providers)
self.assertTrue(Spec('mpich@3:') in mpi_3_providers)
self.assertTrue(Spec('zmpi') in mpi_3_providers)
def test_equal(self):
p = ProviderIndex(spack.repo.all_package_names())
q = ProviderIndex(spack.repo.all_package_names())
self.assertEqual(p, q)
def test_copy(self):
p = ProviderIndex(spack.repo.all_package_names())
q = p.copy()
self.assertTrue(p == q)
self.assertEqual(p, q)
def test_copy(self):
pass