Custom logsumexp (#2028)

* initial custom logsumexp

* more tests

* comments + fix
This commit is contained in:
Awni Hannun
2025-03-31 07:36:55 -07:00
committed by GitHub
parent ec2854b13a
commit de5f38fd48
27 changed files with 590 additions and 255 deletions

View File

@@ -2359,6 +2359,29 @@ array logsumexp(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
if (a.size() == 0) {
throw std::invalid_argument("[logsumexp] Received empty array.");
}
if (a.ndim() == 0 && !axes.empty()) {
throw std::invalid_argument(
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
}
bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && axes.size() == 1 &&
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype());
auto out_shape = a.shape();
out_shape.back() = 1;
auto out = array(
std::move(out_shape),
dtype,
std::make_shared<LogSumExp>(to_stream(s)),
{astype(a, dtype, s)});
if (!keepdims) {
out = squeeze(out, -1, s);
}
return out;
}
auto maxval = stop_gradient(max(a, axes, true, s), s);
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
out = add(out, reshape(maxval, out.shape(), s), s);
@@ -3347,8 +3370,14 @@ array softmax(
if (a.size() == 0) {
return a;
}
if (a.ndim() == 0 && !axes.empty()) {
throw std::invalid_argument(
"[softmax] Received non-empty axes for array with 0 dimensions.");
}
if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) {
bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && axes.size() == 1 &&
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype());
return array(
a.shape(),
@@ -3357,7 +3386,7 @@ array softmax(
{astype(a, dtype, s)});
} else {
auto in = a;
if (precise) {
if (precise && !is_complex) {
in = astype(a, float32, s);
}
auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s);