mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 08:41:13 +08:00
Add MPI barrier
This commit is contained in:
parent
26be608470
commit
c3ccd4919f
@ -32,6 +32,8 @@ struct Group {
|
|||||||
*/
|
*/
|
||||||
Group split(int color, int key = -1);
|
Group split(int color, int key = -1);
|
||||||
|
|
||||||
|
void barrier();
|
||||||
|
|
||||||
const std::shared_ptr<void>& raw_group() {
|
const std::shared_ptr<void>& raw_group() {
|
||||||
return group_;
|
return group_;
|
||||||
}
|
}
|
||||||
|
@ -71,6 +71,7 @@ struct MPIWrapper {
|
|||||||
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
||||||
LOAD_SYMBOL(MPI_Send, send);
|
LOAD_SYMBOL(MPI_Send, send);
|
||||||
LOAD_SYMBOL(MPI_Recv, recv);
|
LOAD_SYMBOL(MPI_Recv, recv);
|
||||||
|
LOAD_SYMBOL(MPI_Barrier, barrier);
|
||||||
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
|
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
|
||||||
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
|
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
|
||||||
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
|
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
|
||||||
@ -195,6 +196,7 @@ struct MPIWrapper {
|
|||||||
int (*comm_free)(MPI_Comm*);
|
int (*comm_free)(MPI_Comm*);
|
||||||
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
|
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
|
||||||
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
|
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
|
||||||
|
int (*barrier)(MPI_Comm);
|
||||||
|
|
||||||
// Objects
|
// Objects
|
||||||
MPI_Comm comm_world_;
|
MPI_Comm comm_world_;
|
||||||
@ -263,6 +265,10 @@ struct MPIGroupImpl {
|
|||||||
return size_;
|
return size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void barrier() {
|
||||||
|
mpi().barrier(comm_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MPI_Comm comm_;
|
MPI_Comm comm_;
|
||||||
bool global_;
|
bool global_;
|
||||||
@ -298,6 +304,11 @@ Group Group::split(int color, int key) {
|
|||||||
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
|
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Group::barrier() {
|
||||||
|
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
|
||||||
|
mpi_group->barrier();
|
||||||
|
}
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return mpi().is_available();
|
return mpi().is_available();
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,8 @@ Group Group::split(int color, int key) {
|
|||||||
throw std::runtime_error("Cannot split the distributed group further");
|
throw std::runtime_error("Cannot split the distributed group further");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Group::barrier() {}
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,8 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
color (int): A value to group processes into subgroups.
|
color (int): A value to group processes into subgroups.
|
||||||
key (int, optional): A key to optionally change the rank ordering
|
key (int, optional): A key to optionally change the rank ordering
|
||||||
of the processes.
|
of the processes.
|
||||||
)pbdoc");
|
)pbdoc")
|
||||||
|
.def("barrier", &distributed::Group::barrier, "Make a synhronization point for all nodes in the group");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"is_available",
|
"is_available",
|
||||||
|
Loading…
Reference in New Issue
Block a user