tests/py-chainer: convert to new stand-alone test process (#38365)
* tests/py-chainer: convert to new stand-alone test process * py-chainer: add skip_modules entry for onnx_chainer
This commit is contained in:
parent
f53c68e005
commit
930b843885
@ -24,6 +24,8 @@ class PyChainer(PythonPackage):
|
||||
|
||||
maintainers("adamjstewart")
|
||||
|
||||
skip_modules = ["onnx_chainer"]
|
||||
|
||||
version("7.2.0", sha256="6e2fba648cc5b8a5421e494385b76fe5ec154f1028a1c5908557f5d16c04f0b3")
|
||||
version("6.7.0", sha256="87cb3378a35e7c5c695028ec91d58dc062356bc91412384ea939d71374610389")
|
||||
|
||||
@ -48,25 +50,21 @@ def cache_test_sources(self):
|
||||
if "+mn" in self.spec:
|
||||
self.cache_extra_test_sources("examples")
|
||||
|
||||
def test(self):
|
||||
if "+mn" in self.spec:
|
||||
# Run test of ChainerMN
|
||||
test_dir = self.test_suite.current_test_data_dir
|
||||
def test_chainermn(self):
|
||||
"""run the ChainerMN test"""
|
||||
if "+mn" not in self.spec:
|
||||
raise SkipTest("Test only supported when built with +mn")
|
||||
|
||||
mnist_dir = join_path(self.install_test_root, "examples", "chainermn", "mnist")
|
||||
mnist_file = join_path(mnist_dir, "train_mnist.py")
|
||||
mpi_name = self.spec["mpi"].prefix.bin.mpirun
|
||||
python_exe = self.spec["python"].command.path
|
||||
opts = ["-n", "4", python_exe, mnist_file, "-o", test_dir]
|
||||
env["OMP_NUM_THREADS"] = "4"
|
||||
mnist_file = join_path(self.install_test_root.examples.chainermn.mnist, "train_mnist.py")
|
||||
mpirun = which(self.spec["mpi"].prefix.bin.mpirun)
|
||||
opts = ["-n", "4", self.spec["python"].command.path, mnist_file, "-o", "."]
|
||||
env["OMP_NUM_THREADS"] = "4"
|
||||
|
||||
self.run_test(mpi_name, options=opts, work_dir=test_dir)
|
||||
mpirun(*opts)
|
||||
|
||||
# check results
|
||||
json_open = open(join_path(test_dir, "log"), "r")
|
||||
json_load = json.load(json_open)
|
||||
v = dict([(d.get("epoch"), d.get("main/accuracy")) for d in json_load])
|
||||
if 1 not in v or 20 not in v:
|
||||
raise RuntimeError("Cannot find epoch 1 or epoch 20")
|
||||
if abs(1.0 - v[1]) < abs(1.0 - v[20]):
|
||||
raise RuntimeError("ChainerMN Test Failed !")
|
||||
# check results
|
||||
json_open = open(join_path(".", "log"), "r")
|
||||
json_load = json.load(json_open)
|
||||
v = dict([(d.get("epoch"), d.get("main/accuracy")) for d in json_load])
|
||||
assert (1 in v) or (20 in v), "Cannot find epoch 1 or epoch 20"
|
||||
assert abs(1.0 - v[1]) >= abs(1.0 - v[20]), "ChainerMN Test Failed!"
|
||||
|
Loading…
Reference in New Issue
Block a user