From a7faa04cd4d91648ce058f782635eeaa84ec28ac Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 24 Jun 2025 10:39:53 -0700 Subject: [PATCH] Add a special case when not keeping the dims --- mlx/backend/cuda/reduce/reduce_utils.cuh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index 0cbccf69e..f3b30f587 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -108,6 +108,11 @@ inline void allocate_same_layout( array& out, const array& in, const std::vector& axes) { + if (out.ndim() < in.ndim()) { + out.set_data(allocator::malloc(out.nbytes())); + return; + } + // Calculate the transpositions applied to in in order to apply them to out. std::vector axis_order(in.ndim()); std::iota(axis_order.begin(), axis_order.end(), 0);