matmul jvps (#1772)

This commit is contained in:
Awni Hannun
2025-01-17 10:36:26 -08:00
committed by GitHub
parent f288db8d34
commit 0c259961ac
4 changed files with 138 additions and 21 deletions

View File

@@ -193,12 +193,7 @@ class AddMM : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_GRADS()
DEFINE_VMAP()
DEFINE_PRINT(AddMM)
@@ -1459,12 +1454,7 @@ class Matmul : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_GRADS()
DEFINE_VMAP()
DEFINE_PRINT(Matmul)
DEFINE_DEFAULT_IS_EQUIVALENT()