mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 22:11:15 +08:00
Make it work even for donated inputs
This commit is contained in:
parent
794a48a9f6
commit
4aefacf0bc
@ -248,8 +248,22 @@ void all_sum(Group group_, const array& input_, array& output) {
|
|||||||
throw std::runtime_error("Only pairwise communication supported for now");
|
throw std::runtime_error("Only pairwise communication supported for now");
|
||||||
}
|
}
|
||||||
array input = ensure_row_contiguous(input_);
|
array input = ensure_row_contiguous(input_);
|
||||||
|
|
||||||
|
// Donation not supported
|
||||||
if (input.data<void>() == output.data<void>()) {
|
if (input.data<void>() == output.data<void>()) {
|
||||||
throw std::runtime_error("Donation not supported");
|
array temp(
|
||||||
|
allocator::malloc_or_wait(output.nbytes()),
|
||||||
|
output.shape(),
|
||||||
|
output.dtype());
|
||||||
|
if (group->rank() == 0) {
|
||||||
|
group->send(input.data<char>(), input.nbytes(), 1);
|
||||||
|
group->recv(temp.data<char>(), output.nbytes(), 1);
|
||||||
|
sum_inplace(temp, output);
|
||||||
|
} else {
|
||||||
|
group->recv(temp.data<char>(), output.nbytes(), 0);
|
||||||
|
group->send(input.data<char>(), input.nbytes(), 0);
|
||||||
|
sum_inplace(temp, output);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (group->rank() == 0) {
|
if (group->rank() == 0) {
|
||||||
group->send(input.data<char>(), input.nbytes(), 1);
|
group->send(input.data<char>(), input.nbytes(), 1);
|
||||||
|
Loading…
Reference in New Issue
Block a user