mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
Fix the test and add custom min/max reductions for uncommon MPI types (#2060)
This commit is contained in:

committed by
GitHub

parent
dfae2c6989
commit
ddaa4b7dcb
@@ -53,7 +53,8 @@ void AllReduce::eval_cpu(
|
||||
distributed::detail::all_min(group(), in, outputs[0], stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum, min and max are supported for now");
|
||||
throw std::runtime_error(
|
||||
"Only all reduce sum, min and max are supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -50,6 +50,46 @@ void simple_sum(
|
||||
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
|
||||
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
|
||||
|
||||
template <typename T>
|
||||
void simple_max(
|
||||
void* input,
|
||||
void* accumulator,
|
||||
int* len,
|
||||
MPI_Datatype* datatype) {
|
||||
T* in = (T*)input;
|
||||
T* acc = (T*)accumulator;
|
||||
int N = *len;
|
||||
|
||||
while (N-- > 0) {
|
||||
*acc = std::max(*acc, *in);
|
||||
acc++;
|
||||
in++;
|
||||
}
|
||||
}
|
||||
template void simple_max<float16_t>(void*, void*, int*, MPI_Datatype*);
|
||||
template void simple_max<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
|
||||
template void simple_max<complex64_t>(void*, void*, int*, MPI_Datatype*);
|
||||
|
||||
template <typename T>
|
||||
void simple_min(
|
||||
void* input,
|
||||
void* accumulator,
|
||||
int* len,
|
||||
MPI_Datatype* datatype) {
|
||||
T* in = (T*)input;
|
||||
T* acc = (T*)accumulator;
|
||||
int N = *len;
|
||||
|
||||
while (N-- > 0) {
|
||||
*acc = std::min(*acc, *in);
|
||||
acc++;
|
||||
in++;
|
||||
}
|
||||
}
|
||||
template void simple_min<float16_t>(void*, void*, int*, MPI_Datatype*);
|
||||
template void simple_min<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
|
||||
template void simple_min<complex64_t>(void*, void*, int*, MPI_Datatype*);
|
||||
|
||||
struct MPIWrapper {
|
||||
MPIWrapper() {
|
||||
initialized_ = false;
|
||||
@@ -129,9 +169,15 @@ struct MPIWrapper {
|
||||
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
|
||||
mpi_type_commit(&mpi_bfloat16_);
|
||||
|
||||
// Custom sum ops
|
||||
// Custom reduction ops
|
||||
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
|
||||
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
|
||||
mpi_op_create(&simple_max<float16_t>, 1, &op_max_f16_);
|
||||
mpi_op_create(&simple_max<bfloat16_t>, 1, &op_max_bf16_);
|
||||
mpi_op_create(&simple_max<complex64_t>, 1, &op_max_c64_);
|
||||
mpi_op_create(&simple_min<float16_t>, 1, &op_min_f16_);
|
||||
mpi_op_create(&simple_min<bfloat16_t>, 1, &op_min_bf16_);
|
||||
mpi_op_create(&simple_min<complex64_t>, 1, &op_min_c64_);
|
||||
|
||||
initialized_ = true;
|
||||
}
|
||||
@@ -194,11 +240,29 @@ struct MPIWrapper {
|
||||
}
|
||||
|
||||
MPI_Op op_max(const array& arr) {
|
||||
return op_max_;
|
||||
switch (arr.dtype()) {
|
||||
case float16:
|
||||
return op_max_f16_;
|
||||
case bfloat16:
|
||||
return op_max_bf16_;
|
||||
case complex64:
|
||||
return op_max_c64_;
|
||||
default:
|
||||
return op_max_;
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Op op_min(const array& arr) {
|
||||
return op_min_;
|
||||
switch (arr.dtype()) {
|
||||
case float16:
|
||||
return op_min_f16_;
|
||||
case bfloat16:
|
||||
return op_min_bf16_;
|
||||
case complex64:
|
||||
return op_min_c64_;
|
||||
default:
|
||||
return op_min_;
|
||||
}
|
||||
}
|
||||
|
||||
void* libmpi_handle_;
|
||||
@@ -230,7 +294,13 @@ struct MPIWrapper {
|
||||
MPI_Op op_sum_f16_;
|
||||
MPI_Op op_sum_bf16_;
|
||||
MPI_Op op_max_;
|
||||
MPI_Op op_max_f16_;
|
||||
MPI_Op op_max_bf16_;
|
||||
MPI_Op op_max_c64_;
|
||||
MPI_Op op_min_;
|
||||
MPI_Op op_min_f16_;
|
||||
MPI_Op op_min_bf16_;
|
||||
MPI_Op op_min_c64_;
|
||||
|
||||
// Datatypes
|
||||
MPI_Datatype mpi_bool_;
|
||||
|
Reference in New Issue
Block a user