add zero vjps for bitwise ops and gather w.r.t. index (#1256)

This commit is contained in:
Alex Barron
2024-07-07 21:34:59 -07:00
committed by GitHub
parent 20bb301195
commit bdb36c9a63
4 changed files with 62 additions and 7 deletions

View File

@@ -471,6 +471,7 @@ class BitwiseBinary : public UnaryPrimitive {
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
bool is_equivalent(const Primitive& other) const override;
void print(std::ostream& os) override;
DEFINE_INPUT_OUTPUT_SHAPE()