mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Data parallel helper (#1407)
This commit is contained in:

committed by
GitHub

parent
8d68a3e805
commit
914409fef9
@@ -32,8 +32,29 @@ array ensure_row_contiguous(const array& arr) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void simple_sum(
|
||||
void* input,
|
||||
void* accumulator,
|
||||
int* len,
|
||||
MPI_Datatype* datatype) {
|
||||
T* in = (T*)input;
|
||||
T* acc = (T*)accumulator;
|
||||
int N = *len;
|
||||
|
||||
while (N-- > 0) {
|
||||
*acc += *in;
|
||||
acc++;
|
||||
in++;
|
||||
}
|
||||
}
|
||||
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
|
||||
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
|
||||
|
||||
struct MPIWrapper {
|
||||
MPIWrapper() {
|
||||
initialized_ = false;
|
||||
|
||||
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
|
||||
if (libmpi_handle_ == nullptr) {
|
||||
return;
|
||||
@@ -50,6 +71,9 @@ struct MPIWrapper {
|
||||
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
||||
LOAD_SYMBOL(MPI_Send, send);
|
||||
LOAD_SYMBOL(MPI_Recv, recv);
|
||||
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
|
||||
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
|
||||
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
|
||||
|
||||
// Objects
|
||||
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
||||
@@ -79,7 +103,24 @@ struct MPIWrapper {
|
||||
if (!is_available()) {
|
||||
return false;
|
||||
}
|
||||
return init(nullptr, nullptr) == MPI_SUCCESS;
|
||||
bool success = init(nullptr, nullptr) == MPI_SUCCESS;
|
||||
|
||||
// Initialize custom types and ops
|
||||
if (success && !initialized_) {
|
||||
// Custom float16 dtypes
|
||||
mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_);
|
||||
mpi_type_commit(&mpi_float16_);
|
||||
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
|
||||
mpi_type_commit(&mpi_bfloat16_);
|
||||
|
||||
// Custom sum ops
|
||||
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
|
||||
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
|
||||
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
void finalize_safe() {
|
||||
@@ -117,13 +158,21 @@ struct MPIWrapper {
|
||||
case complex64:
|
||||
return mpi_complex_;
|
||||
case float16:
|
||||
return mpi_float16_;
|
||||
case bfloat16:
|
||||
throw std::runtime_error("MPI doesn't support 16-bit floats");
|
||||
return mpi_bfloat16_;
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Op op_sum() {
|
||||
return op_sum_;
|
||||
MPI_Op op_sum(const array& arr) {
|
||||
switch (arr.dtype()) {
|
||||
case float16:
|
||||
return op_sum_f16_;
|
||||
case bfloat16:
|
||||
return op_sum_bf16_;
|
||||
default:
|
||||
return op_sum_;
|
||||
}
|
||||
}
|
||||
|
||||
void* libmpi_handle_;
|
||||
@@ -152,6 +201,8 @@ struct MPIWrapper {
|
||||
|
||||
// Ops
|
||||
MPI_Op op_sum_;
|
||||
MPI_Op op_sum_f16_;
|
||||
MPI_Op op_sum_bf16_;
|
||||
|
||||
// Datatypes
|
||||
MPI_Datatype mpi_bool_;
|
||||
@@ -165,6 +216,16 @@ struct MPIWrapper {
|
||||
MPI_Datatype mpi_uint64_;
|
||||
MPI_Datatype mpi_float_;
|
||||
MPI_Datatype mpi_complex_;
|
||||
MPI_Datatype mpi_float16_;
|
||||
MPI_Datatype mpi_bfloat16_;
|
||||
|
||||
private:
|
||||
bool initialized_;
|
||||
|
||||
// Private API
|
||||
int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*);
|
||||
int (*mpi_type_commit)(MPI_Datatype*);
|
||||
int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*);
|
||||
};
|
||||
|
||||
MPIWrapper& mpi() {
|
||||
@@ -276,7 +337,7 @@ void all_sum(Group group, const array& input_, array& output) {
|
||||
output.data<void>(),
|
||||
input.size(),
|
||||
mpi().datatype(input),
|
||||
mpi().op_sum(),
|
||||
mpi().op_sum(input),
|
||||
to_comm(group));
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user