Make it work even for donated inputs

This commit is contained in:
Angelos Katharopoulos 2024-08-29 01:52:56 -07:00
parent 794a48a9f6
commit 4aefacf0bc

View File

@ -248,8 +248,22 @@ void all_sum(Group group_, const array& input_, array& output) {
throw std::runtime_error("Only pairwise communication supported for now");
}
array input = ensure_row_contiguous(input_);
// Donation not supported
if (input.data<void>() == output.data<void>()) {
throw std::runtime_error("Donation not supported");
array temp(
allocator::malloc_or_wait(output.nbytes()),
output.shape(),
output.dtype());
if (group->rank() == 0) {
group->send(input.data<char>(), input.nbytes(), 1);
group->recv(temp.data<char>(), output.nbytes(), 1);
sum_inplace(temp, output);
} else {
group->recv(temp.data<char>(), output.nbytes(), 0);
group->send(input.data<char>(), input.nbytes(), 0);
sum_inplace(temp, output);
}
} else {
if (group->rank() == 0) {
group->send(input.data<char>(), input.nbytes(), 1);