add zero vjps for bitwise ops and gather w.r.t. index (#1256)

This commit is contained in:
Alex Barron
2024-07-07 21:34:59 -07:00
committed by GitHub
parent 20bb301195
commit bdb36c9a63
4 changed files with 62 additions and 7 deletions

View File

@@ -523,7 +523,7 @@ std::vector<array> AsType::vjp(
const std::vector<array>&) {
if (cotangents[0].dtype() != dtype_) {
throw std::invalid_argument(
"[astype] Type of cotangentsgent does not much primal output type.");
"[astype] Type of cotangents does not match primal output type.");
}
return {astype(cotangents[0], primals[0].dtype(), stream())};
}
@@ -629,6 +629,26 @@ std::pair<std::vector<array>, std::vector<int>> BitwiseBinary::vmap(
{to_ax}};
}
std::vector<array> BitwiseBinary::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 2);
std::vector<array> vjps = {zeros_like(tangents[0], stream())};
if (argnums.size() > 1) {
vjps.push_back(vjps.back());
}
return vjps;
}
std::vector<array> BitwiseBinary::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}
std::vector<array> Broadcast::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
@@ -1598,13 +1618,19 @@ std::vector<array> Gather::vjp(
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
if (argnums.size() > 1 || argnums[0] != 0) {
throw std::invalid_argument(
"[gather] Cannot calculate VJP with respect to indices.");
std::vector<array> vjps;
for (int argnum : argnums) {
if (argnum > 0) {
// Grads w.r.t. indices are zero
vjps.push_back(
zeros(primals[argnum].shape(), primals[argnum].dtype(), stream()));
} else {
auto src = zeros_like(primals[0], stream());
std::vector<array> inds(primals.begin() + 1, primals.end());
vjps.push_back(scatter_add(src, inds, cotangents[0], axes_, stream()));
}
}
auto src = zeros_like(primals[0], stream());
std::vector<array> inds(primals.begin() + 1, primals.end());
return {scatter_add(src, inds, cotangents[0], axes_, stream())};
return vjps;
}
std::vector<array> Gather::jvp(

View File

@@ -471,6 +471,7 @@ class BitwiseBinary : public UnaryPrimitive {
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
bool is_equivalent(const Primitive& other) const override;
void print(std::ostream& os) override;
DEFINE_INPUT_OUTPUT_SHAPE()