diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index 0cbccf69e..f3b30f587 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -108,6 +108,11 @@ inline void allocate_same_layout( array& out, const array& in, const std::vector& axes) { + if (out.ndim() < in.ndim()) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } + // Calculate the transpositions applied to in in order to apply them to out. std::vector axis_order(in.ndim()); std::iota(axis_order.begin(), axis_order.end(), 0);