Custom logsumexp (#2028)

* initial custom logsumexp

* more tests

* comments + fix
This commit is contained in:
Awni Hannun
2025-03-31 07:36:55 -07:00
committed by GitHub
parent ec2854b13a
commit de5f38fd48
27 changed files with 590 additions and 255 deletions

View File

@@ -2509,6 +2509,49 @@ std::pair<std::vector<array>, std::vector<int>> LogAddExp::vmap(
return {{logaddexp(a, b, stream())}, {to_ax}};
}
std::pair<std::vector<array>, std::vector<int>> LogSumExp::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0];
auto in = inputs[0];
if (ax == (in.ndim() - 1)) {
in = swapaxes(in, -1, -2, stream());
ax = in.ndim() - 2;
}
return {{logsumexp(in, -1, true, stream())}, {ax}};
}
std::vector<array> LogSumExp::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(primals.size() == 1);
assert(cotangents.size() == 1);
return {multiply(
cotangents[0],
softmax(primals[0], std::vector<int>{-1}, true, stream()),
stream())};
}
std::vector<array> LogSumExp::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(tangents.size() == 1);
return {multiply(
tangents[0],
softmax(primals[0], std::vector<int>{-1}, true, stream()),
stream())};
}
std::vector<Shape> LogSumExp::output_shapes(const std::vector<array>& inputs) {
auto s = inputs[0].shape();
s.back() = 1;
return {s};
}
std::vector<array> Matmul::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,