From d999675cb98564087d65986977cc037216502cae Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 24 Jun 2025 10:51:03 -0700 Subject: [PATCH] Make check more general --- mlx/backend/cuda/reduce/reduce_utils.cuh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index f3b30f587..b76411261 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -108,11 +108,16 @@ inline void allocate_same_layout( array& out, const array& in, const std::vector& axes) { - if (out.ndim() < in.ndim()) { + if (in.flags().row_contiguous) { out.set_data(allocator::malloc(out.nbytes())); return; } + if (out.ndim() < in.ndim()) { + throw std::runtime_error( + "Reduction without keepdims only supported for row-contiguous inputs"); + } + // Calculate the transpositions applied to in in order to apply them to out. std::vector axis_order(in.ndim()); std::iota(axis_order.begin(), axis_order.end(), 0);