mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom logsumexp (#2028)
* initial custom logsumexp * more tests * comments + fix
This commit is contained in:
33
mlx/ops.cpp
33
mlx/ops.cpp
@@ -2359,6 +2359,29 @@ array logsumexp(
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (a.size() == 0) {
|
||||
throw std::invalid_argument("[logsumexp] Received empty array.");
|
||||
}
|
||||
if (a.ndim() == 0 && !axes.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
|
||||
}
|
||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||
if (!is_complex && axes.size() == 1 &&
|
||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto out_shape = a.shape();
|
||||
out_shape.back() = 1;
|
||||
auto out = array(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<LogSumExp>(to_stream(s)),
|
||||
{astype(a, dtype, s)});
|
||||
if (!keepdims) {
|
||||
out = squeeze(out, -1, s);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
auto maxval = stop_gradient(max(a, axes, true, s), s);
|
||||
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
|
||||
out = add(out, reshape(maxval, out.shape(), s), s);
|
||||
@@ -3347,8 +3370,14 @@ array softmax(
|
||||
if (a.size() == 0) {
|
||||
return a;
|
||||
}
|
||||
if (a.ndim() == 0 && !axes.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[softmax] Received non-empty axes for array with 0 dimensions.");
|
||||
}
|
||||
|
||||
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||
if (!is_complex && axes.size() == 1 &&
|
||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return array(
|
||||
a.shape(),
|
||||
@@ -3357,7 +3386,7 @@ array softmax(
|
||||
{astype(a, dtype, s)});
|
||||
} else {
|
||||
auto in = a;
|
||||
if (precise) {
|
||||
if (precise && !is_complex) {
|
||||
in = astype(a, float32, s);
|
||||
}
|
||||
auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);
|
||||
|
||||
Reference in New Issue
Block a user