mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom logsumexp (#2028)
* initial custom logsumexp * more tests * comments + fix
This commit is contained in:
@@ -2509,6 +2509,49 @@ std::pair<std::vector<array>, std::vector<int>> LogAddExp::vmap(
|
||||
return {{logaddexp(a, b, stream())}, {to_ax}};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> LogSumExp::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0];
|
||||
auto in = inputs[0];
|
||||
if (ax == (in.ndim() - 1)) {
|
||||
in = swapaxes(in, -1, -2, stream());
|
||||
ax = in.ndim() - 2;
|
||||
}
|
||||
return {{logsumexp(in, -1, true, stream())}, {ax}};
|
||||
}
|
||||
|
||||
std::vector<array> LogSumExp::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(primals.size() == 1);
|
||||
assert(cotangents.size() == 1);
|
||||
return {multiply(
|
||||
cotangents[0],
|
||||
softmax(primals[0], std::vector<int>{-1}, true, stream()),
|
||||
stream())};
|
||||
}
|
||||
|
||||
std::vector<array> LogSumExp::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(primals.size() == 1);
|
||||
assert(tangents.size() == 1);
|
||||
return {multiply(
|
||||
tangents[0],
|
||||
softmax(primals[0], std::vector<int>{-1}, true, stream()),
|
||||
stream())};
|
||||
}
|
||||
|
||||
std::vector<Shape> LogSumExp::output_shapes(const std::vector<array>& inputs) {
|
||||
auto s = inputs[0].shape();
|
||||
s.back() = 1;
|
||||
return {s};
|
||||
}
|
||||
|
||||
std::vector<array> Matmul::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
||||
Reference in New Issue
Block a user