WIP
This commit is contained in:
parent
5e5024342f
commit
d195576fba
@ -26,6 +26,8 @@
|
|||||||
The ``virtual`` module contains utility classes for virtual dependencies.
|
The ``virtual`` module contains utility classes for virtual dependencies.
|
||||||
"""
|
"""
|
||||||
import itertools
|
import itertools
|
||||||
|
from pprint import pformat
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from yaml.error import MarkedYAMLError
|
from yaml.error import MarkedYAMLError
|
||||||
|
|
||||||
@ -48,15 +50,30 @@ class ProviderIndex(object):
|
|||||||
|
|
||||||
Calling providers_for(spec) will find specs that provide a
|
Calling providers_for(spec) will find specs that provide a
|
||||||
matching implementation of MPI.
|
matching implementation of MPI.
|
||||||
"""
|
|
||||||
def __init__(self, specs=None, **kwargs):
|
|
||||||
# TODO: come up with another name for this. This "restricts"
|
|
||||||
# values to the verbatim impu specs (i.e., it doesn't
|
|
||||||
# pre-apply package's constraints, and keeps things as broad
|
|
||||||
# as possible, so it's really the wrong name)
|
|
||||||
if specs is None: specs = []
|
|
||||||
self.restrict = kwargs.setdefault('restrict', False)
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, specs=None, restrict=False):
|
||||||
|
"""Create a new ProviderIndex.
|
||||||
|
|
||||||
|
Optional arguments:
|
||||||
|
|
||||||
|
specs
|
||||||
|
List (or sequence) of specs. If provided, will call
|
||||||
|
`update` on this ProviderIndex with each spec in the list.
|
||||||
|
|
||||||
|
restrict
|
||||||
|
"restricts" values to the verbatim input specs; do not
|
||||||
|
pre-apply package's constraints.
|
||||||
|
|
||||||
|
TODO: rename this. It is intended to keep things as broad
|
||||||
|
as possible without overly restricting results, so it is
|
||||||
|
not the best name.
|
||||||
|
"""
|
||||||
|
if specs is None: specs = []
|
||||||
|
|
||||||
|
self.restrict = restrict
|
||||||
self.providers = {}
|
self.providers = {}
|
||||||
|
|
||||||
for spec in specs:
|
for spec in specs:
|
||||||
@ -174,10 +191,9 @@ def satisfies(self, other):
|
|||||||
|
|
||||||
|
|
||||||
def to_yaml(self, stream=None):
|
def to_yaml(self, stream=None):
|
||||||
provider_list = dict(
|
provider_list = self._transform(
|
||||||
(name, [[vpkg.to_node_dict(), [p.to_node_dict() for p in pset]]
|
lambda vpkg, pset: [
|
||||||
for vpkg, pset in pdict.items()])
|
vpkg.to_node_dict(), [p.to_node_dict() for p in pset]], list)
|
||||||
for name, pdict in self.providers.items())
|
|
||||||
|
|
||||||
yaml.dump({'provider_index': {'providers': provider_list}},
|
yaml.dump({'provider_index': {'providers': provider_list}},
|
||||||
stream=stream)
|
stream=stream)
|
||||||
@ -201,12 +217,11 @@ def from_yaml(stream):
|
|||||||
|
|
||||||
index = ProviderIndex()
|
index = ProviderIndex()
|
||||||
providers = yfile['provider_index']['providers']
|
providers = yfile['provider_index']['providers']
|
||||||
index.providers = dict(
|
index.providers = _transform(
|
||||||
(name, dict((spack.spec.Spec.from_node_dict(vpkg),
|
providers,
|
||||||
set(spack.spec.Spec.from_node_dict(p) for p in plist))
|
lambda vpkg, plist: (
|
||||||
for vpkg, plist in pdict_list))
|
spack.spec.Spec.from_node_dict(vpkg),
|
||||||
for name, pdict_list in providers.items())
|
set(spack.spec.Spec.from_node_dict(p) for p in plist)))
|
||||||
|
|
||||||
return index
|
return index
|
||||||
|
|
||||||
|
|
||||||
@ -253,12 +268,39 @@ def remove_provider(self, pkg_name):
|
|||||||
def copy(self):
|
def copy(self):
|
||||||
"""Deep copy of this ProviderIndex."""
|
"""Deep copy of this ProviderIndex."""
|
||||||
clone = ProviderIndex()
|
clone = ProviderIndex()
|
||||||
clone.providers = dict(
|
clone.providers = self._transform(
|
||||||
(name, dict((vpkg, set((p.copy() for p in pset)))
|
lambda vpkg, pset: (vpkg, set((p.copy() for p in pset))))
|
||||||
for vpkg, pset in pdict.items()))
|
|
||||||
for name, pdict in self.providers.items())
|
|
||||||
return clone
|
return clone
|
||||||
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.providers == other.providers
|
return self.providers == other.providers
|
||||||
|
|
||||||
|
|
||||||
|
def _transform(self, transform_fun, out_mapping_type=dict):
|
||||||
|
return _transform(self.providers, transform_fun, out_mapping_type)
|
||||||
|
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return pformat(
|
||||||
|
_transform(self.providers,
|
||||||
|
lambda k, v: (k, list(v))))
|
||||||
|
|
||||||
|
|
||||||
|
def _transform(providers, transform_fun, out_mapping_type=dict):
|
||||||
|
"""Syntactic sugar for transforming a providers dict.
|
||||||
|
|
||||||
|
transform_fun takes a (vpkg, pset) mapping and runs it on each
|
||||||
|
pair in nested dicts.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def mapiter(mappings):
|
||||||
|
if isinstance(mappings, dict):
|
||||||
|
return mappings.iteritems()
|
||||||
|
else:
|
||||||
|
return iter(mappings)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
(name, out_mapping_type([
|
||||||
|
transform_fun(vpkg, pset) for vpkg, pset in mapiter(mappings)]))
|
||||||
|
for name, mappings in providers.items())
|
||||||
|
@ -153,9 +153,6 @@ def test_concretize_with_provides_when(self):
|
|||||||
self.assertTrue(not any(spec.satisfies('mpich2@:1.1')
|
self.assertTrue(not any(spec.satisfies('mpich2@:1.1')
|
||||||
for spec in spack.repo.providers_for('mpi@2.2')))
|
for spec in spack.repo.providers_for('mpi@2.2')))
|
||||||
|
|
||||||
self.assertTrue(not any(spec.satisfies('mpich2@:1.1')
|
|
||||||
for spec in spack.repo.providers_for('mpi@2.2')))
|
|
||||||
|
|
||||||
self.assertTrue(not any(spec.satisfies('mpich@:1')
|
self.assertTrue(not any(spec.satisfies('mpich@:1')
|
||||||
for spec in spack.repo.providers_for('mpi@2')))
|
for spec in spack.repo.providers_for('mpi@2')))
|
||||||
|
|
||||||
|
@ -26,12 +26,27 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import spack
|
import spack
|
||||||
|
from spack.spec import Spec
|
||||||
from spack.provider_index import ProviderIndex
|
from spack.provider_index import ProviderIndex
|
||||||
|
from spack.test.mock_packages_test import *
|
||||||
|
|
||||||
|
# Test assume that mock packages provide this:
|
||||||
|
#
|
||||||
|
# {'blas': {
|
||||||
|
# blas: set([netlib-blas, openblas, openblas-with-lapack])},
|
||||||
|
# 'lapack': {lapack: set([netlib-lapack, openblas-with-lapack])},
|
||||||
|
# 'mpi': {mpi@:1: set([mpich@:1]),
|
||||||
|
# mpi@:2.0: set([mpich2]),
|
||||||
|
# mpi@:2.1: set([mpich2@1.1:]),
|
||||||
|
# mpi@:2.2: set([mpich2@1.2:]),
|
||||||
|
# mpi@:3: set([mpich@3:]),
|
||||||
|
# mpi@:10.0: set([zmpi])},
|
||||||
|
# 'stuff': {stuff: set([externalvirtual])}}
|
||||||
|
#
|
||||||
|
|
||||||
class ProviderIndexTest(unittest.TestCase):
|
class ProviderIndexTest(MockPackagesTest):
|
||||||
|
|
||||||
def test_write_and_read(self):
|
def test_yaml_round_trip(self):
|
||||||
p = ProviderIndex(spack.repo.all_package_names())
|
p = ProviderIndex(spack.repo.all_package_names())
|
||||||
|
|
||||||
ostream = StringIO()
|
ostream = StringIO()
|
||||||
@ -40,10 +55,46 @@ def test_write_and_read(self):
|
|||||||
istream = StringIO(ostream.getvalue())
|
istream = StringIO(ostream.getvalue())
|
||||||
q = ProviderIndex.from_yaml(istream)
|
q = ProviderIndex.from_yaml(istream)
|
||||||
|
|
||||||
self.assertTrue(p == q)
|
self.assertEqual(p, q)
|
||||||
|
|
||||||
|
|
||||||
|
def test_providers_for_simple(self):
|
||||||
|
p = ProviderIndex(spack.repo.all_package_names())
|
||||||
|
|
||||||
|
blas_providers = p.providers_for('blas')
|
||||||
|
self.assertTrue(Spec('netlib-blas') in blas_providers)
|
||||||
|
self.assertTrue(Spec('openblas') in blas_providers)
|
||||||
|
self.assertTrue(Spec('openblas-with-lapack') in blas_providers)
|
||||||
|
|
||||||
|
lapack_providers = p.providers_for('lapack')
|
||||||
|
self.assertTrue(Spec('netlib-lapack') in lapack_providers)
|
||||||
|
self.assertTrue(Spec('openblas-with-lapack') in lapack_providers)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mpi_providers(self):
|
||||||
|
p = ProviderIndex(spack.repo.all_package_names())
|
||||||
|
|
||||||
|
mpi_2_providers = p.providers_for('mpi@2')
|
||||||
|
self.assertTrue(Spec('mpich2') in mpi_2_providers)
|
||||||
|
self.assertTrue(Spec('mpich@3:') in mpi_2_providers)
|
||||||
|
|
||||||
|
mpi_3_providers = p.providers_for('mpi@3')
|
||||||
|
self.assertTrue(Spec('mpich2') not in mpi_3_providers)
|
||||||
|
self.assertTrue(Spec('mpich@3:') in mpi_3_providers)
|
||||||
|
self.assertTrue(Spec('zmpi') in mpi_3_providers)
|
||||||
|
|
||||||
|
|
||||||
|
def test_equal(self):
|
||||||
|
p = ProviderIndex(spack.repo.all_package_names())
|
||||||
|
q = ProviderIndex(spack.repo.all_package_names())
|
||||||
|
self.assertEqual(p, q)
|
||||||
|
|
||||||
|
|
||||||
def test_copy(self):
|
def test_copy(self):
|
||||||
p = ProviderIndex(spack.repo.all_package_names())
|
p = ProviderIndex(spack.repo.all_package_names())
|
||||||
q = p.copy()
|
q = p.copy()
|
||||||
self.assertTrue(p == q)
|
self.assertEqual(p, q)
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy(self):
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user