mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add MPI barrier
This commit is contained in:
		| @@ -32,6 +32,8 @@ struct Group { | ||||
|    */ | ||||
|   Group split(int color, int key = -1); | ||||
|  | ||||
|   void barrier(); | ||||
|  | ||||
|   const std::shared_ptr<void>& raw_group() { | ||||
|     return group_; | ||||
|   } | ||||
|   | ||||
| @@ -71,6 +71,7 @@ struct MPIWrapper { | ||||
|     LOAD_SYMBOL(MPI_Allgather, all_gather); | ||||
|     LOAD_SYMBOL(MPI_Send, send); | ||||
|     LOAD_SYMBOL(MPI_Recv, recv); | ||||
|     LOAD_SYMBOL(MPI_Barrier, barrier); | ||||
|     LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous); | ||||
|     LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit); | ||||
|     LOAD_SYMBOL(MPI_Op_create, mpi_op_create); | ||||
| @@ -195,6 +196,7 @@ struct MPIWrapper { | ||||
|   int (*comm_free)(MPI_Comm*); | ||||
|   int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm); | ||||
|   int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*); | ||||
|   int (*barrier)(MPI_Comm); | ||||
|  | ||||
|   // Objects | ||||
|   MPI_Comm comm_world_; | ||||
| @@ -263,6 +265,10 @@ struct MPIGroupImpl { | ||||
|     return size_; | ||||
|   } | ||||
|  | ||||
|   void barrier() { | ||||
|     mpi().barrier(comm_); | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   MPI_Comm comm_; | ||||
|   bool global_; | ||||
| @@ -298,6 +304,11 @@ Group Group::split(int color, int key) { | ||||
|   return Group(std::make_shared<MPIGroupImpl>(new_comm, false)); | ||||
| } | ||||
|  | ||||
| void Group::barrier() { | ||||
|   auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_); | ||||
|   mpi_group->barrier(); | ||||
| } | ||||
|  | ||||
| bool is_available() { | ||||
|   return mpi().is_available(); | ||||
| } | ||||
|   | ||||
| @@ -17,6 +17,8 @@ Group Group::split(int color, int key) { | ||||
|   throw std::runtime_error("Cannot split the distributed group further"); | ||||
| } | ||||
|  | ||||
| void Group::barrier() {} | ||||
|  | ||||
| bool is_available() { | ||||
|   return false; | ||||
| } | ||||
|   | ||||
| @@ -44,7 +44,8 @@ void init_distributed(nb::module_& parent_module) { | ||||
|               color (int): A value to group processes into subgroups. | ||||
|               key (int, optional): A key to optionally change the rank ordering | ||||
|                 of the processes. | ||||
|           )pbdoc"); | ||||
|           )pbdoc") | ||||
|       .def("barrier", &distributed::Group::barrier, "Make a synhronization point for all nodes in the group"); | ||||
|  | ||||
|   m.def( | ||||
|       "is_available", | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos