Custom transforms (#1246)

This commit is contained in:
Angelos Katharopoulos
2024-07-10 18:00:01 -07:00
committed by GitHub
parent a3c287354f
commit 5c1fa64fb0
16 changed files with 734 additions and 39 deletions

View File

@@ -729,37 +729,56 @@ class Cosh : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class CustomVJP : public Primitive {
class CustomTransforms : public Primitive {
public:
explicit CustomVJP(
explicit CustomTransforms(
Stream stream,
int num_outputs,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun)
: Primitive(stream), vjp_fun_(std::move(fun)) {}
const std::vector<array>&)> vjp,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&)> jvp,
std::function<std::pair<std::vector<array>, std::vector<int>>(
const std::vector<array>&,
const std::vector<int>&)> vmap)
: Primitive(stream),
num_outputs_(num_outputs),
vjp_fun_(std::move(vjp)),
jvp_fun_(std::move(jvp)),
vmap_fun_(std::move(vmap)) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(CustomVJP);
DEFINE_GRADS();
DEFINE_VMAP();
DEFINE_PRINT(CustomTransforms);
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
int num_outputs_;
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)>
vjp_fun_;
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&)>
jvp_fun_;
std::function<std::pair<std::vector<array>, std::vector<int>>(
const std::vector<array>&,
const std::vector<int>&)>
vmap_fun_;
};
class Depends : public Primitive {