mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Custom transforms (#1246)
This commit is contained in:
committed by
GitHub
parent
a3c287354f
commit
5c1fa64fb0
@@ -729,37 +729,56 @@ class Cosh : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class CustomVJP : public Primitive {
|
||||
class CustomTransforms : public Primitive {
|
||||
public:
|
||||
explicit CustomVJP(
|
||||
explicit CustomTransforms(
|
||||
Stream stream,
|
||||
int num_outputs,
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)> fun)
|
||||
: Primitive(stream), vjp_fun_(std::move(fun)) {}
|
||||
const std::vector<array>&)> vjp,
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)> jvp,
|
||||
std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)> vmap)
|
||||
: Primitive(stream),
|
||||
num_outputs_(num_outputs),
|
||||
vjp_fun_(std::move(vjp)),
|
||||
jvp_fun_(std::move(jvp)),
|
||||
vmap_fun_(std::move(vmap)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotan,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(CustomVJP);
|
||||
DEFINE_GRADS();
|
||||
DEFINE_VMAP();
|
||||
DEFINE_PRINT(CustomTransforms);
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
|
||||
int num_outputs_;
|
||||
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)>
|
||||
vjp_fun_;
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>
|
||||
jvp_fun_;
|
||||
std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>
|
||||
vmap_fun_;
|
||||
};
|
||||
|
||||
class Depends : public Primitive {
|
||||
|
||||
Reference in New Issue
Block a user