mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-09 13:25:32 +08:00
Fix MPI distributed tests with CUDA backend (#2775)
This commit is contained in:
13
.github/actions/test-linux/action.yml
vendored
13
.github/actions/test-linux/action.yml
vendored
@@ -9,13 +9,18 @@ inputs:
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Run MPI tests
|
||||
shell: bash
|
||||
run: |
|
||||
echo "::group::MPI tests"
|
||||
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run distributed tests
|
||||
# FIXME: This test fails with CUDA build.
|
||||
if: ${{ inputs.cpu-only == 'true' }}
|
||||
shell: bash
|
||||
run: |
|
||||
echo "::group::Distributed tests"
|
||||
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if grep -Fq '[WARN]' stderr.log ; then
|
||||
grep -F '[WARN]' stderr.log
|
||||
@@ -34,7 +39,7 @@ runs:
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run Python tests - GPU
|
||||
if: ${{ !inputs.cpu-only }}
|
||||
if: ${{ inputs.cpu-only == 'false' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: gpu
|
||||
@@ -53,7 +58,7 @@ runs:
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run CPP tests - GPU
|
||||
if: ${{ !inputs.cpu-only }}
|
||||
if: ${{ inputs.cpu-only == 'false' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: gpu
|
||||
|
||||
Reference in New Issue
Block a user