mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Improve the ring backend initialization
This commit is contained in:
		| @@ -22,78 +22,20 @@ | ||||
| #include "mlx/backend/cpu/encoder.h" | ||||
| #include "mlx/distributed/distributed.h" | ||||
| #include "mlx/distributed/distributed_impl.h" | ||||
| #include "mlx/dtype_utils.h" | ||||
| #include "mlx/threadpool.h" | ||||
|  | ||||
| #ifndef SOL_TCP | ||||
| #define SOL_TCP IPPROTO_TCP | ||||
| #endif | ||||
|  | ||||
| #define SWITCH_TYPE(x, ...)  \ | ||||
|   switch ((x).dtype()) {     \ | ||||
|     case bool_: {            \ | ||||
|       using T = bool;        \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case int8: {             \ | ||||
|       using T = int8_t;      \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case int16: {            \ | ||||
|       using T = int16_t;     \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case int32: {            \ | ||||
|       using T = int32_t;     \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case int64: {            \ | ||||
|       using T = int64_t;     \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case uint8: {            \ | ||||
|       using T = uint8_t;     \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case uint16: {           \ | ||||
|       using T = uint16_t;    \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case uint32: {           \ | ||||
|       using T = uint32_t;    \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case uint64: {           \ | ||||
|       using T = uint64_t;    \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case bfloat16: {         \ | ||||
|       using T = bfloat16_t;  \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case float16: {          \ | ||||
|       using T = float16_t;   \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case float32: {          \ | ||||
|       using T = float;       \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case float64: {          \ | ||||
|       using T = double;      \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|     case complex64: {        \ | ||||
|       using T = complex64_t; \ | ||||
|       __VA_ARGS__;           \ | ||||
|     } break;                 \ | ||||
|   } | ||||
|  | ||||
| namespace mlx::core::distributed::ring { | ||||
|  | ||||
| constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024; | ||||
| constexpr const size_t ALL_SUM_BUFFERS = 2; | ||||
| constexpr const int CONN_ATTEMPTS = 5; | ||||
| constexpr const int CONN_WAIT = 1000; | ||||
| constexpr const int INIT_TIMEOUT = 20000; | ||||
|  | ||||
| using GroupImpl = mlx::core::distributed::detail::GroupImpl; | ||||
| using json = nlohmann::json; | ||||
| @@ -503,6 +445,7 @@ std::vector<int> make_connections( | ||||
|  | ||||
|   return sockets; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| struct SumOp { | ||||
|   void operator()(const T* input, T* output, size_t N) { | ||||
| @@ -550,19 +493,27 @@ class RingGroup : public GroupImpl { | ||||
|     size_ = nodes.size(); | ||||
|     int connect_to = (rank_ + 1) % size_; | ||||
|  | ||||
|     // We define the connection order by having the rank_ == size_ - 1 connect | ||||
|     // first and accept after. | ||||
|     if (rank_ < connect_to) { | ||||
|       log_info(verbose_, "Rank", rank_, "accepting"); | ||||
|       sockets_left_ = std::move(accept_connections(nodes[rank_])); | ||||
|       log_info(verbose_, "Rank", rank_, "connecting to", connect_to); | ||||
|       sockets_right_ = std::move(make_connections(nodes[connect_to], verbose)); | ||||
|     } else { | ||||
|       log_info(verbose_, "Rank", rank_, "connecting to", connect_to); | ||||
|       sockets_right_ = std::move(make_connections(nodes[connect_to], verbose)); | ||||
|       log_info(verbose_, "Rank", rank_, "accepting"); | ||||
|       sockets_left_ = std::move(accept_connections(nodes[rank_])); | ||||
|     // Initialize the ring by making all the connections | ||||
|     log_info(verbose_, "Rank", rank_, "accepting"); | ||||
|     log_info(verbose_, "Rank", rank_, "connecting to", connect_to); | ||||
|     auto sl = std::async(std::launch::async, accept_connections, nodes[rank_]); | ||||
|     auto sr = std::async( | ||||
|         std::launch::async, make_connections, nodes[connect_to], verbose); | ||||
|     std::future_status status_sl, status_sr; | ||||
|     for (int i = 0; i < 10; i++) { | ||||
|       status_sl = sl.wait_for(std::chrono::milliseconds(INIT_TIMEOUT / 10)); | ||||
|       status_sr = sl.wait_for(std::chrono::milliseconds(INIT_TIMEOUT / 10)); | ||||
|       if (status_sl == std::future_status::ready && | ||||
|           status_sr == std::future_status::ready) { | ||||
|         break; | ||||
|       } | ||||
|     } | ||||
|     if (status_sl != std::future_status::ready || | ||||
|         status_sr != std::future_status::ready) { | ||||
|       throw std::runtime_error("[ring] Ring initialization timed out"); | ||||
|     } | ||||
|     sockets_left_ = std::move(sl.get()); | ||||
|     sockets_right_ = std::move(sr.get()); | ||||
|  | ||||
|     // Failure if we couldn't make right or left sockets | ||||
|     if (sockets_right_.empty()) { | ||||
| @@ -628,18 +579,24 @@ class RingGroup : public GroupImpl { | ||||
|   } | ||||
|  | ||||
|   void all_sum(const array& input, array& output, Stream stream) override { | ||||
|     SWITCH_TYPE( | ||||
|         output, all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>())); | ||||
|     dispatch_all_types(output.dtype(), [&](auto type_tag) { | ||||
|       using T = MLX_GET_TYPE(type_tag); | ||||
|       all_reduce<T, SumOp<T>>(input, output, stream, SumOp<T>()); | ||||
|     }); | ||||
|   } | ||||
|  | ||||
|   void all_max(const array& input, array& output, Stream stream) override { | ||||
|     SWITCH_TYPE( | ||||
|         output, all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>())); | ||||
|     dispatch_all_types(output.dtype(), [&](auto type_tag) { | ||||
|       using T = MLX_GET_TYPE(type_tag); | ||||
|       all_reduce<T, MaxOp<T>>(input, output, stream, MaxOp<T>()); | ||||
|     }); | ||||
|   } | ||||
|  | ||||
|   void all_min(const array& input, array& output, Stream stream) override { | ||||
|     SWITCH_TYPE( | ||||
|         output, all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>())); | ||||
|     dispatch_all_types(output.dtype(), [&](auto type_tag) { | ||||
|       using T = MLX_GET_TYPE(type_tag); | ||||
|       all_reduce<T, MinOp<T>>(input, output, stream, MinOp<T>()); | ||||
|     }); | ||||
|   } | ||||
|  | ||||
|   std::shared_ptr<GroupImpl> split(int color, int key = -1) override { | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_distributed_tests | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): | ||||
| @@ -150,4 +151,4 @@ class TestMPIDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|     mlx_tests.MLXTestRunner() | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_distributed_tests | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Angelos Katharopoulos
					Angelos Katharopoulos