mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix the test and add custom min/max reductions for uncommon MPI types (#2060)
This commit is contained in:
parent
dfae2c6989
commit
ddaa4b7dcb
@ -53,7 +53,8 @@ void AllReduce::eval_cpu(
|
|||||||
distributed::detail::all_min(group(), in, outputs[0], stream());
|
distributed::detail::all_min(group(), in, outputs[0], stream());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum, min and max are supported for now");
|
throw std::runtime_error(
|
||||||
|
"Only all reduce sum, min and max are supported for now");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,6 +50,46 @@ void simple_sum(
|
|||||||
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
|
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
|
||||||
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
|
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void simple_max(
|
||||||
|
void* input,
|
||||||
|
void* accumulator,
|
||||||
|
int* len,
|
||||||
|
MPI_Datatype* datatype) {
|
||||||
|
T* in = (T*)input;
|
||||||
|
T* acc = (T*)accumulator;
|
||||||
|
int N = *len;
|
||||||
|
|
||||||
|
while (N-- > 0) {
|
||||||
|
*acc = std::max(*acc, *in);
|
||||||
|
acc++;
|
||||||
|
in++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template void simple_max<float16_t>(void*, void*, int*, MPI_Datatype*);
|
||||||
|
template void simple_max<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
|
||||||
|
template void simple_max<complex64_t>(void*, void*, int*, MPI_Datatype*);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void simple_min(
|
||||||
|
void* input,
|
||||||
|
void* accumulator,
|
||||||
|
int* len,
|
||||||
|
MPI_Datatype* datatype) {
|
||||||
|
T* in = (T*)input;
|
||||||
|
T* acc = (T*)accumulator;
|
||||||
|
int N = *len;
|
||||||
|
|
||||||
|
while (N-- > 0) {
|
||||||
|
*acc = std::min(*acc, *in);
|
||||||
|
acc++;
|
||||||
|
in++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template void simple_min<float16_t>(void*, void*, int*, MPI_Datatype*);
|
||||||
|
template void simple_min<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
|
||||||
|
template void simple_min<complex64_t>(void*, void*, int*, MPI_Datatype*);
|
||||||
|
|
||||||
struct MPIWrapper {
|
struct MPIWrapper {
|
||||||
MPIWrapper() {
|
MPIWrapper() {
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
@ -129,9 +169,15 @@ struct MPIWrapper {
|
|||||||
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
|
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
|
||||||
mpi_type_commit(&mpi_bfloat16_);
|
mpi_type_commit(&mpi_bfloat16_);
|
||||||
|
|
||||||
// Custom sum ops
|
// Custom reduction ops
|
||||||
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
|
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
|
||||||
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
|
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
|
||||||
|
mpi_op_create(&simple_max<float16_t>, 1, &op_max_f16_);
|
||||||
|
mpi_op_create(&simple_max<bfloat16_t>, 1, &op_max_bf16_);
|
||||||
|
mpi_op_create(&simple_max<complex64_t>, 1, &op_max_c64_);
|
||||||
|
mpi_op_create(&simple_min<float16_t>, 1, &op_min_f16_);
|
||||||
|
mpi_op_create(&simple_min<bfloat16_t>, 1, &op_min_bf16_);
|
||||||
|
mpi_op_create(&simple_min<complex64_t>, 1, &op_min_c64_);
|
||||||
|
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
}
|
}
|
||||||
@ -194,11 +240,29 @@ struct MPIWrapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
MPI_Op op_max(const array& arr) {
|
MPI_Op op_max(const array& arr) {
|
||||||
return op_max_;
|
switch (arr.dtype()) {
|
||||||
|
case float16:
|
||||||
|
return op_max_f16_;
|
||||||
|
case bfloat16:
|
||||||
|
return op_max_bf16_;
|
||||||
|
case complex64:
|
||||||
|
return op_max_c64_;
|
||||||
|
default:
|
||||||
|
return op_max_;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MPI_Op op_min(const array& arr) {
|
MPI_Op op_min(const array& arr) {
|
||||||
return op_min_;
|
switch (arr.dtype()) {
|
||||||
|
case float16:
|
||||||
|
return op_min_f16_;
|
||||||
|
case bfloat16:
|
||||||
|
return op_min_bf16_;
|
||||||
|
case complex64:
|
||||||
|
return op_min_c64_;
|
||||||
|
default:
|
||||||
|
return op_min_;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void* libmpi_handle_;
|
void* libmpi_handle_;
|
||||||
@ -230,7 +294,13 @@ struct MPIWrapper {
|
|||||||
MPI_Op op_sum_f16_;
|
MPI_Op op_sum_f16_;
|
||||||
MPI_Op op_sum_bf16_;
|
MPI_Op op_sum_bf16_;
|
||||||
MPI_Op op_max_;
|
MPI_Op op_max_;
|
||||||
|
MPI_Op op_max_f16_;
|
||||||
|
MPI_Op op_max_bf16_;
|
||||||
|
MPI_Op op_max_c64_;
|
||||||
MPI_Op op_min_;
|
MPI_Op op_min_;
|
||||||
|
MPI_Op op_min_f16_;
|
||||||
|
MPI_Op op_min_bf16_;
|
||||||
|
MPI_Op op_min_c64_;
|
||||||
|
|
||||||
// Datatypes
|
// Datatypes
|
||||||
MPI_Datatype mpi_bool_;
|
MPI_Datatype mpi_bool_;
|
||||||
|
@ -30,27 +30,51 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
def test_all_reduce(self):
|
def test_all_reduce(self):
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
dtypes = [
|
dtypes = [
|
||||||
mx.int8,
|
(mx.int8, 0),
|
||||||
mx.uint8,
|
(mx.uint8, 0),
|
||||||
mx.int16,
|
(mx.int16, 0),
|
||||||
mx.uint16,
|
(mx.uint16, 0),
|
||||||
mx.int32,
|
(mx.int32, 0),
|
||||||
mx.uint32,
|
(mx.uint32, 0),
|
||||||
mx.float32,
|
(mx.float32, 1e-6),
|
||||||
mx.float16,
|
(mx.float16, 5e-3),
|
||||||
mx.bfloat16,
|
(mx.bfloat16, 1e-1),
|
||||||
mx.complex64,
|
(mx.complex64, 1e-6),
|
||||||
]
|
]
|
||||||
for dt in dtypes:
|
sizes = [
|
||||||
x = mx.ones((2, 2, 4), dtype=dt)
|
(7,),
|
||||||
y = mx.distributed.all_sum(x)
|
(10,),
|
||||||
self.assertTrue(mx.all(y == world.size()))
|
(1024,),
|
||||||
|
(1024, 1024),
|
||||||
|
]
|
||||||
|
key = mx.random.key(0)
|
||||||
|
group = world.split(world.rank() % 2)
|
||||||
|
|
||||||
sub = world.split(world.rank() % 2)
|
for dt, rtol in dtypes:
|
||||||
for dt in dtypes:
|
for sh in sizes:
|
||||||
x = mx.ones((2, 2, 4), dtype=dt)
|
for g in [world, group]:
|
||||||
y = mx.distributed.all_sum(x, group=sub)
|
x = (
|
||||||
self.assertTrue(mx.all(y == sub.size()))
|
mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10
|
||||||
|
).astype(dt)
|
||||||
|
|
||||||
|
# All sum
|
||||||
|
y = mx.distributed.all_sum(x[g.rank()], group=g)
|
||||||
|
z = x.sum(0)
|
||||||
|
maxrelerror = (y - z).abs()
|
||||||
|
if rtol > 0:
|
||||||
|
maxrelerror /= z.abs()
|
||||||
|
maxrelerror = maxrelerror.max()
|
||||||
|
self.assertLessEqual(maxrelerror, rtol)
|
||||||
|
|
||||||
|
# All max
|
||||||
|
y = mx.distributed.all_max(x[g.rank()], group=g)
|
||||||
|
z = x.max(0)
|
||||||
|
self.assertTrue(mx.all(y == z))
|
||||||
|
|
||||||
|
# All min
|
||||||
|
y = mx.distributed.all_min(x[g.rank()], group=g)
|
||||||
|
z = x.min(0)
|
||||||
|
self.assertTrue(mx.all(y == z))
|
||||||
|
|
||||||
def test_all_gather(self):
|
def test_all_gather(self):
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
@ -124,22 +148,6 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
||||||
mx.eval(y, x)
|
mx.eval(y, x)
|
||||||
|
|
||||||
def test_min_max(self):
|
|
||||||
world = mx.distributed.init()
|
|
||||||
base = mx.arange(16).reshape(4, 4)
|
|
||||||
x = base + world.rank() * 32
|
|
||||||
|
|
||||||
def _test_reduction(reduction: str = "all_max"):
|
|
||||||
|
|
||||||
target = base + ((world.size() - 1) * 16) * (reduction == "max")
|
|
||||||
reducer = getattr(mx.distributed, reduction)
|
|
||||||
y = reducer(x)
|
|
||||||
|
|
||||||
self.assertTrue(mx.allclose(y, target))
|
|
||||||
|
|
||||||
for reduction in ["all_max", "all_min"]:
|
|
||||||
_test_reduction(reduction)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -51,16 +51,25 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
x = (
|
x = (
|
||||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||||
).astype(dt)
|
).astype(dt)
|
||||||
for reduction in reductions:
|
|
||||||
reducer_distributed = getattr(mx.distributed, f"all_{reduction}")
|
|
||||||
y = reducer_distributed(x[world.rank()])
|
|
||||||
|
|
||||||
reducer = getattr(mx, reduction)
|
# All sum
|
||||||
z = reducer(x, axis=0)
|
y = mx.distributed.all_sum(x[world.rank()])
|
||||||
mx.eval(y, z)
|
z = x.sum(0)
|
||||||
|
maxrelerror = (y - z).abs()
|
||||||
|
if rtol > 0:
|
||||||
|
maxrelerror /= z.abs()
|
||||||
|
maxrelerror = maxrelerror.max()
|
||||||
|
self.assertLessEqual(maxrelerror, rtol)
|
||||||
|
|
||||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
# All max
|
||||||
self.assertLessEqual(maxrelerror, rtol)
|
y = mx.distributed.all_max(x[world.rank()])
|
||||||
|
z = x.max(0)
|
||||||
|
self.assertTrue(mx.all(y == z))
|
||||||
|
|
||||||
|
# All min
|
||||||
|
y = mx.distributed.all_min(x[world.rank()])
|
||||||
|
z = x.min(0)
|
||||||
|
self.assertTrue(mx.all(y == z))
|
||||||
|
|
||||||
def test_all_gather(self):
|
def test_all_gather(self):
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
|
Loading…
Reference in New Issue
Block a user