mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 18:39:45 +08:00
Fix logsumexp/softmax not fused for some cases (#2474)
This commit is contained in:
parent
aa7b47481a
commit
7fde1b6a1e
31
mlx/ops.cpp
31
mlx/ops.cpp
@ -2381,9 +2381,20 @@ array logsumexp(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
|
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
|
||||||
}
|
}
|
||||||
|
bool reduce_last_dim =
|
||||||
|
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
|
||||||
|
if (reduce_last_dim) {
|
||||||
|
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
|
||||||
|
// is [1, 1, ..., N].
|
||||||
|
for (int i = axes.size() - 2; i >= 0; --i) {
|
||||||
|
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
|
||||||
|
reduce_last_dim = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||||
if (!is_complex && axes.size() == 1 &&
|
if (!is_complex && reduce_last_dim) {
|
||||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
auto out_shape = a.shape();
|
auto out_shape = a.shape();
|
||||||
out_shape.back() = 1;
|
out_shape.back() = 1;
|
||||||
@ -3403,10 +3414,20 @@ array softmax(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[softmax] Received non-empty axes for array with 0 dimensions.");
|
"[softmax] Received non-empty axes for array with 0 dimensions.");
|
||||||
}
|
}
|
||||||
|
bool reduce_last_dim =
|
||||||
|
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
|
||||||
|
if (reduce_last_dim) {
|
||||||
|
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
|
||||||
|
// is [1, 1, ..., N].
|
||||||
|
for (int i = axes.size() - 2; i >= 0; --i) {
|
||||||
|
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
|
||||||
|
reduce_last_dim = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||||
if (!is_complex && axes.size() == 1 &&
|
if (!is_complex && reduce_last_dim) {
|
||||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
|
||||||
auto dtype = at_least_float(a.dtype());
|
auto dtype = at_least_float(a.dtype());
|
||||||
return array(
|
return array(
|
||||||
a.shape(),
|
a.shape(),
|
||||||
|
Loading…
Reference in New Issue
Block a user