mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 13:51:13 +08:00
Make it work even for donated inputs
This commit is contained in:
parent
794a48a9f6
commit
4aefacf0bc
@ -248,8 +248,22 @@ void all_sum(Group group_, const array& input_, array& output) {
|
||||
throw std::runtime_error("Only pairwise communication supported for now");
|
||||
}
|
||||
array input = ensure_row_contiguous(input_);
|
||||
|
||||
// Donation not supported
|
||||
if (input.data<void>() == output.data<void>()) {
|
||||
throw std::runtime_error("Donation not supported");
|
||||
array temp(
|
||||
allocator::malloc_or_wait(output.nbytes()),
|
||||
output.shape(),
|
||||
output.dtype());
|
||||
if (group->rank() == 0) {
|
||||
group->send(input.data<char>(), input.nbytes(), 1);
|
||||
group->recv(temp.data<char>(), output.nbytes(), 1);
|
||||
sum_inplace(temp, output);
|
||||
} else {
|
||||
group->recv(temp.data<char>(), output.nbytes(), 0);
|
||||
group->send(input.data<char>(), input.nbytes(), 0);
|
||||
sum_inplace(temp, output);
|
||||
}
|
||||
} else {
|
||||
if (group->rank() == 0) {
|
||||
group->send(input.data<char>(), input.nbytes(), 1);
|
||||
|
Loading…
Reference in New Issue
Block a user