Fix the test and add custom min/max reductions for uncommon MPI types (#2060)

This commit is contained in:
Angelos Katharopoulos 2025-04-10 17:01:17 -07:00 committed by GitHub
parent dfae2c6989
commit ddaa4b7dcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 135 additions and 47 deletions

View File

@ -53,7 +53,8 @@ void AllReduce::eval_cpu(
distributed::detail::all_min(group(), in, outputs[0], stream());
break;
default:
throw std::runtime_error("Only all reduce sum, min and max are supported for now");
throw std::runtime_error(
"Only all reduce sum, min and max are supported for now");
}
}

View File

@ -50,6 +50,46 @@ void simple_sum(
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template <typename T>
void simple_max(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;
while (N-- > 0) {
*acc = std::max(*acc, *in);
acc++;
in++;
}
}
template void simple_max<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_max<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_max<complex64_t>(void*, void*, int*, MPI_Datatype*);
template <typename T>
void simple_min(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;
while (N-- > 0) {
*acc = std::min(*acc, *in);
acc++;
in++;
}
}
template void simple_min<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_min<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_min<complex64_t>(void*, void*, int*, MPI_Datatype*);
struct MPIWrapper {
MPIWrapper() {
initialized_ = false;
@ -129,9 +169,15 @@ struct MPIWrapper {
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
mpi_type_commit(&mpi_bfloat16_);
// Custom sum ops
// Custom reduction ops
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
mpi_op_create(&simple_max<float16_t>, 1, &op_max_f16_);
mpi_op_create(&simple_max<bfloat16_t>, 1, &op_max_bf16_);
mpi_op_create(&simple_max<complex64_t>, 1, &op_max_c64_);
mpi_op_create(&simple_min<float16_t>, 1, &op_min_f16_);
mpi_op_create(&simple_min<bfloat16_t>, 1, &op_min_bf16_);
mpi_op_create(&simple_min<complex64_t>, 1, &op_min_c64_);
initialized_ = true;
}
@ -194,11 +240,29 @@ struct MPIWrapper {
}
MPI_Op op_max(const array& arr) {
return op_max_;
switch (arr.dtype()) {
case float16:
return op_max_f16_;
case bfloat16:
return op_max_bf16_;
case complex64:
return op_max_c64_;
default:
return op_max_;
}
}
MPI_Op op_min(const array& arr) {
return op_min_;
switch (arr.dtype()) {
case float16:
return op_min_f16_;
case bfloat16:
return op_min_bf16_;
case complex64:
return op_min_c64_;
default:
return op_min_;
}
}
void* libmpi_handle_;
@ -230,7 +294,13 @@ struct MPIWrapper {
MPI_Op op_sum_f16_;
MPI_Op op_sum_bf16_;
MPI_Op op_max_;
MPI_Op op_max_f16_;
MPI_Op op_max_bf16_;
MPI_Op op_max_c64_;
MPI_Op op_min_;
MPI_Op op_min_f16_;
MPI_Op op_min_bf16_;
MPI_Op op_min_c64_;
// Datatypes
MPI_Datatype mpi_bool_;

View File

@ -30,27 +30,51 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
def test_all_reduce(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int16,
mx.uint16,
mx.int32,
mx.uint32,
mx.float32,
mx.float16,
mx.bfloat16,
mx.complex64,
(mx.int8, 0),
(mx.uint8, 0),
(mx.int16, 0),
(mx.uint16, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
(mx.complex64, 1e-6),
]
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_sum(x)
self.assertTrue(mx.all(y == world.size()))
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
group = world.split(world.rank() % 2)
sub = world.split(world.rank() % 2)
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_sum(x, group=sub)
self.assertTrue(mx.all(y == sub.size()))
for dt, rtol in dtypes:
for sh in sizes:
for g in [world, group]:
x = (
mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10
).astype(dt)
# All sum
y = mx.distributed.all_sum(x[g.rank()], group=g)
z = x.sum(0)
maxrelerror = (y - z).abs()
if rtol > 0:
maxrelerror /= z.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
# All max
y = mx.distributed.all_max(x[g.rank()], group=g)
z = x.max(0)
self.assertTrue(mx.all(y == z))
# All min
y = mx.distributed.all_min(x[g.rank()], group=g)
z = x.min(0)
self.assertTrue(mx.all(y == z))
def test_all_gather(self):
world = mx.distributed.init()
@ -124,22 +148,6 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
x = mx.distributed.recv_like(x, neighbor, group=pairs)
mx.eval(y, x)
def test_min_max(self):
world = mx.distributed.init()
base = mx.arange(16).reshape(4, 4)
x = base + world.rank() * 32
def _test_reduction(reduction: str = "all_max"):
target = base + ((world.size() - 1) * 16) * (reduction == "max")
reducer = getattr(mx.distributed, reduction)
y = reducer(x)
self.assertTrue(mx.allclose(y, target))
for reduction in ["all_max", "all_min"]:
_test_reduction(reduction)
if __name__ == "__main__":
unittest.main()

View File

@ -51,16 +51,25 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
for reduction in reductions:
reducer_distributed = getattr(mx.distributed, f"all_{reduction}")
y = reducer_distributed(x[world.rank()])
reducer = getattr(mx, reduction)
z = reducer(x, axis=0)
mx.eval(y, z)
# All sum
y = mx.distributed.all_sum(x[world.rank()])
z = x.sum(0)
maxrelerror = (y - z).abs()
if rtol > 0:
maxrelerror /= z.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol)
# All max
y = mx.distributed.all_max(x[world.rank()])
z = x.max(0)
self.assertTrue(mx.all(y == z))
# All min
y = mx.distributed.all_min(x[world.rank()])
z = x.min(0)
self.assertTrue(mx.all(y == z))
def test_all_gather(self):
world = mx.distributed.init()