mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Make it work even for donated inputs
This commit is contained in:
		@@ -248,8 +248,22 @@ void all_sum(Group group_, const array& input_, array& output) {
 | 
			
		||||
    throw std::runtime_error("Only pairwise communication supported for now");
 | 
			
		||||
  }
 | 
			
		||||
  array input = ensure_row_contiguous(input_);
 | 
			
		||||
 | 
			
		||||
  // Donation not supported
 | 
			
		||||
  if (input.data<void>() == output.data<void>()) {
 | 
			
		||||
    throw std::runtime_error("Donation not supported");
 | 
			
		||||
    array temp(
 | 
			
		||||
        allocator::malloc_or_wait(output.nbytes()),
 | 
			
		||||
        output.shape(),
 | 
			
		||||
        output.dtype());
 | 
			
		||||
    if (group->rank() == 0) {
 | 
			
		||||
      group->send(input.data<char>(), input.nbytes(), 1);
 | 
			
		||||
      group->recv(temp.data<char>(), output.nbytes(), 1);
 | 
			
		||||
      sum_inplace(temp, output);
 | 
			
		||||
    } else {
 | 
			
		||||
      group->recv(temp.data<char>(), output.nbytes(), 0);
 | 
			
		||||
      group->send(input.data<char>(), input.nbytes(), 0);
 | 
			
		||||
      sum_inplace(temp, output);
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    if (group->rank() == 0) {
 | 
			
		||||
      group->send(input.data<char>(), input.nbytes(), 1);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user