From ddaa4b7dcbdc1412df907fe2c6b881084bf34e49 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 10 Apr 2025 17:01:17 -0700 Subject: [PATCH] Fix the test and add custom min/max reductions for uncommon MPI types (#2060) --- mlx/backend/cpu/distributed.cpp | 3 +- mlx/distributed/mpi/mpi.cpp | 76 ++++++++++++++++++++++++-- python/tests/mpi_test_distributed.py | 78 +++++++++++++++------------ python/tests/ring_test_distributed.py | 25 ++++++--- 4 files changed, 135 insertions(+), 47 deletions(-) diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp index dd4d179ac..a3edf8f49 100644 --- a/mlx/backend/cpu/distributed.cpp +++ b/mlx/backend/cpu/distributed.cpp @@ -53,7 +53,8 @@ void AllReduce::eval_cpu( distributed::detail::all_min(group(), in, outputs[0], stream()); break; default: - throw std::runtime_error("Only all reduce sum, min and max are supported for now"); + throw std::runtime_error( + "Only all reduce sum, min and max are supported for now"); } } diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index f48009397..e80a1759f 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -50,6 +50,46 @@ void simple_sum( template void simple_sum(void*, void*, int*, MPI_Datatype*); template void simple_sum(void*, void*, int*, MPI_Datatype*); +template +void simple_max( + void* input, + void* accumulator, + int* len, + MPI_Datatype* datatype) { + T* in = (T*)input; + T* acc = (T*)accumulator; + int N = *len; + + while (N-- > 0) { + *acc = std::max(*acc, *in); + acc++; + in++; + } +} +template void simple_max(void*, void*, int*, MPI_Datatype*); +template void simple_max(void*, void*, int*, MPI_Datatype*); +template void simple_max(void*, void*, int*, MPI_Datatype*); + +template +void simple_min( + void* input, + void* accumulator, + int* len, + MPI_Datatype* datatype) { + T* in = (T*)input; + T* acc = (T*)accumulator; + int N = *len; + + while (N-- > 0) { + *acc = std::min(*acc, *in); + acc++; + in++; + } +} +template void simple_min(void*, void*, int*, MPI_Datatype*); +template void simple_min(void*, void*, int*, MPI_Datatype*); +template void simple_min(void*, void*, int*, MPI_Datatype*); + struct MPIWrapper { MPIWrapper() { initialized_ = false; @@ -129,9 +169,15 @@ struct MPIWrapper { mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_); mpi_type_commit(&mpi_bfloat16_); - // Custom sum ops + // Custom reduction ops mpi_op_create(&simple_sum, 1, &op_sum_f16_); mpi_op_create(&simple_sum, 1, &op_sum_bf16_); + mpi_op_create(&simple_max, 1, &op_max_f16_); + mpi_op_create(&simple_max, 1, &op_max_bf16_); + mpi_op_create(&simple_max, 1, &op_max_c64_); + mpi_op_create(&simple_min, 1, &op_min_f16_); + mpi_op_create(&simple_min, 1, &op_min_bf16_); + mpi_op_create(&simple_min, 1, &op_min_c64_); initialized_ = true; } @@ -194,11 +240,29 @@ struct MPIWrapper { } MPI_Op op_max(const array& arr) { - return op_max_; + switch (arr.dtype()) { + case float16: + return op_max_f16_; + case bfloat16: + return op_max_bf16_; + case complex64: + return op_max_c64_; + default: + return op_max_; + } } MPI_Op op_min(const array& arr) { - return op_min_; + switch (arr.dtype()) { + case float16: + return op_min_f16_; + case bfloat16: + return op_min_bf16_; + case complex64: + return op_min_c64_; + default: + return op_min_; + } } void* libmpi_handle_; @@ -230,7 +294,13 @@ struct MPIWrapper { MPI_Op op_sum_f16_; MPI_Op op_sum_bf16_; MPI_Op op_max_; + MPI_Op op_max_f16_; + MPI_Op op_max_bf16_; + MPI_Op op_max_c64_; MPI_Op op_min_; + MPI_Op op_min_f16_; + MPI_Op op_min_bf16_; + MPI_Op op_min_c64_; // Datatypes MPI_Datatype mpi_bool_; diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index 65fbd09ce..26d340dbe 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -30,27 +30,51 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): def test_all_reduce(self): world = mx.distributed.init() dtypes = [ - mx.int8, - mx.uint8, - mx.int16, - mx.uint16, - mx.int32, - mx.uint32, - mx.float32, - mx.float16, - mx.bfloat16, - mx.complex64, + (mx.int8, 0), + (mx.uint8, 0), + (mx.int16, 0), + (mx.uint16, 0), + (mx.int32, 0), + (mx.uint32, 0), + (mx.float32, 1e-6), + (mx.float16, 5e-3), + (mx.bfloat16, 1e-1), + (mx.complex64, 1e-6), ] - for dt in dtypes: - x = mx.ones((2, 2, 4), dtype=dt) - y = mx.distributed.all_sum(x) - self.assertTrue(mx.all(y == world.size())) + sizes = [ + (7,), + (10,), + (1024,), + (1024, 1024), + ] + key = mx.random.key(0) + group = world.split(world.rank() % 2) - sub = world.split(world.rank() % 2) - for dt in dtypes: - x = mx.ones((2, 2, 4), dtype=dt) - y = mx.distributed.all_sum(x, group=sub) - self.assertTrue(mx.all(y == sub.size())) + for dt, rtol in dtypes: + for sh in sizes: + for g in [world, group]: + x = ( + mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10 + ).astype(dt) + + # All sum + y = mx.distributed.all_sum(x[g.rank()], group=g) + z = x.sum(0) + maxrelerror = (y - z).abs() + if rtol > 0: + maxrelerror /= z.abs() + maxrelerror = maxrelerror.max() + self.assertLessEqual(maxrelerror, rtol) + + # All max + y = mx.distributed.all_max(x[g.rank()], group=g) + z = x.max(0) + self.assertTrue(mx.all(y == z)) + + # All min + y = mx.distributed.all_min(x[g.rank()], group=g) + z = x.min(0) + self.assertTrue(mx.all(y == z)) def test_all_gather(self): world = mx.distributed.init() @@ -124,22 +148,6 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): x = mx.distributed.recv_like(x, neighbor, group=pairs) mx.eval(y, x) - def test_min_max(self): - world = mx.distributed.init() - base = mx.arange(16).reshape(4, 4) - x = base + world.rank() * 32 - - def _test_reduction(reduction: str = "all_max"): - - target = base + ((world.size() - 1) * 16) * (reduction == "max") - reducer = getattr(mx.distributed, reduction) - y = reducer(x) - - self.assertTrue(mx.allclose(y, target)) - - for reduction in ["all_max", "all_min"]: - _test_reduction(reduction) - if __name__ == "__main__": unittest.main() diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index d74c534e0..77d45dbad 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -51,16 +51,25 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): x = ( mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10 ).astype(dt) - for reduction in reductions: - reducer_distributed = getattr(mx.distributed, f"all_{reduction}") - y = reducer_distributed(x[world.rank()]) - reducer = getattr(mx, reduction) - z = reducer(x, axis=0) - mx.eval(y, z) + # All sum + y = mx.distributed.all_sum(x[world.rank()]) + z = x.sum(0) + maxrelerror = (y - z).abs() + if rtol > 0: + maxrelerror /= z.abs() + maxrelerror = maxrelerror.max() + self.assertLessEqual(maxrelerror, rtol) - maxrelerror = ((y - z).abs() / z.abs()).max() - self.assertLessEqual(maxrelerror, rtol) + # All max + y = mx.distributed.all_max(x[world.rank()]) + z = x.max(0) + self.assertTrue(mx.all(y == z)) + + # All min + y = mx.distributed.all_min(x[world.rank()]) + z = x.min(0) + self.assertTrue(mx.all(y == z)) def test_all_gather(self): world = mx.distributed.init()