From 4aefacf0bc22f0fbfac08f80a5a3c00e4e3c9a68 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 29 Aug 2024 01:52:56 -0700 Subject: [PATCH] Make it work even for donated inputs --- mlx/distributed/sockets/sockets.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mlx/distributed/sockets/sockets.cpp b/mlx/distributed/sockets/sockets.cpp index 1ea19a3a8..6f725314f 100644 --- a/mlx/distributed/sockets/sockets.cpp +++ b/mlx/distributed/sockets/sockets.cpp @@ -248,8 +248,22 @@ void all_sum(Group group_, const array& input_, array& output) { throw std::runtime_error("Only pairwise communication supported for now"); } array input = ensure_row_contiguous(input_); + + // Donation not supported if (input.data() == output.data()) { - throw std::runtime_error("Donation not supported"); + array temp( + allocator::malloc_or_wait(output.nbytes()), + output.shape(), + output.dtype()); + if (group->rank() == 0) { + group->send(input.data(), input.nbytes(), 1); + group->recv(temp.data(), output.nbytes(), 1); + sum_inplace(temp, output); + } else { + group->recv(temp.data(), output.nbytes(), 0); + group->send(input.data(), input.nbytes(), 0); + sum_inplace(temp, output); + } } else { if (group->rank() == 0) { group->send(input.data(), input.nbytes(), 1);