Bitwise Inverse (#1862)

* add bitwise inverse

* add vmap + fix nojit

* inverse -> invert

* add to compile + remove unused
This commit is contained in:
Alex Barron
2025-02-13 08:44:14 -08:00
committed by GitHub
parent e425dc00c0
commit 5cd97f7ffe
19 changed files with 147 additions and 8 deletions

View File

@@ -4522,6 +4522,14 @@ std::pair<std::vector<array>, std::vector<int>> Tanh::vmap(
return {{tanh(inputs[0], stream())}, axes};
}
std::pair<std::vector<array>, std::vector<int>> BitwiseInvert::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{bitwise_invert(inputs[0], stream())}, axes};
}
std::vector<array> BlockMaskedMM::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,