mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 03:48:15 +08:00
Custom transforms (#1246)
This commit is contained in:

committed by
GitHub

parent
a3c287354f
commit
5c1fa64fb0
@@ -17,6 +17,10 @@ bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
@@ -102,7 +106,7 @@ void array::eval() {
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing();
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
|
@@ -36,7 +36,7 @@ DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(NumberOfElements)
|
||||
|
@@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void CustomVJP::eval(
|
||||
void CustomTransforms::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
|
@@ -52,7 +52,7 @@ DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
|
@@ -171,7 +171,7 @@ void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void CustomVJP::eval_gpu(
|
||||
void CustomTransforms::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
|
@@ -42,7 +42,7 @@ NO_CPU(Convolution)
|
||||
NO_CPU(Copy)
|
||||
NO_CPU(Cos)
|
||||
NO_CPU(Cosh)
|
||||
NO_CPU_MULTI(CustomVJP)
|
||||
NO_CPU_MULTI(CustomTransforms)
|
||||
NO_CPU_MULTI(Depends)
|
||||
NO_CPU(Divide)
|
||||
NO_CPU_MULTI(DivMod)
|
||||
|
@@ -43,7 +43,7 @@ NO_GPU(Convolution)
|
||||
NO_GPU(Copy)
|
||||
NO_GPU(Cos)
|
||||
NO_GPU(Cosh)
|
||||
NO_GPU_MULTI(CustomVJP)
|
||||
NO_GPU_MULTI(CustomTransforms)
|
||||
NO_GPU_MULTI(Depends)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
|
17
mlx/fast.cpp
17
mlx/fast.cpp
@@ -18,7 +18,7 @@ std::vector<array> Custom::vjp(
|
||||
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
|
||||
std::vector<array> vjp_outs;
|
||||
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
vjp_outs.push_back(vjps[i]);
|
||||
j++;
|
||||
}
|
||||
@@ -30,15 +30,16 @@ std::vector<array> Custom::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
|
||||
std::vector<array> jvp_outs;
|
||||
for (int i = 0, j = 0; i < jvps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
jvp_outs.push_back(jvps[i]);
|
||||
j++;
|
||||
std::vector<array> all_tangents;
|
||||
for (int i = 0, j = 0; i < primals.size(); i++) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
all_tangents.emplace_back(tangents[j++]);
|
||||
} else {
|
||||
all_tangents.emplace_back(zeros_like(primals[i]));
|
||||
}
|
||||
}
|
||||
return jvp_outs;
|
||||
auto [_, jvps] = mlx::core::jvp(fallback_, primals, all_tangents);
|
||||
return jvps;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
||||
|
@@ -1113,17 +1113,21 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
|
||||
return {{cosh(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> CustomVJP::vjp(
|
||||
std::vector<array> CustomTransforms::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
std::vector<array> inputs(primals.begin(), primals.end() - outputs.size());
|
||||
// Extract the inputs to the VJP function
|
||||
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
|
||||
|
||||
// Compute all the vjps
|
||||
auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
|
||||
for (const auto& cot : cotangents) {
|
||||
all_vjps.emplace_back(cot);
|
||||
}
|
||||
|
||||
// Select the vjps requested
|
||||
std::vector<array> vjps;
|
||||
vjps.reserve(argnums.size());
|
||||
for (auto arg : argnums) {
|
||||
@@ -1133,6 +1137,26 @@ std::vector<array> CustomVJP::vjp(
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> CustomTransforms::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Extract the inputs to the JVP function
|
||||
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
|
||||
|
||||
// Compute the jvps
|
||||
return jvp_fun_(inputs, tangents, argnums);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> CustomTransforms::vmap(
|
||||
const std::vector<array>& inputs_,
|
||||
const std::vector<int>& axes_) {
|
||||
// Extract the inputs to the vmap function
|
||||
std::vector<array> inputs(inputs_.begin(), inputs_.end() - num_outputs_);
|
||||
std::vector<int> axes(axes_.begin(), axes_.end() - num_outputs_);
|
||||
return vmap_fun_(inputs, axes);
|
||||
}
|
||||
|
||||
std::vector<array> Depends::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@@ -729,37 +729,56 @@ class Cosh : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class CustomVJP : public Primitive {
|
||||
class CustomTransforms : public Primitive {
|
||||
public:
|
||||
explicit CustomVJP(
|
||||
explicit CustomTransforms(
|
||||
Stream stream,
|
||||
int num_outputs,
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)> fun)
|
||||
: Primitive(stream), vjp_fun_(std::move(fun)) {}
|
||||
const std::vector<array>&)> vjp,
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)> jvp,
|
||||
std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)> vmap)
|
||||
: Primitive(stream),
|
||||
num_outputs_(num_outputs),
|
||||
vjp_fun_(std::move(vjp)),
|
||||
jvp_fun_(std::move(jvp)),
|
||||
vmap_fun_(std::move(vmap)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotan,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(CustomVJP);
|
||||
DEFINE_GRADS();
|
||||
DEFINE_VMAP();
|
||||
DEFINE_PRINT(CustomTransforms);
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
|
||||
int num_outputs_;
|
||||
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)>
|
||||
vjp_fun_;
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>
|
||||
jvp_fun_;
|
||||
std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>
|
||||
vmap_fun_;
|
||||
};
|
||||
|
||||
class Depends : public Primitive {
|
||||
|
@@ -33,8 +33,11 @@ class Synchronizer : public Primitive {
|
||||
// Initialize the static tracing counter from transforms_impl.h .
|
||||
//
|
||||
// This is used to implement the in_tracing() function the returns true if we
|
||||
// are currently under a function transformation.
|
||||
// are currently under a function transformation and the retain_graph()
|
||||
// function which returns true if we are forced to retain the graph during
|
||||
// evaluation.
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
int detail::RetainGraph::tracing_counter{0};
|
||||
|
||||
array eval_impl(std::vector<array> outputs, bool async) {
|
||||
std::queue<array> tape;
|
||||
@@ -331,7 +334,11 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
}
|
||||
}
|
||||
|
||||
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
|
||||
std::vector<array> vjps;
|
||||
{
|
||||
detail::RetainGraph retain;
|
||||
vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
|
||||
}
|
||||
// Accumulate the vector-jacobian products for each input
|
||||
for (int i = 0; i < argnums.size(); ++i) {
|
||||
auto in_id = a.inputs()[argnums[i]].id();
|
||||
@@ -778,14 +785,27 @@ std::function<array(const array&)> vmap(
|
||||
return [vfun](const array& a) { return vfun({a})[0]; };
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::function<std::vector<array>(
|
||||
std::optional<std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)> fun_vjp) {
|
||||
const std::vector<array>&)>> fun_vjp /* = std::nullopt */,
|
||||
std::optional<std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>> fun_jvp /* = std::nullopt */,
|
||||
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>> fun_vmap /* = std::nullopt */) {
|
||||
if (!fun_vjp.has_value() && !fun_jvp.has_value() && !fun_vmap.has_value()) {
|
||||
return fun;
|
||||
}
|
||||
|
||||
return [fun = std::move(fun),
|
||||
fun_vjp = std::move(fun_vjp)](const std::vector<array>& args) {
|
||||
fun_vjp = std::move(fun_vjp),
|
||||
fun_jvp = std::move(fun_jvp),
|
||||
fun_vmap = std::move(fun_vmap)](const std::vector<array>& args) {
|
||||
// Compute the outputs
|
||||
auto outputs = fun(args);
|
||||
for (auto& out : outputs) {
|
||||
@@ -814,11 +834,63 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
return array::make_arrays(
|
||||
std::move(shapes),
|
||||
dtypes,
|
||||
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
|
||||
std::make_shared<CustomTransforms>(
|
||||
to_stream(s),
|
||||
outputs.size(),
|
||||
|
||||
// We use the passed vjp function or compute it from the inputs and
|
||||
// passed cotangents. Note that this may be less efficient than
|
||||
// using `fun` directly because we may not be able to fully reuse
|
||||
// the outputs of the forward pass.
|
||||
fun_vjp.value_or(
|
||||
[fun](auto primals, auto cotangents, auto outputs) {
|
||||
auto [__, vjps] = vjp(fun, primals, cotangents);
|
||||
return vjps;
|
||||
}),
|
||||
|
||||
// We use the passed jvp function or compute it from the primals
|
||||
// and tangents. Similarly we can't take full advantage of the
|
||||
// argnums so it is best to use `fun` directly if we don't need a
|
||||
// custom transform.
|
||||
//
|
||||
// TODO: Use stop_gradient to make full use of argnums and not
|
||||
// waste computation.
|
||||
fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) {
|
||||
std::vector<array> all_tangents;
|
||||
for (int i = 0, j = 0; i < primals.size(); i++) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
all_tangents.emplace_back(tangents[j++]);
|
||||
} else {
|
||||
all_tangents.emplace_back(zeros_like(primals[i]));
|
||||
}
|
||||
}
|
||||
auto [__, jvps] = jvp(fun, primals, all_tangents);
|
||||
return jvps;
|
||||
}),
|
||||
|
||||
// Same as above, we use the passed vmap function or we compute it
|
||||
// from `fun`. The output axes is selected to be all 0s which again
|
||||
// may be suboptimal but the only thing we can do without any
|
||||
// information for `fun`.
|
||||
fun_vmap.value_or(
|
||||
[fun, out_size = outputs.size()](auto inputs, auto in_axes)
|
||||
-> std::pair<std::vector<array>, std::vector<int>> {
|
||||
std::vector<int> out_axes(out_size, 0);
|
||||
return {vmap(fun, in_axes, out_axes)(inputs), out_axes};
|
||||
})),
|
||||
inputs);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)> fun_vjp) {
|
||||
return custom_function(fun, fun_vjp, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun) {
|
||||
auto vjp_fun = [fun](
|
||||
|
@@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -179,8 +181,31 @@ std::function<std::vector<array>(const std::vector<array>&)> vmap(
|
||||
const std::vector<int>& out_axes = {});
|
||||
|
||||
/**
|
||||
* Return the results of calling fun with args but if their vjp is computed it
|
||||
* will be computed by fun_vjp.
|
||||
* Redefine the transformations of `fun` according to the provided functions.
|
||||
*
|
||||
* Namely when calling the vjp of `fun` then `fun_vjp` will be called,
|
||||
* `fun_jvp` for the jvp and `fun_vmap` for vmap.
|
||||
*
|
||||
* If any transformation is not provided, then a default one is created by
|
||||
* calling `vjp`, `jvp` and `vmap` on the function directly.
|
||||
*/
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::optional<std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)>> fun_vjp = std::nullopt,
|
||||
std::optional<std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>> fun_jvp = std::nullopt,
|
||||
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>> fun_vmap = std::nullopt);
|
||||
|
||||
/**
|
||||
* Return a function that behaves exactly like `fun` but if the vjp of the
|
||||
* results is computed `fun_vjp` will be used instead of `vjp(fun, ...)` .
|
||||
*/
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
|
@@ -50,4 +50,20 @@ struct InTracing {
|
||||
static int tracing_counter;
|
||||
};
|
||||
|
||||
struct RetainGraph {
|
||||
RetainGraph() {
|
||||
tracing_counter++;
|
||||
}
|
||||
~RetainGraph() {
|
||||
tracing_counter--;
|
||||
}
|
||||
|
||||
static bool retain_graph() {
|
||||
return tracing_counter > 0;
|
||||
}
|
||||
|
||||
private:
|
||||
static int tracing_counter;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
Reference in New Issue
Block a user