Add MPI barrier

This commit is contained in:
Angelos Katharopoulos 2024-11-01 11:41:53 -07:00
parent 26be608470
commit c3ccd4919f
4 changed files with 17 additions and 1 deletions

View File

@ -32,6 +32,8 @@ struct Group {
*/ */
Group split(int color, int key = -1); Group split(int color, int key = -1);
void barrier();
const std::shared_ptr<void>& raw_group() { const std::shared_ptr<void>& raw_group() {
return group_; return group_;
} }

View File

@ -71,6 +71,7 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Allgather, all_gather); LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Send, send); LOAD_SYMBOL(MPI_Send, send);
LOAD_SYMBOL(MPI_Recv, recv); LOAD_SYMBOL(MPI_Recv, recv);
LOAD_SYMBOL(MPI_Barrier, barrier);
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous); LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit); LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
LOAD_SYMBOL(MPI_Op_create, mpi_op_create); LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
@ -195,6 +196,7 @@ struct MPIWrapper {
int (*comm_free)(MPI_Comm*); int (*comm_free)(MPI_Comm*);
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm); int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*); int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
int (*barrier)(MPI_Comm);
// Objects // Objects
MPI_Comm comm_world_; MPI_Comm comm_world_;
@ -263,6 +265,10 @@ struct MPIGroupImpl {
return size_; return size_;
} }
void barrier() {
mpi().barrier(comm_);
}
private: private:
MPI_Comm comm_; MPI_Comm comm_;
bool global_; bool global_;
@ -298,6 +304,11 @@ Group Group::split(int color, int key) {
return Group(std::make_shared<MPIGroupImpl>(new_comm, false)); return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
} }
void Group::barrier() {
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
mpi_group->barrier();
}
bool is_available() { bool is_available() {
return mpi().is_available(); return mpi().is_available();
} }

View File

@ -17,6 +17,8 @@ Group Group::split(int color, int key) {
throw std::runtime_error("Cannot split the distributed group further"); throw std::runtime_error("Cannot split the distributed group further");
} }
void Group::barrier() {}
bool is_available() { bool is_available() {
return false; return false;
} }

View File

@ -44,7 +44,8 @@ void init_distributed(nb::module_& parent_module) {
color (int): A value to group processes into subgroups. color (int): A value to group processes into subgroups.
key (int, optional): A key to optionally change the rank ordering key (int, optional): A key to optionally change the rank ordering
of the processes. of the processes.
)pbdoc"); )pbdoc")
.def("barrier", &distributed::Group::barrier, "Make a synhronization point for all nodes in the group");
m.def( m.def(
"is_available", "is_available",