Flake8 fixes

This commit is contained in:
Todd Gamblin 2016-08-09 01:37:19 -07:00
parent 102ac7bcf1
commit 0c75c13cc0
16 changed files with 293 additions and 329 deletions

View File

@ -39,13 +39,20 @@
class Lock(object): class Lock(object):
"""This is an implementation of a filesystem lock using Python's lockf.
In Python, `lockf` actually calls `fcntl`, so this should work with any
filesystem implementation that supports locking through the fcntl calls.
This includes distributed filesystems like Lustre (when flock is enabled)
and recent NFS versions.
"""
def __init__(self, file_path): def __init__(self, file_path):
self._file_path = file_path self._file_path = file_path
self._fd = None self._fd = None
self._reads = 0 self._reads = 0
self._writes = 0 self._writes = 0
def _lock(self, op, timeout): def _lock(self, op, timeout):
"""This takes a lock using POSIX locks (``fnctl.lockf``). """This takes a lock using POSIX locks (``fnctl.lockf``).
@ -80,7 +87,6 @@ def _lock(self, op, timeout):
raise LockError("Timed out waiting for lock.") raise LockError("Timed out waiting for lock.")
def _unlock(self): def _unlock(self):
"""Releases a lock using POSIX locks (``fcntl.lockf``) """Releases a lock using POSIX locks (``fcntl.lockf``)
@ -92,7 +98,6 @@ def _unlock(self):
os.close(self._fd) os.close(self._fd)
self._fd = None self._fd = None
def acquire_read(self, timeout=_default_timeout): def acquire_read(self, timeout=_default_timeout):
"""Acquires a recursive, shared lock for reading. """Acquires a recursive, shared lock for reading.
@ -112,7 +117,6 @@ def acquire_read(self, timeout=_default_timeout):
self._reads += 1 self._reads += 1
return False return False
def acquire_write(self, timeout=_default_timeout): def acquire_write(self, timeout=_default_timeout):
"""Acquires a recursive, exclusive lock for writing. """Acquires a recursive, exclusive lock for writing.
@ -132,7 +136,6 @@ def acquire_write(self, timeout=_default_timeout):
self._writes += 1 self._writes += 1
return False return False
def release_read(self): def release_read(self):
"""Releases a read lock. """Releases a read lock.
@ -153,7 +156,6 @@ def release_read(self):
self._reads -= 1 self._reads -= 1
return False return False
def release_write(self): def release_write(self):
"""Releases a write lock. """Releases a write lock.

View File

@ -1,3 +1,4 @@
# flake8: noqa
############################################################################## ##############################################################################
# Copyright (c) 2013-2016, Lawrence Livermore National Security, LLC. # Copyright (c) 2013-2016, Lawrence Livermore National Security, LLC.
# Produced at the Lawrence Livermore National Laboratory. # Produced at the Lawrence Livermore National Laboratory.
@ -147,7 +148,7 @@
_tmp_candidates = (_default_tmp, '/nfs/tmp2', '/tmp', '/var/tmp') _tmp_candidates = (_default_tmp, '/nfs/tmp2', '/tmp', '/var/tmp')
for path in _tmp_candidates: for path in _tmp_candidates:
# don't add a second username if it's already unique by user. # don't add a second username if it's already unique by user.
if not _tmp_user in path: if _tmp_user not in path:
tmp_dirs.append(join_path(path, '%u', 'spack-stage')) tmp_dirs.append(join_path(path, '%u', 'spack-stage'))
else: else:
tmp_dirs.append(join_path(path, 'spack-stage')) tmp_dirs.append(join_path(path, 'spack-stage'))
@ -179,11 +180,12 @@
# Spack internal code should call 'import spack' and accesses other # Spack internal code should call 'import spack' and accesses other
# variables (spack.repo, paths, etc.) directly. # variables (spack.repo, paths, etc.) directly.
# #
# TODO: maybe this should be separated out and should go in build_environment.py? # TODO: maybe this should be separated out to build_environment.py?
# TODO: it's not clear where all the stuff that needs to be included in packages # TODO: it's not clear where all the stuff that needs to be included in
# should live. This file is overloaded for spack core vs. for packages. # packages should live. This file is overloaded for spack core vs.
# for packages.
# #
__all__ = ['Package', 'StagedPackage', 'CMakePackage', \ __all__ = ['Package', 'StagedPackage', 'CMakePackage',
'Version', 'when', 'ver', 'alldeps', 'nolink'] 'Version', 'when', 'ver', 'alldeps', 'nolink']
from spack.package import Package, ExtensionConflictError from spack.package import Package, ExtensionConflictError
from spack.package import StagedPackage, CMakePackage from spack.package import StagedPackage, CMakePackage
@ -204,8 +206,8 @@
__all__ += spack.util.executable.__all__ __all__ += spack.util.executable.__all__
from spack.package import \ from spack.package import \
install_dependency_symlinks, flatten_dependencies, DependencyConflictError, \ install_dependency_symlinks, flatten_dependencies, \
InstallError, ExternalPackageError DependencyConflictError, InstallError, ExternalPackageError
__all__ += [ __all__ += [
'install_dependency_symlinks', 'flatten_dependencies', 'DependencyConflictError', 'install_dependency_symlinks', 'flatten_dependencies',
'InstallError', 'ExternalPackageError'] 'DependencyConflictError', 'InstallError', 'ExternalPackageError']

View File

@ -45,14 +45,14 @@ def setup_parser(subparser):
def purge(parser, args): def purge(parser, args):
# Special case: no flags. # Special case: no flags.
if not any((args.stage, args.cache, args.all)): if not any((args.stage, args.downloads, args.user_cache, args.all)):
stage.purge() stage.purge()
return return
# handle other flags with fall through. # handle other flags with fall through.
if args.stage or args.all: if args.stage or args.all:
stage.purge() stage.purge()
if args.cache or args.all: if args.downloads or args.all:
spack.fetch_cache.destroy() spack.fetch_cache.destroy()
if args.user_cache or args.all: if args.user_cache or args.all:
spack.user_cache.destroy() spack.user_cache.destroy()

View File

@ -23,11 +23,9 @@
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
############################################################################## ##############################################################################
import os import os
from pprint import pprint
from llnl.util.filesystem import join_path, mkdirp from llnl.util.filesystem import join_path, mkdirp
from llnl.util.tty.colify import colify from llnl.util.tty.colify import colify
from llnl.util.lang import list_modules
import spack import spack
import spack.test import spack.test
@ -35,11 +33,13 @@
description = "Run unit tests" description = "Run unit tests"
def setup_parser(subparser): def setup_parser(subparser):
subparser.add_argument( subparser.add_argument(
'names', nargs='*', help="Names of tests to run.") 'names', nargs='*', help="Names of tests to run.")
subparser.add_argument( subparser.add_argument(
'-l', '--list', action='store_true', dest='list', help="Show available tests") '-l', '--list', action='store_true', dest='list',
help="Show available tests")
subparser.add_argument( subparser.add_argument(
'--createXmlOutput', action='store_true', dest='createXmlOutput', '--createXmlOutput', action='store_true', dest='createXmlOutput',
help="Create JUnit XML from test results") help="Create JUnit XML from test results")
@ -69,6 +69,7 @@ def fetch(self):
def __str__(self): def __str__(self):
return "[mock fetcher]" return "[mock fetcher]"
def test(parser, args): def test(parser, args):
if args.list: if args.list:
print "Available tests:" print "Available tests:"

View File

@ -357,9 +357,9 @@ def _write(self, type, value, traceback):
This is a helper function called by the WriteTransaction context This is a helper function called by the WriteTransaction context
manager. If there is an exception while the write lock is active, manager. If there is an exception while the write lock is active,
nothing will be written to the database file, but the in-memory database nothing will be written to the database file, but the in-memory
*may* be left in an inconsistent state. It will be consistent after the database *may* be left in an inconsistent state. It will be consistent
start of the next transaction, when it read from disk again. after the start of the next transaction, when it read from disk again.
This routine does no locking. This routine does no locking.

View File

@ -28,7 +28,6 @@
from llnl.util.filesystem import * from llnl.util.filesystem import *
from llnl.util.lock import * from llnl.util.lock import *
import spack
from spack.error import SpackError from spack.error import SpackError
@ -54,11 +53,14 @@ def __init__(self, root):
self._locks = {} self._locks = {}
def purge(self): def destroy(self):
"""Remove all files under the cache root.""" """Remove all files under the cache root."""
for f in os.listdir(self.root): for f in os.listdir(self.root):
path = join_path(self.root, f) path = join_path(self.root, f)
shutil.rmtree(f) if os.path.isdir(path):
shutil.rmtree(path, True)
else:
os.remove(path)
def cache_path(self, key): def cache_path(self, key):
"""Path to the file in the cache for a particular key.""" """Path to the file in the cache for a particular key."""
@ -154,7 +156,6 @@ def __exit__(cm, type, value, traceback):
return WriteTransaction(self._get_lock(key), WriteContextManager) return WriteTransaction(self._get_lock(key), WriteContextManager)
def mtime(self, key): def mtime(self, key):
"""Return modification time of cache file, or 0 if it does not exist. """Return modification time of cache file, or 0 if it does not exist.
@ -168,7 +169,6 @@ def mtime(self, key):
sinfo = os.stat(self.cache_path(key)) sinfo = os.stat(self.cache_path(key))
return sinfo.st_mtime return sinfo.st_mtime
def remove(self, key): def remove(self, key):
lock = self._get_lock(key) lock = self._get_lock(key)
try: try:
@ -178,4 +178,6 @@ def remove(self, key):
lock.release_write() lock.release_write()
os.unlink(self._lock_path(key)) os.unlink(self._lock_path(key))
class CacheError(SpackError): pass
class CacheError(SpackError):
pass

View File

@ -1416,6 +1416,7 @@ def use_cray_compiler_names():
os.environ['FC'] = 'ftn' os.environ['FC'] = 'ftn'
os.environ['F77'] = 'ftn' os.environ['F77'] = 'ftn'
def flatten_dependencies(spec, flat_dir): def flatten_dependencies(spec, flat_dir):
"""Make each dependency of spec present in dir via symlink.""" """Make each dependency of spec present in dir via symlink."""
for dep in spec.traverse(root=False): for dep in spec.traverse(root=False):

View File

@ -25,7 +25,7 @@
""" """
The ``virtual`` module contains utility classes for virtual dependencies. The ``virtual`` module contains utility classes for virtual dependencies.
""" """
import itertools from itertools import product as iproduct
from pprint import pformat from pprint import pformat
import yaml import yaml
@ -52,8 +52,6 @@ class ProviderIndex(object):
matching implementation of MPI. matching implementation of MPI.
""" """
def __init__(self, specs=None, restrict=False): def __init__(self, specs=None, restrict=False):
"""Create a new ProviderIndex. """Create a new ProviderIndex.
@ -71,7 +69,8 @@ def __init__(self, specs=None, restrict=False):
as possible without overly restricting results, so it is as possible without overly restricting results, so it is
not the best name. not the best name.
""" """
if specs is None: specs = [] if specs is None:
specs = []
self.restrict = restrict self.restrict = restrict
self.providers = {} self.providers = {}
@ -85,7 +84,6 @@ def __init__(self, specs=None, restrict=False):
self.update(spec) self.update(spec)
def update(self, spec): def update(self, spec):
if not isinstance(spec, spack.spec.Spec): if not isinstance(spec, spack.spec.Spec):
spec = spack.spec.Spec(spec) spec = spack.spec.Spec(spec)
@ -104,7 +102,7 @@ def update(self, spec):
provided_name = provided_spec.name provided_name = provided_spec.name
provider_map = self.providers.setdefault(provided_name, {}) provider_map = self.providers.setdefault(provided_name, {})
if not provided_spec in provider_map: if provided_spec not in provider_map:
provider_map[provided_spec] = set() provider_map[provided_spec] = set()
if self.restrict: if self.restrict:
@ -126,7 +124,6 @@ def update(self, spec):
constrained.constrain(provider_spec) constrained.constrain(provider_spec)
provider_map[provided_spec].add(constrained) provider_map[provided_spec].add(constrained)
def providers_for(self, *vpkg_specs): def providers_for(self, *vpkg_specs):
"""Gives specs of all packages that provide virtual packages """Gives specs of all packages that provide virtual packages
with the supplied specs.""" with the supplied specs."""
@ -138,26 +135,25 @@ def providers_for(self, *vpkg_specs):
# Add all the providers that satisfy the vpkg spec. # Add all the providers that satisfy the vpkg spec.
if vspec.name in self.providers: if vspec.name in self.providers:
for provider_spec, spec_set in self.providers[vspec.name].items(): for p_spec, spec_set in self.providers[vspec.name].items():
if provider_spec.satisfies(vspec, deps=False): if p_spec.satisfies(vspec, deps=False):
providers.update(spec_set) providers.update(spec_set)
# Return providers in order # Return providers in order
return sorted(providers) return sorted(providers)
# TODO: this is pretty darned nasty, and inefficient, but there # TODO: this is pretty darned nasty, and inefficient, but there
# are not that many vdeps in most specs. # are not that many vdeps in most specs.
def _cross_provider_maps(self, lmap, rmap): def _cross_provider_maps(self, lmap, rmap):
result = {} result = {}
for lspec, rspec in itertools.product(lmap, rmap): for lspec, rspec in iproduct(lmap, rmap):
try: try:
constrained = lspec.constrained(rspec) constrained = lspec.constrained(rspec)
except spack.spec.UnsatisfiableSpecError: except spack.spec.UnsatisfiableSpecError:
continue continue
# lp and rp are left and right provider specs. # lp and rp are left and right provider specs.
for lp_spec, rp_spec in itertools.product(lmap[lspec], rmap[rspec]): for lp_spec, rp_spec in iproduct(lmap[lspec], rmap[rspec]):
if lp_spec.name == rp_spec.name: if lp_spec.name == rp_spec.name:
try: try:
const = lp_spec.constrained(rp_spec, deps=False) const = lp_spec.constrained(rp_spec, deps=False)
@ -166,12 +162,10 @@ def _cross_provider_maps(self, lmap, rmap):
continue continue
return result return result
def __contains__(self, name): def __contains__(self, name):
"""Whether a particular vpkg name is in the index.""" """Whether a particular vpkg name is in the index."""
return name in self.providers return name in self.providers
def satisfies(self, other): def satisfies(self, other):
"""Check that providers of virtual specs are compatible.""" """Check that providers of virtual specs are compatible."""
common = set(self.providers) & set(other.providers) common = set(self.providers) & set(other.providers)
@ -189,7 +183,6 @@ def satisfies(self, other):
return all(c in result for c in common) return all(c in result for c in common)
def to_yaml(self, stream=None): def to_yaml(self, stream=None):
provider_list = self._transform( provider_list = self._transform(
lambda vpkg, pset: [ lambda vpkg, pset: [
@ -198,7 +191,6 @@ def to_yaml(self, stream=None):
yaml.dump({'provider_index': {'providers': provider_list}}, yaml.dump({'provider_index': {'providers': provider_list}},
stream=stream) stream=stream)
@staticmethod @staticmethod
def from_yaml(stream): def from_yaml(stream):
try: try:
@ -211,7 +203,7 @@ def from_yaml(stream):
raise spack.spec.SpackYAMLError( raise spack.spec.SpackYAMLError(
"YAML ProviderIndex was not a dict.") "YAML ProviderIndex was not a dict.")
if not 'provider_index' in yfile: if 'provider_index' not in yfile:
raise spack.spec.SpackYAMLError( raise spack.spec.SpackYAMLError(
"YAML ProviderIndex does not start with 'provider_index'") "YAML ProviderIndex does not start with 'provider_index'")
@ -224,7 +216,6 @@ def from_yaml(stream):
set(spack.spec.Spec.from_node_dict(p) for p in plist))) set(spack.spec.Spec.from_node_dict(p) for p in plist)))
return index return index
def merge(self, other): def merge(self, other):
"""Merge `other` ProviderIndex into this one.""" """Merge `other` ProviderIndex into this one."""
other = other.copy() # defensive copy. other = other.copy() # defensive copy.
@ -242,7 +233,6 @@ def merge(self, other):
spdict[provided_spec] += opdict[provided_spec] spdict[provided_spec] += opdict[provided_spec]
def remove_provider(self, pkg_name): def remove_provider(self, pkg_name):
"""Remove a provider from the ProviderIndex.""" """Remove a provider from the ProviderIndex."""
empty_pkg_dict = [] empty_pkg_dict = []
@ -264,7 +254,6 @@ def remove_provider(self, pkg_name):
for pkg in empty_pkg_dict: for pkg in empty_pkg_dict:
del self.providers[pkg] del self.providers[pkg]
def copy(self): def copy(self):
"""Deep copy of this ProviderIndex.""" """Deep copy of this ProviderIndex."""
clone = ProviderIndex() clone = ProviderIndex()
@ -272,15 +261,12 @@ def copy(self):
lambda vpkg, pset: (vpkg, set((p.copy() for p in pset)))) lambda vpkg, pset: (vpkg, set((p.copy() for p in pset))))
return clone return clone
def __eq__(self, other): def __eq__(self, other):
return self.providers == other.providers return self.providers == other.providers
def _transform(self, transform_fun, out_mapping_type=dict): def _transform(self, transform_fun, out_mapping_type=dict):
return _transform(self.providers, transform_fun, out_mapping_type) return _transform(self.providers, transform_fun, out_mapping_type)
def __str__(self): def __str__(self):
return pformat( return pformat(
_transform(self.providers, _transform(self.providers,

View File

@ -38,7 +38,6 @@
import yaml import yaml
import llnl.util.tty as tty import llnl.util.tty as tty
from llnl.util.lock import Lock
from llnl.util.filesystem import * from llnl.util.filesystem import *
import spack import spack
@ -142,7 +141,6 @@ def __init__(self, *repo_dirs, **kwargs):
"To remove the bad repository, run this command:", "To remove the bad repository, run this command:",
" spack repo rm %s" % root) " spack repo rm %s" % root)
def swap(self, other): def swap(self, other):
"""Convenience function to make swapping repostiories easier. """Convenience function to make swapping repostiories easier.
@ -160,7 +158,6 @@ def swap(self, other):
setattr(self, attr, getattr(other, attr)) setattr(self, attr, getattr(other, attr))
setattr(other, attr, tmp) setattr(other, attr, tmp)
def _add(self, repo): def _add(self, repo):
"""Add a repository to the namespace and path indexes. """Add a repository to the namespace and path indexes.
@ -174,31 +171,28 @@ def _add(self, repo):
if repo.namespace in self.by_namespace: if repo.namespace in self.by_namespace:
raise DuplicateRepoError( raise DuplicateRepoError(
"Package repos '%s' and '%s' both provide namespace %s" "Package repos '%s' and '%s' both provide namespace %s"
% (repo.root, self.by_namespace[repo.namespace].root, repo.namespace)) % (repo.root, self.by_namespace[repo.namespace].root,
repo.namespace))
# Add repo to the pkg indexes # Add repo to the pkg indexes
self.by_namespace[repo.full_namespace] = repo self.by_namespace[repo.full_namespace] = repo
self.by_path[repo.root] = repo self.by_path[repo.root] = repo
def put_first(self, repo): def put_first(self, repo):
"""Add repo first in the search path.""" """Add repo first in the search path."""
self._add(repo) self._add(repo)
self.repos.insert(0, repo) self.repos.insert(0, repo)
def put_last(self, repo): def put_last(self, repo):
"""Add repo last in the search path.""" """Add repo last in the search path."""
self._add(repo) self._add(repo)
self.repos.append(repo) self.repos.append(repo)
def remove(self, repo): def remove(self, repo):
"""Remove a repo from the search path.""" """Remove a repo from the search path."""
if repo in self.repos: if repo in self.repos:
self.repos.remove(repo) self.repos.remove(repo)
def get_repo(self, namespace, default=NOT_PROVIDED): def get_repo(self, namespace, default=NOT_PROVIDED):
"""Get a repository by namespace. """Get a repository by namespace.
Arguments Arguments
@ -218,12 +212,10 @@ def get_repo(self, namespace, default=NOT_PROVIDED):
return default return default
return self.by_namespace[fullspace] return self.by_namespace[fullspace]
def first_repo(self): def first_repo(self):
"""Get the first repo in precedence order.""" """Get the first repo in precedence order."""
return self.repos[0] if self.repos else None return self.repos[0] if self.repos else None
def all_package_names(self): def all_package_names(self):
"""Return all unique package names in all repositories.""" """Return all unique package names in all repositories."""
if self._all_package_names is None: if self._all_package_names is None:
@ -234,12 +226,10 @@ def all_package_names(self):
self._all_package_names = sorted(all_pkgs, key=lambda n: n.lower()) self._all_package_names = sorted(all_pkgs, key=lambda n: n.lower())
return self._all_package_names return self._all_package_names
def all_packages(self): def all_packages(self):
for name in self.all_package_names(): for name in self.all_package_names():
yield self.get(name) yield self.get(name)
@property @property
def provider_index(self): def provider_index(self):
"""Merged ProviderIndex from all Repos in the RepoPath.""" """Merged ProviderIndex from all Repos in the RepoPath."""
@ -250,7 +240,6 @@ def provider_index(self):
return self._provider_index return self._provider_index
@_autospec @_autospec
def providers_for(self, vpkg_spec): def providers_for(self, vpkg_spec):
providers = self.provider_index.providers_for(vpkg_spec) providers = self.provider_index.providers_for(vpkg_spec)
@ -258,12 +247,10 @@ def providers_for(self, vpkg_spec):
raise UnknownPackageError(vpkg_spec.name) raise UnknownPackageError(vpkg_spec.name)
return providers return providers
@_autospec @_autospec
def extensions_for(self, extendee_spec): def extensions_for(self, extendee_spec):
return [p for p in self.all_packages() if p.extends(extendee_spec)] return [p for p in self.all_packages() if p.extends(extendee_spec)]
def find_module(self, fullname, path=None): def find_module(self, fullname, path=None):
"""Implements precedence for overlaid namespaces. """Implements precedence for overlaid namespaces.
@ -290,7 +277,6 @@ def find_module(self, fullname, path=None):
return None return None
def load_module(self, fullname): def load_module(self, fullname):
"""Handles loading container namespaces when necessary. """Handles loading container namespaces when necessary.
@ -307,7 +293,6 @@ def load_module(self, fullname):
sys.modules[fullname] = module sys.modules[fullname] = module
return module return module
@_autospec @_autospec
def repo_for_pkg(self, spec): def repo_for_pkg(self, spec):
"""Given a spec, get the repository for its package.""" """Given a spec, get the repository for its package."""
@ -329,7 +314,6 @@ def repo_for_pkg(self, spec):
# that can operate on packages that don't exist yet. # that can operate on packages that don't exist yet.
return self.first_repo() return self.first_repo()
@_autospec @_autospec
def get(self, spec, new=False): def get(self, spec, new=False):
"""Find a repo that contains the supplied spec's package. """Find a repo that contains the supplied spec's package.
@ -338,12 +322,10 @@ def get(self, spec, new=False):
""" """
return self.repo_for_pkg(spec).get(spec) return self.repo_for_pkg(spec).get(spec)
def get_pkg_class(self, pkg_name): def get_pkg_class(self, pkg_name):
"""Find a class for the spec's package and return the class object.""" """Find a class for the spec's package and return the class object."""
return self.repo_for_pkg(pkg_name).get_pkg_class(pkg_name) return self.repo_for_pkg(pkg_name).get_pkg_class(pkg_name)
@_autospec @_autospec
def dump_provenance(self, spec, path): def dump_provenance(self, spec, path):
"""Dump provenance information for a spec to a particular path. """Dump provenance information for a spec to a particular path.
@ -353,24 +335,19 @@ def dump_provenance(self, spec, path):
""" """
return self.repo_for_pkg(spec).dump_provenance(spec, path) return self.repo_for_pkg(spec).dump_provenance(spec, path)
def dirname_for_package_name(self, pkg_name): def dirname_for_package_name(self, pkg_name):
return self.repo_for_pkg(pkg_name).dirname_for_package_name(pkg_name) return self.repo_for_pkg(pkg_name).dirname_for_package_name(pkg_name)
def filename_for_package_name(self, pkg_name): def filename_for_package_name(self, pkg_name):
return self.repo_for_pkg(pkg_name).filename_for_package_name(pkg_name) return self.repo_for_pkg(pkg_name).filename_for_package_name(pkg_name)
def exists(self, pkg_name): def exists(self, pkg_name):
return any(repo.exists(pkg_name) for repo in self.repos) return any(repo.exists(pkg_name) for repo in self.repos)
def __contains__(self, pkg_name): def __contains__(self, pkg_name):
return self.exists(pkg_name) return self.exists(pkg_name)
class Repo(object): class Repo(object):
"""Class representing a package repository in the filesystem. """Class representing a package repository in the filesystem.
@ -404,7 +381,8 @@ def __init__(self, root, namespace=repo_namespace):
# check and raise BadRepoError on fail. # check and raise BadRepoError on fail.
def check(condition, msg): def check(condition, msg):
if not condition: raise BadRepoError(msg) if not condition:
raise BadRepoError(msg)
# Validate repository layout. # Validate repository layout.
self.config_file = join_path(self.root, repo_config_name) self.config_file = join_path(self.root, repo_config_name)
@ -422,12 +400,14 @@ def check(condition, msg):
self.namespace = config['namespace'] self.namespace = config['namespace']
check(re.match(r'[a-zA-Z][a-zA-Z0-9_.]+', self.namespace), check(re.match(r'[a-zA-Z][a-zA-Z0-9_.]+', self.namespace),
("Invalid namespace '%s' in repo '%s'. " % (self.namespace, self.root)) + ("Invalid namespace '%s' in repo '%s'. "
% (self.namespace, self.root)) +
"Namespaces must be valid python identifiers separated by '.'") "Namespaces must be valid python identifiers separated by '.'")
# Set up 'full_namespace' to include the super-namespace # Set up 'full_namespace' to include the super-namespace
if self.super_namespace: if self.super_namespace:
self.full_namespace = "%s.%s" % (self.super_namespace, self.namespace) self.full_namespace = "%s.%s" % (
self.super_namespace, self.namespace)
else: else:
self.full_namespace = self.namespace self.full_namespace = self.namespace
@ -465,7 +445,7 @@ def _create_namespace(self):
for l in range(1, len(self._names) + 1): for l in range(1, len(self._names) + 1):
ns = '.'.join(self._names[:l]) ns = '.'.join(self._names[:l])
if not ns in sys.modules: if ns not in sys.modules:
module = SpackNamespace(ns) module = SpackNamespace(ns)
module.__loader__ = self module.__loader__ = self
sys.modules[ns] = module sys.modules[ns] = module
@ -485,7 +465,6 @@ def _create_namespace(self):
# but keep track of the parent in this loop # but keep track of the parent in this loop
parent = module parent = module
def real_name(self, import_name): def real_name(self, import_name):
"""Allow users to import Spack packages using Python identifiers. """Allow users to import Spack packages using Python identifiers.
@ -511,13 +490,11 @@ def real_name(self, import_name):
return name return name
return None return None
def is_prefix(self, fullname): def is_prefix(self, fullname):
"""True if fullname is a prefix of this Repo's namespace.""" """True if fullname is a prefix of this Repo's namespace."""
parts = fullname.split('.') parts = fullname.split('.')
return self._names[:len(parts)] == parts return self._names[:len(parts)] == parts
def find_module(self, fullname, path=None): def find_module(self, fullname, path=None):
"""Python find_module import hook. """Python find_module import hook.
@ -533,7 +510,6 @@ def find_module(self, fullname, path=None):
return None return None
def load_module(self, fullname): def load_module(self, fullname):
"""Python importer load hook. """Python importer load hook.
@ -565,7 +541,6 @@ def load_module(self, fullname):
return module return module
def _read_config(self): def _read_config(self):
"""Check for a YAML config file in this db's root directory.""" """Check for a YAML config file in this db's root directory."""
try: try:
@ -574,23 +549,23 @@ def _read_config(self):
if (not yaml_data or 'repo' not in yaml_data or if (not yaml_data or 'repo' not in yaml_data or
not isinstance(yaml_data['repo'], dict)): not isinstance(yaml_data['repo'], dict)):
tty.die("Invalid %s in repository %s" tty.die("Invalid %s in repository %s" % (
% (repo_config_name, self.root)) repo_config_name, self.root))
return yaml_data['repo'] return yaml_data['repo']
except exceptions.IOError, e: except exceptions.IOError:
tty.die("Error reading %s when opening %s" tty.die("Error reading %s when opening %s"
% (self.config_file, self.root)) % (self.config_file, self.root))
@_autospec @_autospec
def get(self, spec, new=False): def get(self, spec, new=False):
if spec.virtual: if spec.virtual:
raise UnknownPackageError(spec.name) raise UnknownPackageError(spec.name)
if spec.namespace and spec.namespace != self.namespace: if spec.namespace and spec.namespace != self.namespace:
raise UnknownPackageError("Repository %s does not contain package %s" raise UnknownPackageError(
"Repository %s does not contain package %s"
% (self.namespace, spec.fullname)) % (self.namespace, spec.fullname))
key = hash(spec) key = hash(spec)
@ -599,14 +574,13 @@ def get(self, spec, new=False):
try: try:
copy = spec.copy() # defensive copy. Package owns its spec. copy = spec.copy() # defensive copy. Package owns its spec.
self._instances[key] = package_class(copy) self._instances[key] = package_class(copy)
except Exception, e: except Exception:
if spack.debug: if spack.debug:
sys.excepthook(*sys.exc_info()) sys.excepthook(*sys.exc_info())
raise FailedConstructorError(spec.fullname, *sys.exc_info()) raise FailedConstructorError(spec.fullname, *sys.exc_info())
return self._instances[key] return self._instances[key]
@_autospec @_autospec
def dump_provenance(self, spec, path): def dump_provenance(self, spec, path):
"""Dump provenance information for a spec to a particular path. """Dump provenance information for a spec to a particular path.
@ -619,7 +593,8 @@ def dump_provenance(self, spec, path):
raise UnknownPackageError(spec.name) raise UnknownPackageError(spec.name)
if spec.namespace and spec.namespace != self.namespace: if spec.namespace and spec.namespace != self.namespace:
raise UnknownPackageError("Repository %s does not contain package %s." raise UnknownPackageError(
"Repository %s does not contain package %s."
% (self.namespace, spec.fullname)) % (self.namespace, spec.fullname))
# Install any patch files needed by packages. # Install any patch files needed by packages.
@ -635,12 +610,10 @@ def dump_provenance(self, spec, path):
# Install the package.py file itself. # Install the package.py file itself.
install(self.filename_for_package_name(spec), path) install(self.filename_for_package_name(spec), path)
def purge(self): def purge(self):
"""Clear entire package instance cache.""" """Clear entire package instance cache."""
self._instances.clear() self._instances.clear()
def _update_provider_index(self): def _update_provider_index(self):
# Check modification dates of all packages # Check modification dates of all packages
self._fast_package_check() self._fast_package_check()
@ -669,7 +642,6 @@ def read():
self._provider_index.to_yaml(new) self._provider_index.to_yaml(new)
@property @property
def provider_index(self): def provider_index(self):
"""A provider index with names *specific* to this repo.""" """A provider index with names *specific* to this repo."""
@ -677,7 +649,6 @@ def provider_index(self):
self._update_provider_index() self._update_provider_index()
return self._provider_index return self._provider_index
@_autospec @_autospec
def providers_for(self, vpkg_spec): def providers_for(self, vpkg_spec):
providers = self.provider_index.providers_for(vpkg_spec) providers = self.provider_index.providers_for(vpkg_spec)
@ -685,18 +656,15 @@ def providers_for(self, vpkg_spec):
raise UnknownPackageError(vpkg_spec.name) raise UnknownPackageError(vpkg_spec.name)
return providers return providers
@_autospec @_autospec
def extensions_for(self, extendee_spec): def extensions_for(self, extendee_spec):
return [p for p in self.all_packages() if p.extends(extendee_spec)] return [p for p in self.all_packages() if p.extends(extendee_spec)]
def _check_namespace(self, spec): def _check_namespace(self, spec):
"""Check that the spec's namespace is the same as this repository's.""" """Check that the spec's namespace is the same as this repository's."""
if spec.namespace and spec.namespace != self.namespace: if spec.namespace and spec.namespace != self.namespace:
raise UnknownNamespaceError(spec.namespace) raise UnknownNamespaceError(spec.namespace)
@_autospec @_autospec
def dirname_for_package_name(self, spec): def dirname_for_package_name(self, spec):
"""Get the directory name for a particular package. This is the """Get the directory name for a particular package. This is the
@ -704,7 +672,6 @@ def dirname_for_package_name(self, spec):
self._check_namespace(spec) self._check_namespace(spec)
return join_path(self.packages_path, spec.name) return join_path(self.packages_path, spec.name)
@_autospec @_autospec
def filename_for_package_name(self, spec): def filename_for_package_name(self, spec):
"""Get the filename for the module we should load for a particular """Get the filename for the module we should load for a particular
@ -719,7 +686,6 @@ def filename_for_package_name(self, spec):
pkg_dir = self.dirname_for_package_name(spec.name) pkg_dir = self.dirname_for_package_name(spec.name)
return join_path(pkg_dir, package_file_name) return join_path(pkg_dir, package_file_name)
def _fast_package_check(self): def _fast_package_check(self):
"""List packages in the repo and check whether index is up to date. """List packages in the repo and check whether index is up to date.
@ -783,13 +749,11 @@ def _fast_package_check(self):
return self._all_package_names return self._all_package_names
def all_package_names(self): def all_package_names(self):
"""Returns a sorted list of all package names in the Repo.""" """Returns a sorted list of all package names in the Repo."""
self._fast_package_check() self._fast_package_check()
return self._all_package_names return self._all_package_names
def all_packages(self): def all_packages(self):
"""Iterator over all packages in the repository. """Iterator over all packages in the repository.
@ -799,7 +763,6 @@ def all_packages(self):
for name in self.all_package_names(): for name in self.all_package_names():
yield self.get(name) yield self.get(name)
def exists(self, pkg_name): def exists(self, pkg_name):
"""Whether a package with the supplied name exists.""" """Whether a package with the supplied name exists."""
if self._all_package_names: if self._all_package_names:
@ -813,7 +776,6 @@ def exists(self, pkg_name):
filename = self.filename_for_package_name(pkg_name) filename = self.filename_for_package_name(pkg_name)
return os.path.exists(filename) return os.path.exists(filename)
def _get_pkg_module(self, pkg_name): def _get_pkg_module(self, pkg_name):
"""Create a module for a particular package. """Create a module for a particular package.
@ -845,7 +807,6 @@ def _get_pkg_module(self, pkg_name):
return self._modules[pkg_name] return self._modules[pkg_name]
def get_pkg_class(self, pkg_name): def get_pkg_class(self, pkg_name):
"""Get the class for the package out of its module. """Get the class for the package out of its module.
@ -853,7 +814,6 @@ def get_pkg_class(self, pkg_name):
package. Then extracts the package class from the module package. Then extracts the package class from the module
according to Spack's naming convention. according to Spack's naming convention.
""" """
fullname = pkg_name
namespace, _, pkg_name = pkg_name.rpartition('.') namespace, _, pkg_name = pkg_name.rpartition('.')
if namespace and (namespace != self.namespace): if namespace and (namespace != self.namespace):
raise InvalidNamespaceError('Invalid namespace for %s repo: %s' raise InvalidNamespaceError('Invalid namespace for %s repo: %s'
@ -868,15 +828,12 @@ def get_pkg_class(self, pkg_name):
return cls return cls
def __str__(self): def __str__(self):
return "[Repo '%s' at '%s']" % (self.namespace, self.root) return "[Repo '%s' at '%s']" % (self.namespace, self.root)
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
def __contains__(self, pkg_name): def __contains__(self, pkg_name):
return self.exists(pkg_name) return self.exists(pkg_name)
@ -885,30 +842,37 @@ def create_repo(root, namespace=None):
"""Create a new repository in root with the specified namespace. """Create a new repository in root with the specified namespace.
If the namespace is not provided, use basename of root. If the namespace is not provided, use basename of root.
Return the canonicalized path and the namespace of the created repository. Return the canonicalized path and namespace of the created repository.
""" """
root = canonicalize_path(root) root = canonicalize_path(root)
if not namespace: if not namespace:
namespace = os.path.basename(root) namespace = os.path.basename(root)
if not re.match(r'\w[\.\w-]*', namespace): if not re.match(r'\w[\.\w-]*', namespace):
raise InvalidNamespaceError("'%s' is not a valid namespace." % namespace) raise InvalidNamespaceError(
"'%s' is not a valid namespace." % namespace)
existed = False existed = False
if os.path.exists(root): if os.path.exists(root):
if os.path.isfile(root): if os.path.isfile(root):
raise BadRepoError('File %s already exists and is not a directory' % root) raise BadRepoError('File %s already exists and is not a directory'
% root)
elif os.path.isdir(root): elif os.path.isdir(root):
if not os.access(root, os.R_OK | os.W_OK): if not os.access(root, os.R_OK | os.W_OK):
raise BadRepoError('Cannot create new repo in %s: cannot access directory.' % root) raise BadRepoError(
'Cannot create new repo in %s: cannot access directory.'
% root)
if os.listdir(root): if os.listdir(root):
raise BadRepoError('Cannot create new repo in %s: directory is not empty.' % root) raise BadRepoError(
'Cannot create new repo in %s: directory is not empty.'
% root)
existed = True existed = True
full_path = os.path.realpath(root) full_path = os.path.realpath(root)
parent = os.path.dirname(full_path) parent = os.path.dirname(full_path)
if not os.access(parent, os.R_OK | os.W_OK): if not os.access(parent, os.R_OK | os.W_OK):
raise BadRepoError("Cannot create repository in %s: can't access parent!" % root) raise BadRepoError(
"Cannot create repository in %s: can't access parent!" % root)
try: try:
config_path = os.path.join(root, repo_config_name) config_path = os.path.join(root, repo_config_name)

View File

@ -82,6 +82,7 @@
'cmd.test_compiler_cmd', 'cmd.test_compiler_cmd',
] ]
def list_tests(): def list_tests():
"""Return names of all tests that can be run for Spack.""" """Return names of all tests that can be run for Spack."""
return test_names return test_names

View File

@ -29,6 +29,7 @@
from spack.concretize import find_spec from spack.concretize import find_spec
from spack.test.mock_packages_test import * from spack.test.mock_packages_test import *
class ConcretizeTest(MockPackagesTest): class ConcretizeTest(MockPackagesTest):
def check_spec(self, abstract, concrete): def check_spec(self, abstract, concrete):
@ -59,7 +60,6 @@ def check_spec(self, abstract, concrete):
if abstract.architecture and abstract.architecture.concrete: if abstract.architecture and abstract.architecture.concrete:
self.assertEqual(abstract.architecture, concrete.architecture) self.assertEqual(abstract.architecture, concrete.architecture)
def check_concretize(self, abstract_spec): def check_concretize(self, abstract_spec):
abstract = Spec(abstract_spec) abstract = Spec(abstract_spec)
concrete = abstract.concretized() concrete = abstract.concretized()
@ -70,29 +70,24 @@ def check_concretize(self, abstract_spec):
return concrete return concrete
def test_concretize_no_deps(self): def test_concretize_no_deps(self):
self.check_concretize('libelf') self.check_concretize('libelf')
self.check_concretize('libelf@0.8.13') self.check_concretize('libelf@0.8.13')
def test_concretize_dag(self): def test_concretize_dag(self):
self.check_concretize('callpath') self.check_concretize('callpath')
self.check_concretize('mpileaks') self.check_concretize('mpileaks')
self.check_concretize('libelf') self.check_concretize('libelf')
def test_concretize_variant(self): def test_concretize_variant(self):
self.check_concretize('mpich+debug') self.check_concretize('mpich+debug')
self.check_concretize('mpich~debug') self.check_concretize('mpich~debug')
self.check_concretize('mpich debug=2') self.check_concretize('mpich debug=2')
self.check_concretize('mpich') self.check_concretize('mpich')
def test_conretize_compiler_flags(self): def test_conretize_compiler_flags(self):
self.check_concretize('mpich cppflags="-O3"') self.check_concretize('mpich cppflags="-O3"')
def test_concretize_preferred_version(self): def test_concretize_preferred_version(self):
spec = self.check_concretize('python') spec = self.check_concretize('python')
self.assertEqual(spec.versions, ver('2.7.11')) self.assertEqual(spec.versions, ver('2.7.11'))
@ -100,7 +95,6 @@ def test_concretize_preferred_version(self):
spec = self.check_concretize('python@3.5.1') spec = self.check_concretize('python@3.5.1')
self.assertEqual(spec.versions, ver('3.5.1')) self.assertEqual(spec.versions, ver('3.5.1'))
def test_concretize_with_virtual(self): def test_concretize_with_virtual(self):
self.check_concretize('mpileaks ^mpi') self.check_concretize('mpileaks ^mpi')
self.check_concretize('mpileaks ^mpi@:1.1') self.check_concretize('mpileaks ^mpi@:1.1')
@ -111,7 +105,6 @@ def test_concretize_with_virtual(self):
self.check_concretize('mpileaks ^mpi@:1') self.check_concretize('mpileaks ^mpi@:1')
self.check_concretize('mpileaks ^mpi@1.2:2') self.check_concretize('mpileaks ^mpi@1.2:2')
def test_concretize_with_restricted_virtual(self): def test_concretize_with_restricted_virtual(self):
self.check_concretize('mpileaks ^mpich2') self.check_concretize('mpileaks ^mpich2')
@ -142,55 +135,55 @@ def test_concretize_with_restricted_virtual(self):
concrete = self.check_concretize('mpileaks ^mpich2@1.3.1:1.4') concrete = self.check_concretize('mpileaks ^mpich2@1.3.1:1.4')
self.assertTrue(concrete['mpich2'].satisfies('mpich2@1.3.1:1.4')) self.assertTrue(concrete['mpich2'].satisfies('mpich2@1.3.1:1.4'))
def test_concretize_with_provides_when(self): def test_concretize_with_provides_when(self):
"""Make sure insufficient versions of MPI are not in providers list when """Make sure insufficient versions of MPI are not in providers list when
we ask for some advanced version. we ask for some advanced version.
""" """
self.assertTrue(not any(spec.satisfies('mpich2@:1.0') self.assertTrue(
not any(spec.satisfies('mpich2@:1.0')
for spec in spack.repo.providers_for('mpi@2.1'))) for spec in spack.repo.providers_for('mpi@2.1')))
self.assertTrue(not any(spec.satisfies('mpich2@:1.1') self.assertTrue(
not any(spec.satisfies('mpich2@:1.1')
for spec in spack.repo.providers_for('mpi@2.2'))) for spec in spack.repo.providers_for('mpi@2.2')))
self.assertTrue(not any(spec.satisfies('mpich@:1') self.assertTrue(
not any(spec.satisfies('mpich@:1')
for spec in spack.repo.providers_for('mpi@2'))) for spec in spack.repo.providers_for('mpi@2')))
self.assertTrue(not any(spec.satisfies('mpich@:1') self.assertTrue(
not any(spec.satisfies('mpich@:1')
for spec in spack.repo.providers_for('mpi@3'))) for spec in spack.repo.providers_for('mpi@3')))
self.assertTrue(not any(spec.satisfies('mpich2') self.assertTrue(
not any(spec.satisfies('mpich2')
for spec in spack.repo.providers_for('mpi@3'))) for spec in spack.repo.providers_for('mpi@3')))
def test_concretize_two_virtuals(self): def test_concretize_two_virtuals(self):
"""Test a package with multiple virtual dependencies.""" """Test a package with multiple virtual dependencies."""
s = Spec('hypre').concretize() Spec('hypre').concretize()
def test_concretize_two_virtuals_with_one_bound(self): def test_concretize_two_virtuals_with_one_bound(self):
"""Test a package with multiple virtual dependencies and one preset.""" """Test a package with multiple virtual dependencies and one preset."""
s = Spec('hypre ^openblas').concretize() Spec('hypre ^openblas').concretize()
def test_concretize_two_virtuals_with_two_bound(self): def test_concretize_two_virtuals_with_two_bound(self):
"""Test a package with multiple virtual dependencies and two of them preset.""" """Test a package with multiple virtual deps and two of them preset."""
s = Spec('hypre ^openblas ^netlib-lapack').concretize() Spec('hypre ^openblas ^netlib-lapack').concretize()
def test_concretize_two_virtuals_with_dual_provider(self): def test_concretize_two_virtuals_with_dual_provider(self):
"""Test a package with multiple virtual dependencies and force a provider """Test a package with multiple virtual dependencies and force a provider
that provides both.""" that provides both."""
s = Spec('hypre ^openblas-with-lapack').concretize() Spec('hypre ^openblas-with-lapack').concretize()
def test_concretize_two_virtuals_with_dual_provider_and_a_conflict(self): def test_concretize_two_virtuals_with_dual_provider_and_a_conflict(self):
"""Test a package with multiple virtual dependencies and force a provider """Test a package with multiple virtual dependencies and force a
that provides both, and another conflicting package that provides one.""" provider that provides both, and another conflicting package that
provides one.
"""
s = Spec('hypre ^openblas-with-lapack ^netlib-lapack') s = Spec('hypre ^openblas-with-lapack ^netlib-lapack')
self.assertRaises(spack.spec.MultipleProviderError, s.concretize) self.assertRaises(spack.spec.MultipleProviderError, s.concretize)
def test_virtual_is_fully_expanded_for_callpath(self): def test_virtual_is_fully_expanded_for_callpath(self):
# force dependence on fake "zmpi" by asking for MPI 10.0 # force dependence on fake "zmpi" by asking for MPI 10.0
spec = Spec('callpath ^mpi@10.0') spec = Spec('callpath ^mpi@10.0')
@ -207,7 +200,6 @@ def test_virtual_is_fully_expanded_for_callpath(self):
self.assertTrue('fake' in spec._dependencies['zmpi'].spec) self.assertTrue('fake' in spec._dependencies['zmpi'].spec)
def test_virtual_is_fully_expanded_for_mpileaks(self): def test_virtual_is_fully_expanded_for_mpileaks(self):
spec = Spec('mpileaks ^mpi@10.0') spec = Spec('mpileaks ^mpi@10.0')
self.assertTrue('mpi' in spec._dependencies) self.assertTrue('mpi' in spec._dependencies)
@ -217,23 +209,24 @@ def test_virtual_is_fully_expanded_for_mpileaks(self):
self.assertTrue('zmpi' in spec._dependencies) self.assertTrue('zmpi' in spec._dependencies)
self.assertTrue('callpath' in spec._dependencies) self.assertTrue('callpath' in spec._dependencies)
self.assertTrue('zmpi' in spec._dependencies['callpath']. self.assertTrue(
spec._dependencies) 'zmpi' in spec._dependencies['callpath']
self.assertTrue('fake' in spec._dependencies['callpath']. .spec._dependencies)
spec._dependencies['zmpi']. self.assertTrue(
spec._dependencies) 'fake' in spec._dependencies['callpath']
.spec._dependencies['zmpi']
.spec._dependencies)
self.assertTrue(all(not 'mpi' in d._dependencies for d in spec.traverse())) self.assertTrue(
all('mpi' not in d._dependencies for d in spec.traverse()))
self.assertTrue('zmpi' in spec) self.assertTrue('zmpi' in spec)
self.assertTrue('mpi' in spec) self.assertTrue('mpi' in spec)
def test_my_dep_depends_on_provider_of_my_virtual_dep(self): def test_my_dep_depends_on_provider_of_my_virtual_dep(self):
spec = Spec('indirect_mpich') spec = Spec('indirect_mpich')
spec.normalize() spec.normalize()
spec.concretize() spec.concretize()
def test_compiler_inheritance(self): def test_compiler_inheritance(self):
spec = Spec('mpileaks') spec = Spec('mpileaks')
spec.normalize() spec.normalize()
@ -245,26 +238,26 @@ def test_compiler_inheritance(self):
self.assertTrue(spec['libdwarf'].compiler.satisfies('clang')) self.assertTrue(spec['libdwarf'].compiler.satisfies('clang'))
self.assertTrue(spec['libelf'].compiler.satisfies('clang')) self.assertTrue(spec['libelf'].compiler.satisfies('clang'))
def test_external_package(self): def test_external_package(self):
spec = Spec('externaltool%gcc') spec = Spec('externaltool%gcc')
spec.concretize() spec.concretize()
self.assertEqual(spec['externaltool'].external, '/path/to/external_tool') self.assertEqual(
spec['externaltool'].external, '/path/to/external_tool')
self.assertFalse('externalprereq' in spec) self.assertFalse('externalprereq' in spec)
self.assertTrue(spec['externaltool'].compiler.satisfies('gcc')) self.assertTrue(spec['externaltool'].compiler.satisfies('gcc'))
def test_external_package_module(self): def test_external_package_module(self):
# No tcl modules on darwin/linux machines # No tcl modules on darwin/linux machines
# TODO: improved way to check for this. # TODO: improved way to check for this.
if (spack.architecture.platform().name == 'darwin' or platform = spack.architecture.platform().name
spack.architecture.platform().name == 'linux'): if (platform == 'darwin' or platform == 'linux'):
return return
spec = Spec('externalmodule') spec = Spec('externalmodule')
spec.concretize() spec.concretize()
self.assertEqual(spec['externalmodule'].external_module, 'external-module') self.assertEqual(
spec['externalmodule'].external_module, 'external-module')
self.assertFalse('externalprereq' in spec) self.assertFalse('externalprereq' in spec)
self.assertTrue(spec['externalmodule'].compiler.satisfies('gcc')) self.assertTrue(spec['externalmodule'].compiler.satisfies('gcc'))
@ -277,16 +270,16 @@ def test_nobuild_package(self):
got_error = True got_error = True
self.assertTrue(got_error) self.assertTrue(got_error)
def test_external_and_virtual(self): def test_external_and_virtual(self):
spec = Spec('externaltest') spec = Spec('externaltest')
spec.concretize() spec.concretize()
self.assertEqual(spec['externaltool'].external, '/path/to/external_tool') self.assertEqual(
self.assertEqual(spec['stuff'].external, '/path/to/external_virtual_gcc') spec['externaltool'].external, '/path/to/external_tool')
self.assertEqual(
spec['stuff'].external, '/path/to/external_virtual_gcc')
self.assertTrue(spec['externaltool'].compiler.satisfies('gcc')) self.assertTrue(spec['externaltool'].compiler.satisfies('gcc'))
self.assertTrue(spec['stuff'].compiler.satisfies('gcc')) self.assertTrue(spec['stuff'].compiler.satisfies('gcc'))
def test_find_spec_parents(self): def test_find_spec_parents(self):
"""Tests the spec finding logic used by concretization. """ """Tests the spec finding logic used by concretization. """
s = Spec('a +foo', s = Spec('a +foo',
@ -297,7 +290,6 @@ def test_find_spec_parents(self):
self.assertEqual('a', find_spec(s['b'], lambda s: '+foo' in s).name) self.assertEqual('a', find_spec(s['b'], lambda s: '+foo' in s).name)
def test_find_spec_children(self): def test_find_spec_children(self):
s = Spec('a', s = Spec('a',
Spec('b +foo', Spec('b +foo',
@ -312,7 +304,6 @@ def test_find_spec_children(self):
Spec('e +foo')) Spec('e +foo'))
self.assertEqual('c', find_spec(s['b'], lambda s: '+foo' in s).name) self.assertEqual('c', find_spec(s['b'], lambda s: '+foo' in s).name)
def test_find_spec_sibling(self): def test_find_spec_sibling(self):
s = Spec('a', s = Spec('a',
Spec('b +foo', Spec('b +foo',
@ -330,7 +321,6 @@ def test_find_spec_sibling(self):
Spec('f +foo'))) Spec('f +foo')))
self.assertEqual('f', find_spec(s['b'], lambda s: '+foo' in s).name) self.assertEqual('f', find_spec(s['b'], lambda s: '+foo' in s).name)
def test_find_spec_self(self): def test_find_spec_self(self):
s = Spec('a', s = Spec('a',
Spec('b +foo', Spec('b +foo',
@ -339,7 +329,6 @@ def test_find_spec_self(self):
Spec('e')) Spec('e'))
self.assertEqual('b', find_spec(s['b'], lambda s: '+foo' in s).name) self.assertEqual('b', find_spec(s['b'], lambda s: '+foo' in s).name)
def test_find_spec_none(self): def test_find_spec_none(self):
s = Spec('a', s = Spec('a',
Spec('b', Spec('b',
@ -348,7 +337,6 @@ def test_find_spec_none(self):
Spec('e')) Spec('e'))
self.assertEqual(None, find_spec(s['b'], lambda s: '+foo' in s)) self.assertEqual(None, find_spec(s['b'], lambda s: '+foo' in s))
def test_compiler_child(self): def test_compiler_child(self):
s = Spec('mpileaks%clang ^dyninst%gcc') s = Spec('mpileaks%clang ^dyninst%gcc')
s.concretize() s.concretize()

View File

@ -31,7 +31,6 @@
import spack import spack
from llnl.util.filesystem import join_path from llnl.util.filesystem import join_path
from llnl.util.lock import *
from llnl.util.tty.colify import colify from llnl.util.tty.colify import colify
from spack.test.mock_database import MockDatabase from spack.test.mock_database import MockDatabase
@ -104,10 +103,12 @@ def test_010_all_install_sanity(self):
self.assertEqual(len(libelf_specs), 1) self.assertEqual(len(libelf_specs), 1)
# Query by dependency # Query by dependency
self.assertEqual(len([s for s in all_specs if s.satisfies('mpileaks ^mpich')]), 1) self.assertEqual(
self.assertEqual(len([s for s in all_specs if s.satisfies('mpileaks ^mpich2')]), 1) len([s for s in all_specs if s.satisfies('mpileaks ^mpich')]), 1)
self.assertEqual(len([s for s in all_specs if s.satisfies('mpileaks ^zmpi')]), 1) self.assertEqual(
len([s for s in all_specs if s.satisfies('mpileaks ^mpich2')]), 1)
self.assertEqual(
len([s for s in all_specs if s.satisfies('mpileaks ^zmpi')]), 1)
def test_015_write_and_read(self): def test_015_write_and_read(self):
# write and read DB # write and read DB
@ -122,7 +123,6 @@ def test_015_write_and_read(self):
self.assertEqual(new_rec.path, rec.path) self.assertEqual(new_rec.path, rec.path)
self.assertEqual(new_rec.installed, rec.installed) self.assertEqual(new_rec.installed, rec.installed)
def _check_db_sanity(self): def _check_db_sanity(self):
"""Utiilty function to check db against install layout.""" """Utiilty function to check db against install layout."""
expected = sorted(spack.install_layout.all_specs()) expected = sorted(spack.install_layout.all_specs())
@ -132,12 +132,10 @@ def _check_db_sanity(self):
for e, a in zip(expected, actual): for e, a in zip(expected, actual):
self.assertEqual(e, a) self.assertEqual(e, a)
def test_020_db_sanity(self): def test_020_db_sanity(self):
"""Make sure query() returns what's actually in the db.""" """Make sure query() returns what's actually in the db."""
self._check_db_sanity() self._check_db_sanity()
def test_030_db_sanity_from_another_process(self): def test_030_db_sanity_from_another_process(self):
def read_and_modify(): def read_and_modify():
self._check_db_sanity() # check that other process can read DB self._check_db_sanity() # check that other process can read DB
@ -152,14 +150,12 @@ def read_and_modify():
with self.installed_db.read_transaction(): with self.installed_db.read_transaction():
self.assertEqual(len(self.installed_db.query('mpileaks ^zmpi')), 0) self.assertEqual(len(self.installed_db.query('mpileaks ^zmpi')), 0)
def test_040_ref_counts(self): def test_040_ref_counts(self):
"""Ensure that we got ref counts right when we read the DB.""" """Ensure that we got ref counts right when we read the DB."""
self.installed_db._check_ref_counts() self.installed_db._check_ref_counts()
def test_050_basic_query(self): def test_050_basic_query(self):
"""Ensure that querying the database is consistent with what is installed.""" """Ensure querying database is consistent with what is installed."""
# query everything # query everything
self.assertEqual(len(spack.installed_db.query()), 13) self.assertEqual(len(spack.installed_db.query()), 13)
@ -186,7 +182,6 @@ def test_050_basic_query(self):
self.assertEqual(len(self.installed_db.query('mpileaks ^mpich2')), 1) self.assertEqual(len(self.installed_db.query('mpileaks ^mpich2')), 1)
self.assertEqual(len(self.installed_db.query('mpileaks ^zmpi')), 1) self.assertEqual(len(self.installed_db.query('mpileaks ^zmpi')), 1)
def _check_remove_and_add_package(self, spec): def _check_remove_and_add_package(self, spec):
"""Remove a spec from the DB, then add it and make sure everything's """Remove a spec from the DB, then add it and make sure everything's
still ok once it is added. This checks that it was still ok once it is added. This checks that it was
@ -215,15 +210,12 @@ def _check_remove_and_add_package(self, spec):
self._check_db_sanity() self._check_db_sanity()
self.installed_db._check_ref_counts() self.installed_db._check_ref_counts()
def test_060_remove_and_add_root_package(self): def test_060_remove_and_add_root_package(self):
self._check_remove_and_add_package('mpileaks ^mpich') self._check_remove_and_add_package('mpileaks ^mpich')
def test_070_remove_and_add_dependency_package(self): def test_070_remove_and_add_dependency_package(self):
self._check_remove_and_add_package('dyninst') self._check_remove_and_add_package('dyninst')
def test_080_root_ref_counts(self): def test_080_root_ref_counts(self):
rec = self.installed_db.get_record('mpileaks ^mpich') rec = self.installed_db.get_record('mpileaks ^mpich')
@ -231,44 +223,52 @@ def test_080_root_ref_counts(self):
self.installed_db.remove('mpileaks ^mpich') self.installed_db.remove('mpileaks ^mpich')
# record no longer in DB # record no longer in DB
self.assertEqual(self.installed_db.query('mpileaks ^mpich', installed=any), []) self.assertEqual(
self.installed_db.query('mpileaks ^mpich', installed=any), [])
# record's deps have updated ref_counts # record's deps have updated ref_counts
self.assertEqual(self.installed_db.get_record('callpath ^mpich').ref_count, 0) self.assertEqual(
self.installed_db.get_record('callpath ^mpich').ref_count, 0)
self.assertEqual(self.installed_db.get_record('mpich').ref_count, 1) self.assertEqual(self.installed_db.get_record('mpich').ref_count, 1)
# put the spec back # Put the spec back
self.installed_db.add(rec.spec, rec.path) self.installed_db.add(rec.spec, rec.path)
# record is present again # record is present again
self.assertEqual(len(self.installed_db.query('mpileaks ^mpich', installed=any)), 1) self.assertEqual(
len(self.installed_db.query('mpileaks ^mpich', installed=any)), 1)
# dependencies have ref counts updated # dependencies have ref counts updated
self.assertEqual(self.installed_db.get_record('callpath ^mpich').ref_count, 1) self.assertEqual(
self.installed_db.get_record('callpath ^mpich').ref_count, 1)
self.assertEqual(self.installed_db.get_record('mpich').ref_count, 2) self.assertEqual(self.installed_db.get_record('mpich').ref_count, 2)
def test_090_non_root_ref_counts(self): def test_090_non_root_ref_counts(self):
mpileaks_mpich_rec = self.installed_db.get_record('mpileaks ^mpich') self.installed_db.get_record('mpileaks ^mpich')
callpath_mpich_rec = self.installed_db.get_record('callpath ^mpich') self.installed_db.get_record('callpath ^mpich')
# "force remove" a non-root spec from the DB # "force remove" a non-root spec from the DB
self.installed_db.remove('callpath ^mpich') self.installed_db.remove('callpath ^mpich')
# record still in DB but marked uninstalled # record still in DB but marked uninstalled
self.assertEqual(self.installed_db.query('callpath ^mpich', installed=True), []) self.assertEqual(
self.assertEqual(len(self.installed_db.query('callpath ^mpich', installed=any)), 1) self.installed_db.query('callpath ^mpich', installed=True), [])
self.assertEqual(
len(self.installed_db.query('callpath ^mpich', installed=any)), 1)
# record and its deps have same ref_counts # record and its deps have same ref_counts
self.assertEqual(self.installed_db.get_record('callpath ^mpich', installed=any).ref_count, 1) self.assertEqual(self.installed_db.get_record(
'callpath ^mpich', installed=any).ref_count, 1)
self.assertEqual(self.installed_db.get_record('mpich').ref_count, 2) self.assertEqual(self.installed_db.get_record('mpich').ref_count, 2)
# remove only dependent of uninstalled callpath record # remove only dependent of uninstalled callpath record
self.installed_db.remove('mpileaks ^mpich') self.installed_db.remove('mpileaks ^mpich')
# record and parent are completely gone. # record and parent are completely gone.
self.assertEqual(self.installed_db.query('mpileaks ^mpich', installed=any), []) self.assertEqual(
self.assertEqual(self.installed_db.query('callpath ^mpich', installed=any), []) self.installed_db.query('mpileaks ^mpich', installed=any), [])
self.assertEqual(
self.installed_db.query('callpath ^mpich', installed=any), [])
# mpich ref count updated properly. # mpich ref count updated properly.
mpich_rec = self.installed_db.get_record('mpich') mpich_rec = self.installed_db.get_record('mpich')
@ -282,14 +282,16 @@ def fail_while_writing():
with self.installed_db.read_transaction(): with self.installed_db.read_transaction():
self.assertEqual( self.assertEqual(
len(self.installed_db.query('mpileaks ^zmpi', installed=any)), 1) len(self.installed_db.query('mpileaks ^zmpi', installed=any)),
1)
self.assertRaises(Exception, fail_while_writing) self.assertRaises(Exception, fail_while_writing)
# reload DB and make sure zmpi is still there. # reload DB and make sure zmpi is still there.
with self.installed_db.read_transaction(): with self.installed_db.read_transaction():
self.assertEqual( self.assertEqual(
len(self.installed_db.query('mpileaks ^zmpi', installed=any)), 1) len(self.installed_db.query('mpileaks ^zmpi', installed=any)),
1)
def test_110_no_write_with_exception_on_install(self): def test_110_no_write_with_exception_on_install(self):
def fail_while_writing(): def fail_while_writing():

View File

@ -30,7 +30,6 @@
import tempfile import tempfile
import unittest import unittest
import spack
from spack.file_cache import FileCache from spack.file_cache import FileCache

View File

@ -46,21 +46,21 @@ def setUp(self):
self.lock_path = join_path(self.tempdir, 'lockfile') self.lock_path = join_path(self.tempdir, 'lockfile')
touch(self.lock_path) touch(self.lock_path)
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tempdir, ignore_errors=True) shutil.rmtree(self.tempdir, ignore_errors=True)
def multiproc_test(self, *functions): def multiproc_test(self, *functions):
"""Order some processes using simple barrier synchronization.""" """Order some processes using simple barrier synchronization."""
b = Barrier(len(functions), timeout=barrier_timeout) b = Barrier(len(functions), timeout=barrier_timeout)
procs = [Process(target=f, args=(b,)) for f in functions] procs = [Process(target=f, args=(b,)) for f in functions]
for p in procs: p.start()
for p in procs:
p.start()
for p in procs: for p in procs:
p.join() p.join()
self.assertEqual(p.exitcode, 0) self.assertEqual(p.exitcode, 0)
# #
# Process snippets below can be composed into tests. # Process snippets below can be composed into tests.
# #
@ -88,7 +88,6 @@ def timeout_read(self, barrier):
self.assertRaises(LockError, lock.acquire_read, 0.1) self.assertRaises(LockError, lock.acquire_read, 0.1)
barrier.wait() barrier.wait()
# #
# Test that exclusive locks on other processes time out when an # Test that exclusive locks on other processes time out when an
# exclusive lock is held. # exclusive lock is held.
@ -97,11 +96,13 @@ def test_write_lock_timeout_on_write(self):
self.multiproc_test(self.acquire_write, self.timeout_write) self.multiproc_test(self.acquire_write, self.timeout_write)
def test_write_lock_timeout_on_write_2(self): def test_write_lock_timeout_on_write_2(self):
self.multiproc_test(self.acquire_write, self.timeout_write, self.timeout_write) self.multiproc_test(
self.acquire_write, self.timeout_write, self.timeout_write)
def test_write_lock_timeout_on_write_3(self): def test_write_lock_timeout_on_write_3(self):
self.multiproc_test(self.acquire_write, self.timeout_write, self.timeout_write, self.timeout_write) self.multiproc_test(
self.acquire_write, self.timeout_write, self.timeout_write,
self.timeout_write)
# #
# Test that shared locks on other processes time out when an # Test that shared locks on other processes time out when an
@ -111,11 +112,13 @@ def test_read_lock_timeout_on_write(self):
self.multiproc_test(self.acquire_write, self.timeout_read) self.multiproc_test(self.acquire_write, self.timeout_read)
def test_read_lock_timeout_on_write_2(self): def test_read_lock_timeout_on_write_2(self):
self.multiproc_test(self.acquire_write, self.timeout_read, self.timeout_read) self.multiproc_test(
self.acquire_write, self.timeout_read, self.timeout_read)
def test_read_lock_timeout_on_write_3(self): def test_read_lock_timeout_on_write_3(self):
self.multiproc_test(self.acquire_write, self.timeout_read, self.timeout_read, self.timeout_read) self.multiproc_test(
self.acquire_write, self.timeout_read, self.timeout_read,
self.timeout_read)
# #
# Test that exclusive locks time out when shared locks are held. # Test that exclusive locks time out when shared locks are held.
@ -124,27 +127,35 @@ def test_write_lock_timeout_on_read(self):
self.multiproc_test(self.acquire_read, self.timeout_write) self.multiproc_test(self.acquire_read, self.timeout_write)
def test_write_lock_timeout_on_read_2(self): def test_write_lock_timeout_on_read_2(self):
self.multiproc_test(self.acquire_read, self.timeout_write, self.timeout_write) self.multiproc_test(
self.acquire_read, self.timeout_write, self.timeout_write)
def test_write_lock_timeout_on_read_3(self): def test_write_lock_timeout_on_read_3(self):
self.multiproc_test(self.acquire_read, self.timeout_write, self.timeout_write, self.timeout_write) self.multiproc_test(
self.acquire_read, self.timeout_write, self.timeout_write,
self.timeout_write)
# #
# Test that exclusive locks time while lots of shared locks are held. # Test that exclusive locks time while lots of shared locks are held.
# #
def test_write_lock_timeout_with_multiple_readers_2_1(self): def test_write_lock_timeout_with_multiple_readers_2_1(self):
self.multiproc_test(self.acquire_read, self.acquire_read, self.timeout_write) self.multiproc_test(
self.acquire_read, self.acquire_read, self.timeout_write)
def test_write_lock_timeout_with_multiple_readers_2_2(self): def test_write_lock_timeout_with_multiple_readers_2_2(self):
self.multiproc_test(self.acquire_read, self.acquire_read, self.timeout_write, self.timeout_write) self.multiproc_test(
self.acquire_read, self.acquire_read, self.timeout_write,
self.timeout_write)
def test_write_lock_timeout_with_multiple_readers_3_1(self): def test_write_lock_timeout_with_multiple_readers_3_1(self):
self.multiproc_test(self.acquire_read, self.acquire_read, self.acquire_read, self.timeout_write) self.multiproc_test(
self.acquire_read, self.acquire_read, self.acquire_read,
self.timeout_write)
def test_write_lock_timeout_with_multiple_readers_3_2(self): def test_write_lock_timeout_with_multiple_readers_3_2(self):
self.multiproc_test(self.acquire_read, self.acquire_read, self.acquire_read, self.timeout_write, self.timeout_write) self.multiproc_test(
self.acquire_read, self.acquire_read, self.acquire_read,
self.timeout_write, self.timeout_write)
# #
# Longer test case that ensures locks are reusable. Ordering is # Longer test case that ensures locks are reusable. Ordering is
@ -271,13 +282,17 @@ def exit_fn(t, v, tb):
lock = Lock(self.lock_path) lock = Lock(self.lock_path)
vals = {'entered': False, 'exited': False, 'exception': False} vals = {'entered': False, 'exited': False, 'exception': False}
with ReadTransaction(lock, enter_fn, exit_fn): pass with ReadTransaction(lock, enter_fn, exit_fn):
pass
self.assertTrue(vals['entered']) self.assertTrue(vals['entered'])
self.assertTrue(vals['exited']) self.assertTrue(vals['exited'])
self.assertFalse(vals['exception']) self.assertFalse(vals['exception'])
vals = {'entered': False, 'exited': False, 'exception': False} vals = {'entered': False, 'exited': False, 'exception': False}
with WriteTransaction(lock, enter_fn, exit_fn): pass with WriteTransaction(lock, enter_fn, exit_fn):
pass
self.assertTrue(vals['entered']) self.assertTrue(vals['entered'])
self.assertTrue(vals['exited']) self.assertTrue(vals['exited'])
self.assertFalse(vals['exception']) self.assertFalse(vals['exception'])
@ -329,7 +344,9 @@ def exit_fn(t, v, tb):
vals = {'entered': False, 'exited': False, 'exited_fn': False, vals = {'entered': False, 'exited': False, 'exited_fn': False,
'exception': False, 'exception_fn': False} 'exception': False, 'exception_fn': False}
with ReadTransaction(lock, TestContextManager, exit_fn): pass with ReadTransaction(lock, TestContextManager, exit_fn):
pass
self.assertTrue(vals['entered']) self.assertTrue(vals['entered'])
self.assertTrue(vals['exited']) self.assertTrue(vals['exited'])
self.assertFalse(vals['exception']) self.assertFalse(vals['exception'])
@ -338,7 +355,9 @@ def exit_fn(t, v, tb):
vals = {'entered': False, 'exited': False, 'exited_fn': False, vals = {'entered': False, 'exited': False, 'exited_fn': False,
'exception': False, 'exception_fn': False} 'exception': False, 'exception_fn': False}
with ReadTransaction(lock, TestContextManager): pass with ReadTransaction(lock, TestContextManager):
pass
self.assertTrue(vals['entered']) self.assertTrue(vals['entered'])
self.assertTrue(vals['exited']) self.assertTrue(vals['exited'])
self.assertFalse(vals['exception']) self.assertFalse(vals['exception'])
@ -347,7 +366,9 @@ def exit_fn(t, v, tb):
vals = {'entered': False, 'exited': False, 'exited_fn': False, vals = {'entered': False, 'exited': False, 'exited_fn': False,
'exception': False, 'exception_fn': False} 'exception': False, 'exception_fn': False}
with WriteTransaction(lock, TestContextManager, exit_fn): pass with WriteTransaction(lock, TestContextManager, exit_fn):
pass
self.assertTrue(vals['entered']) self.assertTrue(vals['entered'])
self.assertTrue(vals['exited']) self.assertTrue(vals['exited'])
self.assertFalse(vals['exception']) self.assertFalse(vals['exception'])
@ -356,7 +377,9 @@ def exit_fn(t, v, tb):
vals = {'entered': False, 'exited': False, 'exited_fn': False, vals = {'entered': False, 'exited': False, 'exited_fn': False,
'exception': False, 'exception_fn': False} 'exception': False, 'exception_fn': False}
with WriteTransaction(lock, TestContextManager): pass with WriteTransaction(lock, TestContextManager):
pass
self.assertTrue(vals['entered']) self.assertTrue(vals['entered'])
self.assertTrue(vals['exited']) self.assertTrue(vals['exited'])
self.assertFalse(vals['exception']) self.assertFalse(vals['exception'])

View File

@ -22,27 +22,28 @@
# License along with this program; if not, write to the Free Software # License along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
############################################################################## ##############################################################################
"""Tests for provider index cache files.
Tests assume that mock packages provide this:
{'blas': {
blas: set([netlib-blas, openblas, openblas-with-lapack])},
'lapack': {lapack: set([netlib-lapack, openblas-with-lapack])},
'mpi': {mpi@:1: set([mpich@:1]),
mpi@:2.0: set([mpich2]),
mpi@:2.1: set([mpich2@1.1:]),
mpi@:2.2: set([mpich2@1.2:]),
mpi@:3: set([mpich@3:]),
mpi@:10.0: set([zmpi])},
'stuff': {stuff: set([externalvirtual])}}
"""
from StringIO import StringIO from StringIO import StringIO
import unittest
import spack import spack
from spack.spec import Spec from spack.spec import Spec
from spack.provider_index import ProviderIndex from spack.provider_index import ProviderIndex
from spack.test.mock_packages_test import * from spack.test.mock_packages_test import *
# Test assume that mock packages provide this:
#
# {'blas': {
# blas: set([netlib-blas, openblas, openblas-with-lapack])},
# 'lapack': {lapack: set([netlib-lapack, openblas-with-lapack])},
# 'mpi': {mpi@:1: set([mpich@:1]),
# mpi@:2.0: set([mpich2]),
# mpi@:2.1: set([mpich2@1.1:]),
# mpi@:2.2: set([mpich2@1.2:]),
# mpi@:3: set([mpich@3:]),
# mpi@:10.0: set([zmpi])},
# 'stuff': {stuff: set([externalvirtual])}}
#
class ProviderIndexTest(MockPackagesTest): class ProviderIndexTest(MockPackagesTest):
@ -57,7 +58,6 @@ def test_yaml_round_trip(self):
self.assertEqual(p, q) self.assertEqual(p, q)
def test_providers_for_simple(self): def test_providers_for_simple(self):
p = ProviderIndex(spack.repo.all_package_names()) p = ProviderIndex(spack.repo.all_package_names())
@ -70,7 +70,6 @@ def test_providers_for_simple(self):
self.assertTrue(Spec('netlib-lapack') in lapack_providers) self.assertTrue(Spec('netlib-lapack') in lapack_providers)
self.assertTrue(Spec('openblas-with-lapack') in lapack_providers) self.assertTrue(Spec('openblas-with-lapack') in lapack_providers)
def test_mpi_providers(self): def test_mpi_providers(self):
p = ProviderIndex(spack.repo.all_package_names()) p = ProviderIndex(spack.repo.all_package_names())
@ -83,13 +82,11 @@ def test_mpi_providers(self):
self.assertTrue(Spec('mpich@3:') in mpi_3_providers) self.assertTrue(Spec('mpich@3:') in mpi_3_providers)
self.assertTrue(Spec('zmpi') in mpi_3_providers) self.assertTrue(Spec('zmpi') in mpi_3_providers)
def test_equal(self): def test_equal(self):
p = ProviderIndex(spack.repo.all_package_names()) p = ProviderIndex(spack.repo.all_package_names())
q = ProviderIndex(spack.repo.all_package_names()) q = ProviderIndex(spack.repo.all_package_names())
self.assertEqual(p, q) self.assertEqual(p, q)
def test_copy(self): def test_copy(self):
p = ProviderIndex(spack.repo.all_package_names()) p = ProviderIndex(spack.repo.all_package_names())
q = p.copy() q = p.copy()

View File

@ -30,6 +30,7 @@
from spack.spec import Spec from spack.spec import Spec
from spack.test.mock_packages_test import * from spack.test.mock_packages_test import *
class SpecYamlTest(MockPackagesTest): class SpecYamlTest(MockPackagesTest):
def check_yaml_round_trip(self, spec): def check_yaml_round_trip(self, spec):
@ -37,30 +38,25 @@ def check_yaml_round_trip(self, spec):
spec_from_yaml = Spec.from_yaml(yaml_text) spec_from_yaml = Spec.from_yaml(yaml_text)
self.assertTrue(spec.eq_dag(spec_from_yaml)) self.assertTrue(spec.eq_dag(spec_from_yaml))
def test_simple_spec(self): def test_simple_spec(self):
spec = Spec('mpileaks') spec = Spec('mpileaks')
self.check_yaml_round_trip(spec) self.check_yaml_round_trip(spec)
def test_normal_spec(self): def test_normal_spec(self):
spec = Spec('mpileaks+debug~opt') spec = Spec('mpileaks+debug~opt')
spec.normalize() spec.normalize()
self.check_yaml_round_trip(spec) self.check_yaml_round_trip(spec)
def test_ambiguous_version_spec(self): def test_ambiguous_version_spec(self):
spec = Spec('mpileaks@1.0:5.0,6.1,7.3+debug~opt') spec = Spec('mpileaks@1.0:5.0,6.1,7.3+debug~opt')
spec.normalize() spec.normalize()
self.check_yaml_round_trip(spec) self.check_yaml_round_trip(spec)
def test_concrete_spec(self): def test_concrete_spec(self):
spec = Spec('mpileaks+debug~opt') spec = Spec('mpileaks+debug~opt')
spec.concretize() spec.concretize()
self.check_yaml_round_trip(spec) self.check_yaml_round_trip(spec)
def test_yaml_subdag(self): def test_yaml_subdag(self):
spec = Spec('mpileaks^mpich+debug') spec = Spec('mpileaks^mpich+debug')
spec.concretize() spec.concretize()