From 7fde1b6a1e9af278d9e219ec0452229e768e4792 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 9 Aug 2025 06:07:17 +0900 Subject: [PATCH] Fix logsumexp/softmax not fused for some cases (#2474) --- mlx/ops.cpp | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 6c4f76424..14e135deb 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2381,9 +2381,20 @@ array logsumexp( throw std::invalid_argument( "[logsumexp] Received non-empty axes for array with 0 dimensions."); } + bool reduce_last_dim = + !axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1); + if (reduce_last_dim) { + // For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape + // is [1, 1, ..., N]. + for (int i = axes.size() - 2; i >= 0; --i) { + if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) { + reduce_last_dim = false; + break; + } + } + } bool is_complex = issubdtype(a.dtype(), complexfloating); - if (!is_complex && axes.size() == 1 && - (a.ndim() == axes[0] + 1 || axes[0] == -1)) { + if (!is_complex && reduce_last_dim) { auto dtype = at_least_float(a.dtype()); auto out_shape = a.shape(); out_shape.back() = 1; @@ -3403,10 +3414,20 @@ array softmax( throw std::invalid_argument( "[softmax] Received non-empty axes for array with 0 dimensions."); } - + bool reduce_last_dim = + !axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1); + if (reduce_last_dim) { + // For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape + // is [1, 1, ..., N]. + for (int i = axes.size() - 2; i >= 0; --i) { + if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) { + reduce_last_dim = false; + break; + } + } + } bool is_complex = issubdtype(a.dtype(), complexfloating); - if (!is_complex && axes.size() == 1 && - (a.ndim() == axes[0] + 1 || axes[0] == -1)) { + if (!is_complex && reduce_last_dim) { auto dtype = at_least_float(a.dtype()); return array( a.shape(),