diff --git a/mlx/backend/cuda/distributed.cu b/mlx/backend/cuda/distributed.cu index 90eeacd7c..4d2658534 100644 --- a/mlx/backend/cuda/distributed.cu +++ b/mlx/backend/cuda/distributed.cu @@ -24,6 +24,7 @@ void AllReduce::eval_gpu( out.copy_shared_buffer(in); return {in, out}; } else { + out.set_data(allocator::malloc(out.nbytes())); return {in, out}; } };