WIP
This commit is contained in:
parent
5e5024342f
commit
d195576fba
@ -26,6 +26,8 @@
|
||||
The ``virtual`` module contains utility classes for virtual dependencies.
|
||||
"""
|
||||
import itertools
|
||||
from pprint import pformat
|
||||
|
||||
import yaml
|
||||
from yaml.error import MarkedYAMLError
|
||||
|
||||
@ -48,15 +50,30 @@ class ProviderIndex(object):
|
||||
|
||||
Calling providers_for(spec) will find specs that provide a
|
||||
matching implementation of MPI.
|
||||
"""
|
||||
def __init__(self, specs=None, **kwargs):
|
||||
# TODO: come up with another name for this. This "restricts"
|
||||
# values to the verbatim impu specs (i.e., it doesn't
|
||||
# pre-apply package's constraints, and keeps things as broad
|
||||
# as possible, so it's really the wrong name)
|
||||
if specs is None: specs = []
|
||||
self.restrict = kwargs.setdefault('restrict', False)
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, specs=None, restrict=False):
|
||||
"""Create a new ProviderIndex.
|
||||
|
||||
Optional arguments:
|
||||
|
||||
specs
|
||||
List (or sequence) of specs. If provided, will call
|
||||
`update` on this ProviderIndex with each spec in the list.
|
||||
|
||||
restrict
|
||||
"restricts" values to the verbatim input specs; do not
|
||||
pre-apply package's constraints.
|
||||
|
||||
TODO: rename this. It is intended to keep things as broad
|
||||
as possible without overly restricting results, so it is
|
||||
not the best name.
|
||||
"""
|
||||
if specs is None: specs = []
|
||||
|
||||
self.restrict = restrict
|
||||
self.providers = {}
|
||||
|
||||
for spec in specs:
|
||||
@ -174,10 +191,9 @@ def satisfies(self, other):
|
||||
|
||||
|
||||
def to_yaml(self, stream=None):
|
||||
provider_list = dict(
|
||||
(name, [[vpkg.to_node_dict(), [p.to_node_dict() for p in pset]]
|
||||
for vpkg, pset in pdict.items()])
|
||||
for name, pdict in self.providers.items())
|
||||
provider_list = self._transform(
|
||||
lambda vpkg, pset: [
|
||||
vpkg.to_node_dict(), [p.to_node_dict() for p in pset]], list)
|
||||
|
||||
yaml.dump({'provider_index': {'providers': provider_list}},
|
||||
stream=stream)
|
||||
@ -201,12 +217,11 @@ def from_yaml(stream):
|
||||
|
||||
index = ProviderIndex()
|
||||
providers = yfile['provider_index']['providers']
|
||||
index.providers = dict(
|
||||
(name, dict((spack.spec.Spec.from_node_dict(vpkg),
|
||||
set(spack.spec.Spec.from_node_dict(p) for p in plist))
|
||||
for vpkg, plist in pdict_list))
|
||||
for name, pdict_list in providers.items())
|
||||
|
||||
index.providers = _transform(
|
||||
providers,
|
||||
lambda vpkg, plist: (
|
||||
spack.spec.Spec.from_node_dict(vpkg),
|
||||
set(spack.spec.Spec.from_node_dict(p) for p in plist)))
|
||||
return index
|
||||
|
||||
|
||||
@ -253,12 +268,39 @@ def remove_provider(self, pkg_name):
|
||||
def copy(self):
|
||||
"""Deep copy of this ProviderIndex."""
|
||||
clone = ProviderIndex()
|
||||
clone.providers = dict(
|
||||
(name, dict((vpkg, set((p.copy() for p in pset)))
|
||||
for vpkg, pset in pdict.items()))
|
||||
for name, pdict in self.providers.items())
|
||||
clone.providers = self._transform(
|
||||
lambda vpkg, pset: (vpkg, set((p.copy() for p in pset))))
|
||||
return clone
|
||||
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.providers == other.providers
|
||||
|
||||
|
||||
def _transform(self, transform_fun, out_mapping_type=dict):
|
||||
return _transform(self.providers, transform_fun, out_mapping_type)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return pformat(
|
||||
_transform(self.providers,
|
||||
lambda k, v: (k, list(v))))
|
||||
|
||||
|
||||
def _transform(providers, transform_fun, out_mapping_type=dict):
|
||||
"""Syntactic sugar for transforming a providers dict.
|
||||
|
||||
transform_fun takes a (vpkg, pset) mapping and runs it on each
|
||||
pair in nested dicts.
|
||||
|
||||
"""
|
||||
def mapiter(mappings):
|
||||
if isinstance(mappings, dict):
|
||||
return mappings.iteritems()
|
||||
else:
|
||||
return iter(mappings)
|
||||
|
||||
return dict(
|
||||
(name, out_mapping_type([
|
||||
transform_fun(vpkg, pset) for vpkg, pset in mapiter(mappings)]))
|
||||
for name, mappings in providers.items())
|
||||
|
@ -153,9 +153,6 @@ def test_concretize_with_provides_when(self):
|
||||
self.assertTrue(not any(spec.satisfies('mpich2@:1.1')
|
||||
for spec in spack.repo.providers_for('mpi@2.2')))
|
||||
|
||||
self.assertTrue(not any(spec.satisfies('mpich2@:1.1')
|
||||
for spec in spack.repo.providers_for('mpi@2.2')))
|
||||
|
||||
self.assertTrue(not any(spec.satisfies('mpich@:1')
|
||||
for spec in spack.repo.providers_for('mpi@2')))
|
||||
|
||||
|
@ -26,12 +26,27 @@
|
||||
import unittest
|
||||
|
||||
import spack
|
||||
from spack.spec import Spec
|
||||
from spack.provider_index import ProviderIndex
|
||||
from spack.test.mock_packages_test import *
|
||||
|
||||
# Test assume that mock packages provide this:
|
||||
#
|
||||
# {'blas': {
|
||||
# blas: set([netlib-blas, openblas, openblas-with-lapack])},
|
||||
# 'lapack': {lapack: set([netlib-lapack, openblas-with-lapack])},
|
||||
# 'mpi': {mpi@:1: set([mpich@:1]),
|
||||
# mpi@:2.0: set([mpich2]),
|
||||
# mpi@:2.1: set([mpich2@1.1:]),
|
||||
# mpi@:2.2: set([mpich2@1.2:]),
|
||||
# mpi@:3: set([mpich@3:]),
|
||||
# mpi@:10.0: set([zmpi])},
|
||||
# 'stuff': {stuff: set([externalvirtual])}}
|
||||
#
|
||||
|
||||
class ProviderIndexTest(unittest.TestCase):
|
||||
class ProviderIndexTest(MockPackagesTest):
|
||||
|
||||
def test_write_and_read(self):
|
||||
def test_yaml_round_trip(self):
|
||||
p = ProviderIndex(spack.repo.all_package_names())
|
||||
|
||||
ostream = StringIO()
|
||||
@ -40,10 +55,46 @@ def test_write_and_read(self):
|
||||
istream = StringIO(ostream.getvalue())
|
||||
q = ProviderIndex.from_yaml(istream)
|
||||
|
||||
self.assertTrue(p == q)
|
||||
self.assertEqual(p, q)
|
||||
|
||||
|
||||
def test_providers_for_simple(self):
|
||||
p = ProviderIndex(spack.repo.all_package_names())
|
||||
|
||||
blas_providers = p.providers_for('blas')
|
||||
self.assertTrue(Spec('netlib-blas') in blas_providers)
|
||||
self.assertTrue(Spec('openblas') in blas_providers)
|
||||
self.assertTrue(Spec('openblas-with-lapack') in blas_providers)
|
||||
|
||||
lapack_providers = p.providers_for('lapack')
|
||||
self.assertTrue(Spec('netlib-lapack') in lapack_providers)
|
||||
self.assertTrue(Spec('openblas-with-lapack') in lapack_providers)
|
||||
|
||||
|
||||
def test_mpi_providers(self):
|
||||
p = ProviderIndex(spack.repo.all_package_names())
|
||||
|
||||
mpi_2_providers = p.providers_for('mpi@2')
|
||||
self.assertTrue(Spec('mpich2') in mpi_2_providers)
|
||||
self.assertTrue(Spec('mpich@3:') in mpi_2_providers)
|
||||
|
||||
mpi_3_providers = p.providers_for('mpi@3')
|
||||
self.assertTrue(Spec('mpich2') not in mpi_3_providers)
|
||||
self.assertTrue(Spec('mpich@3:') in mpi_3_providers)
|
||||
self.assertTrue(Spec('zmpi') in mpi_3_providers)
|
||||
|
||||
|
||||
def test_equal(self):
|
||||
p = ProviderIndex(spack.repo.all_package_names())
|
||||
q = ProviderIndex(spack.repo.all_package_names())
|
||||
self.assertEqual(p, q)
|
||||
|
||||
|
||||
def test_copy(self):
|
||||
p = ProviderIndex(spack.repo.all_package_names())
|
||||
q = p.copy()
|
||||
self.assertTrue(p == q)
|
||||
self.assertEqual(p, q)
|
||||
|
||||
|
||||
def test_copy(self):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user