diff --git a/mlx/distributed/sockets/sockets.cpp b/mlx/distributed/sockets/sockets.cpp index 0b9b38c58..1cf10a11b 100644 --- a/mlx/distributed/sockets/sockets.cpp +++ b/mlx/distributed/sockets/sockets.cpp @@ -385,6 +385,19 @@ Group Group::split(int color, int key) { throw std::runtime_error("Splitting not supported yet"); } +void Group::barrier() { + char buff[128]; + std::memset(buff, 1, 128); + + auto group = std::static_pointer_cast(raw_group()); + int size = group->size(); + int rank = group->rank(); + + for (int distance = 1; distance <= size / 2; distance *= 2) { + group->send_recv_sum(buff, 128, rank ^ distance); + } +} + Group init(bool strict /* = false */) { static std::shared_ptr global_group = nullptr;