mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Custom transforms (#1246)
This commit is contained in:
parent
a3c287354f
commit
5c1fa64fb0
@ -10,6 +10,7 @@ Transforms
|
||||
|
||||
eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
enable_compile
|
||||
grad
|
||||
|
@ -17,6 +17,10 @@ bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
@ -102,7 +106,7 @@ void array::eval() {
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing();
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
|
@ -36,7 +36,7 @@ DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(NumberOfElements)
|
||||
|
@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void CustomVJP::eval(
|
||||
void CustomTransforms::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
|
@ -52,7 +52,7 @@ DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
|
@ -171,7 +171,7 @@ void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void CustomVJP::eval_gpu(
|
||||
void CustomTransforms::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
|
@ -42,7 +42,7 @@ NO_CPU(Convolution)
|
||||
NO_CPU(Copy)
|
||||
NO_CPU(Cos)
|
||||
NO_CPU(Cosh)
|
||||
NO_CPU_MULTI(CustomVJP)
|
||||
NO_CPU_MULTI(CustomTransforms)
|
||||
NO_CPU_MULTI(Depends)
|
||||
NO_CPU(Divide)
|
||||
NO_CPU_MULTI(DivMod)
|
||||
|
@ -43,7 +43,7 @@ NO_GPU(Convolution)
|
||||
NO_GPU(Copy)
|
||||
NO_GPU(Cos)
|
||||
NO_GPU(Cosh)
|
||||
NO_GPU_MULTI(CustomVJP)
|
||||
NO_GPU_MULTI(CustomTransforms)
|
||||
NO_GPU_MULTI(Depends)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
|
17
mlx/fast.cpp
17
mlx/fast.cpp
@ -18,7 +18,7 @@ std::vector<array> Custom::vjp(
|
||||
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
|
||||
std::vector<array> vjp_outs;
|
||||
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
vjp_outs.push_back(vjps[i]);
|
||||
j++;
|
||||
}
|
||||
@ -30,15 +30,16 @@ std::vector<array> Custom::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
|
||||
std::vector<array> jvp_outs;
|
||||
for (int i = 0, j = 0; i < jvps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
jvp_outs.push_back(jvps[i]);
|
||||
j++;
|
||||
std::vector<array> all_tangents;
|
||||
for (int i = 0, j = 0; i < primals.size(); i++) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
all_tangents.emplace_back(tangents[j++]);
|
||||
} else {
|
||||
all_tangents.emplace_back(zeros_like(primals[i]));
|
||||
}
|
||||
}
|
||||
return jvp_outs;
|
||||
auto [_, jvps] = mlx::core::jvp(fallback_, primals, all_tangents);
|
||||
return jvps;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
||||
|
@ -1113,17 +1113,21 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
|
||||
return {{cosh(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> CustomVJP::vjp(
|
||||
std::vector<array> CustomTransforms::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
std::vector<array> inputs(primals.begin(), primals.end() - outputs.size());
|
||||
// Extract the inputs to the VJP function
|
||||
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
|
||||
|
||||
// Compute all the vjps
|
||||
auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
|
||||
for (const auto& cot : cotangents) {
|
||||
all_vjps.emplace_back(cot);
|
||||
}
|
||||
|
||||
// Select the vjps requested
|
||||
std::vector<array> vjps;
|
||||
vjps.reserve(argnums.size());
|
||||
for (auto arg : argnums) {
|
||||
@ -1133,6 +1137,26 @@ std::vector<array> CustomVJP::vjp(
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> CustomTransforms::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Extract the inputs to the JVP function
|
||||
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
|
||||
|
||||
// Compute the jvps
|
||||
return jvp_fun_(inputs, tangents, argnums);
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> CustomTransforms::vmap(
|
||||
const std::vector<array>& inputs_,
|
||||
const std::vector<int>& axes_) {
|
||||
// Extract the inputs to the vmap function
|
||||
std::vector<array> inputs(inputs_.begin(), inputs_.end() - num_outputs_);
|
||||
std::vector<int> axes(axes_.begin(), axes_.end() - num_outputs_);
|
||||
return vmap_fun_(inputs, axes);
|
||||
}
|
||||
|
||||
std::vector<array> Depends::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@ -729,37 +729,56 @@ class Cosh : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class CustomVJP : public Primitive {
|
||||
class CustomTransforms : public Primitive {
|
||||
public:
|
||||
explicit CustomVJP(
|
||||
explicit CustomTransforms(
|
||||
Stream stream,
|
||||
int num_outputs,
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)> fun)
|
||||
: Primitive(stream), vjp_fun_(std::move(fun)) {}
|
||||
const std::vector<array>&)> vjp,
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)> jvp,
|
||||
std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)> vmap)
|
||||
: Primitive(stream),
|
||||
num_outputs_(num_outputs),
|
||||
vjp_fun_(std::move(vjp)),
|
||||
jvp_fun_(std::move(jvp)),
|
||||
vmap_fun_(std::move(vmap)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotan,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
DEFINE_PRINT(CustomVJP);
|
||||
DEFINE_GRADS();
|
||||
DEFINE_VMAP();
|
||||
DEFINE_PRINT(CustomTransforms);
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
|
||||
int num_outputs_;
|
||||
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)>
|
||||
vjp_fun_;
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>
|
||||
jvp_fun_;
|
||||
std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>
|
||||
vmap_fun_;
|
||||
};
|
||||
|
||||
class Depends : public Primitive {
|
||||
|
@ -33,8 +33,11 @@ class Synchronizer : public Primitive {
|
||||
// Initialize the static tracing counter from transforms_impl.h .
|
||||
//
|
||||
// This is used to implement the in_tracing() function the returns true if we
|
||||
// are currently under a function transformation.
|
||||
// are currently under a function transformation and the retain_graph()
|
||||
// function which returns true if we are forced to retain the graph during
|
||||
// evaluation.
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
int detail::RetainGraph::tracing_counter{0};
|
||||
|
||||
array eval_impl(std::vector<array> outputs, bool async) {
|
||||
std::queue<array> tape;
|
||||
@ -331,7 +334,11 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
}
|
||||
}
|
||||
|
||||
auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
|
||||
std::vector<array> vjps;
|
||||
{
|
||||
detail::RetainGraph retain;
|
||||
vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
|
||||
}
|
||||
// Accumulate the vector-jacobian products for each input
|
||||
for (int i = 0; i < argnums.size(); ++i) {
|
||||
auto in_id = a.inputs()[argnums[i]].id();
|
||||
@ -778,14 +785,27 @@ std::function<array(const array&)> vmap(
|
||||
return [vfun](const array& a) { return vfun({a})[0]; };
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::function<std::vector<array>(
|
||||
std::optional<std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)> fun_vjp) {
|
||||
const std::vector<array>&)>> fun_vjp /* = std::nullopt */,
|
||||
std::optional<std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>> fun_jvp /* = std::nullopt */,
|
||||
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>> fun_vmap /* = std::nullopt */) {
|
||||
if (!fun_vjp.has_value() && !fun_jvp.has_value() && !fun_vmap.has_value()) {
|
||||
return fun;
|
||||
}
|
||||
|
||||
return [fun = std::move(fun),
|
||||
fun_vjp = std::move(fun_vjp)](const std::vector<array>& args) {
|
||||
fun_vjp = std::move(fun_vjp),
|
||||
fun_jvp = std::move(fun_jvp),
|
||||
fun_vmap = std::move(fun_vmap)](const std::vector<array>& args) {
|
||||
// Compute the outputs
|
||||
auto outputs = fun(args);
|
||||
for (auto& out : outputs) {
|
||||
@ -814,11 +834,63 @@ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
return array::make_arrays(
|
||||
std::move(shapes),
|
||||
dtypes,
|
||||
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
|
||||
std::make_shared<CustomTransforms>(
|
||||
to_stream(s),
|
||||
outputs.size(),
|
||||
|
||||
// We use the passed vjp function or compute it from the inputs and
|
||||
// passed cotangents. Note that this may be less efficient than
|
||||
// using `fun` directly because we may not be able to fully reuse
|
||||
// the outputs of the forward pass.
|
||||
fun_vjp.value_or(
|
||||
[fun](auto primals, auto cotangents, auto outputs) {
|
||||
auto [__, vjps] = vjp(fun, primals, cotangents);
|
||||
return vjps;
|
||||
}),
|
||||
|
||||
// We use the passed jvp function or compute it from the primals
|
||||
// and tangents. Similarly we can't take full advantage of the
|
||||
// argnums so it is best to use `fun` directly if we don't need a
|
||||
// custom transform.
|
||||
//
|
||||
// TODO: Use stop_gradient to make full use of argnums and not
|
||||
// waste computation.
|
||||
fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) {
|
||||
std::vector<array> all_tangents;
|
||||
for (int i = 0, j = 0; i < primals.size(); i++) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
all_tangents.emplace_back(tangents[j++]);
|
||||
} else {
|
||||
all_tangents.emplace_back(zeros_like(primals[i]));
|
||||
}
|
||||
}
|
||||
auto [__, jvps] = jvp(fun, primals, all_tangents);
|
||||
return jvps;
|
||||
}),
|
||||
|
||||
// Same as above, we use the passed vmap function or we compute it
|
||||
// from `fun`. The output axes is selected to be all 0s which again
|
||||
// may be suboptimal but the only thing we can do without any
|
||||
// information for `fun`.
|
||||
fun_vmap.value_or(
|
||||
[fun, out_size = outputs.size()](auto inputs, auto in_axes)
|
||||
-> std::pair<std::vector<array>, std::vector<int>> {
|
||||
std::vector<int> out_axes(out_size, 0);
|
||||
return {vmap(fun, in_axes, out_axes)(inputs), out_axes};
|
||||
})),
|
||||
inputs);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)> fun_vjp) {
|
||||
return custom_function(fun, fun_vjp, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun) {
|
||||
auto vjp_fun = [fun](
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -179,8 +181,31 @@ std::function<std::vector<array>(const std::vector<array>&)> vmap(
|
||||
const std::vector<int>& out_axes = {});
|
||||
|
||||
/**
|
||||
* Return the results of calling fun with args but if their vjp is computed it
|
||||
* will be computed by fun_vjp.
|
||||
* Redefine the transformations of `fun` according to the provided functions.
|
||||
*
|
||||
* Namely when calling the vjp of `fun` then `fun_vjp` will be called,
|
||||
* `fun_jvp` for the jvp and `fun_vmap` for vmap.
|
||||
*
|
||||
* If any transformation is not provided, then a default one is created by
|
||||
* calling `vjp`, `jvp` and `vmap` on the function directly.
|
||||
*/
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::optional<std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&)>> fun_vjp = std::nullopt,
|
||||
std::optional<std::function<std::vector<array>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>> fun_jvp = std::nullopt,
|
||||
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&)>> fun_vmap = std::nullopt);
|
||||
|
||||
/**
|
||||
* Return a function that behaves exactly like `fun` but if the vjp of the
|
||||
* results is computed `fun_vjp` will be used instead of `vjp(fun, ...)` .
|
||||
*/
|
||||
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
|
@ -50,4 +50,20 @@ struct InTracing {
|
||||
static int tracing_counter;
|
||||
};
|
||||
|
||||
struct RetainGraph {
|
||||
RetainGraph() {
|
||||
tracing_counter++;
|
||||
}
|
||||
~RetainGraph() {
|
||||
tracing_counter--;
|
||||
}
|
||||
|
||||
static bool retain_graph() {
|
||||
return tracing_counter > 0;
|
||||
}
|
||||
|
||||
private:
|
||||
static int tracing_counter;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
@ -593,7 +593,454 @@ class PyCheckpointedFun {
|
||||
nb::callable fun_;
|
||||
};
|
||||
|
||||
/**
|
||||
* PyCustomFunction is the class that implements the python decorator
|
||||
* `mx.custom_function`.
|
||||
*
|
||||
* It implements a callable that instead of simply calling `fun` it creates a
|
||||
* CustomTransforms primitive via the `custom_function` C++ op which allows us
|
||||
* to redefine the vjp, jvp and vmap transformations.
|
||||
*
|
||||
* The implementation is verbose due to explicit handling of the destruction of
|
||||
* various python objects to make sure that there is no double-free and that
|
||||
* all of them are deleted while under GIL.
|
||||
*
|
||||
* Namely, for every one of the functions passed to the C++ `custom_function`
|
||||
* we create a callable struct that holds the following python objects (when
|
||||
* needed).
|
||||
*
|
||||
* - An nb::callable which holds the passed function or transform
|
||||
* - An nb::object holding input structure, namely the `(args, kwargs)`
|
||||
* passed to the function in order to be able to recreate the arguments
|
||||
* from the input arrays.
|
||||
* - A std::shared_ptr<nb::object> holding the output structure name the
|
||||
* structure of the return value of `fun`. It is a shared_ptr so that it
|
||||
* can be set when the function is called and then used in the `vjp`
|
||||
* transform. We delete the object only when the shared_ptr is about to be
|
||||
* deleted see `output_structure_.use_count() == 1` to make sure that the
|
||||
* object is deleted under GIL.
|
||||
*/
|
||||
class PyCustomFunction {
|
||||
public:
|
||||
PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}
|
||||
~PyCustomFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
if (vjp_fun_.has_value()) {
|
||||
(*vjp_fun_).release().dec_ref();
|
||||
}
|
||||
if (jvp_fun_.has_value()) {
|
||||
(*jvp_fun_).release().dec_ref();
|
||||
}
|
||||
if (vmap_fun_.has_value()) {
|
||||
(*vmap_fun_).release().dec_ref();
|
||||
}
|
||||
}
|
||||
|
||||
struct InnerFunction {
|
||||
nb::callable fun_;
|
||||
nb::object input_structure_;
|
||||
std::shared_ptr<nb::object> output_structure_;
|
||||
|
||||
InnerFunction(
|
||||
nb::callable fun,
|
||||
nb::object input_structure,
|
||||
std::shared_ptr<nb::object> output_structure)
|
||||
: fun_(std::move(fun)),
|
||||
input_structure_(std::move(input_structure)),
|
||||
output_structure_(std::move(output_structure)) {}
|
||||
~InnerFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
if (output_structure_.use_count() == 1) {
|
||||
output_structure_->release().dec_ref();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> operator()(const std::vector<array>& inputs) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(input_structure_, inputs));
|
||||
std::vector<array> outputs;
|
||||
std::tie(outputs, *output_structure_) =
|
||||
tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
|
||||
return outputs;
|
||||
}
|
||||
};
|
||||
|
||||
struct InnerVJPFunction {
|
||||
nb::callable vjp_fun_;
|
||||
nb::object input_structure_;
|
||||
std::shared_ptr<nb::object> output_structure_;
|
||||
|
||||
InnerVJPFunction(
|
||||
nb::callable vjp_fun,
|
||||
nb::object input_structure,
|
||||
std::shared_ptr<nb::object> output_structure)
|
||||
: vjp_fun_(std::move(vjp_fun)),
|
||||
input_structure_(std::move(input_structure)),
|
||||
output_structure_(std::move(output_structure)) {}
|
||||
~InnerVJPFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
vjp_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
if (output_structure_.use_count() == 1) {
|
||||
output_structure_->release().dec_ref();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> operator()(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<array>& outputs) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(input_structure_, primals));
|
||||
auto args = nb::cast<nb::tuple>(new_inputs[0]);
|
||||
auto new_cotangents =
|
||||
tree_unflatten_from_structure(*output_structure_, cotangents);
|
||||
auto new_outputs =
|
||||
tree_unflatten_from_structure(*output_structure_, outputs);
|
||||
|
||||
if (args.size() == 1) {
|
||||
return tree_flatten(
|
||||
vjp_fun_(args[0], new_cotangents, new_outputs, **new_inputs[1]),
|
||||
false);
|
||||
} else {
|
||||
return tree_flatten(
|
||||
vjp_fun_(args, new_cotangents, new_outputs, **new_inputs[1]),
|
||||
false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct InnerJVPFunction {
|
||||
nb::callable jvp_fun_;
|
||||
nb::object input_structure_;
|
||||
|
||||
InnerJVPFunction(nb::callable jvp_fun, nb::object input_structure)
|
||||
: jvp_fun_(std::move(jvp_fun)),
|
||||
input_structure_(std::move(input_structure)) {}
|
||||
~InnerJVPFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
jvp_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::vector<array> operator()(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(input_structure_, primals));
|
||||
auto args = nb::cast<nb::tuple>(new_inputs[0]);
|
||||
auto kwargs = nb::cast<nb::dict>(new_inputs[1]);
|
||||
if (kwargs.size() > 0) {
|
||||
throw std::invalid_argument(
|
||||
"[custom jvp] Function should only accept positional arguments");
|
||||
}
|
||||
|
||||
// Make a new pytree which has tangents or None when a tangent is not
|
||||
// available.
|
||||
std::vector<bool> have_tangents(primals.size(), false);
|
||||
for (auto arg : argnums) {
|
||||
have_tangents[arg] = true;
|
||||
}
|
||||
int array_index = 0;
|
||||
int tangent_index = 0;
|
||||
auto new_tangents =
|
||||
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||
if (nb::isinstance<array>(element) &&
|
||||
have_tangents[array_index++]) {
|
||||
return nb::cast(tangents[tangent_index++]);
|
||||
} else {
|
||||
return nb::none();
|
||||
}
|
||||
}));
|
||||
|
||||
if (args.size() == 1) {
|
||||
return tree_flatten(jvp_fun_(args[0], new_tangents[0]), false);
|
||||
} else {
|
||||
return tree_flatten(jvp_fun_(args, new_tangents), false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct InnerVmapFunction {
|
||||
nb::callable vmap_fun_;
|
||||
nb::object input_structure_;
|
||||
|
||||
InnerVmapFunction(nb::callable vmap_fun, nb::object input_structure)
|
||||
: vmap_fun_(std::move(vmap_fun)),
|
||||
input_structure_(std::move(input_structure)) {}
|
||||
~InnerVmapFunction() {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
vmap_fun_.release().dec_ref();
|
||||
input_structure_.release().dec_ref();
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> operator()(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto new_inputs = nb::cast<nb::tuple>(
|
||||
tree_unflatten_from_structure(input_structure_, inputs));
|
||||
auto args = nb::cast<nb::tuple>(new_inputs[0]);
|
||||
auto kwargs = nb::cast<nb::dict>(new_inputs[1]);
|
||||
if (kwargs.size() > 0) {
|
||||
throw std::invalid_argument(
|
||||
"[custom vmap] Function should only accept positional arguments");
|
||||
}
|
||||
|
||||
int arr_index;
|
||||
auto new_axes =
|
||||
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
|
||||
int axis = axes[arr_index++];
|
||||
if (nb::isinstance<array>(element) && axis >= 0) {
|
||||
return nb::cast(axis);
|
||||
} else {
|
||||
return nb::none();
|
||||
}
|
||||
}));
|
||||
|
||||
nb::object result;
|
||||
if (args.size() == 1) {
|
||||
result = vmap_fun_(args[0], new_axes[0]);
|
||||
} else {
|
||||
result = vmap_fun_(args, new_axes);
|
||||
}
|
||||
|
||||
if (!nb::isinstance<nb::tuple>(result)) {
|
||||
throw std::invalid_argument(
|
||||
"[custom vmap] Vmap function should return a tuple with 2 items.");
|
||||
}
|
||||
nb::tuple result_tuple = nb::cast<nb::tuple>(result);
|
||||
if (result_tuple.size() != 2) {
|
||||
throw std::invalid_argument(
|
||||
"[custom vmap] Vmap function should return a tuple with 2 items.");
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
std::vector<int> output_axes;
|
||||
tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
|
||||
if (nb::isinstance<array>(objects[0])) {
|
||||
outputs.push_back(nb::cast<array>(objects[0]));
|
||||
output_axes.push_back(
|
||||
objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
|
||||
}
|
||||
});
|
||||
|
||||
return {outputs, output_axes};
|
||||
}
|
||||
};
|
||||
|
||||
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
|
||||
if (!vjp_fun_.has_value() && !jvp_fun_.has_value() &&
|
||||
!vmap_fun_.has_value()) {
|
||||
return fun_(*args, **kwargs);
|
||||
}
|
||||
|
||||
// Extract the inputs and their structure in capturable vars
|
||||
std::vector<array> input_arrays;
|
||||
nb::object input_structure;
|
||||
auto full_args = nb::make_tuple(args, kwargs);
|
||||
std::tie(input_arrays, input_structure) =
|
||||
tree_flatten_with_structure(full_args, false);
|
||||
|
||||
// The output structure will be stored here to be used in the custom vjp
|
||||
// function
|
||||
auto output_structure = std::make_shared<nb::object>();
|
||||
|
||||
// Make a function that calls fun_ in the forward pass and vjp_ in the
|
||||
// backward pass. Then call it immediately and return the results.
|
||||
auto f = custom_function(
|
||||
InnerFunction(fun_, input_structure, output_structure),
|
||||
make_vjp_function(input_structure, output_structure),
|
||||
make_jvp_function(input_structure),
|
||||
make_vmap_function(input_structure));
|
||||
|
||||
auto outputs = f(input_arrays);
|
||||
return tree_unflatten_from_structure(*output_structure, outputs);
|
||||
}
|
||||
|
||||
PyCustomFunction& set_vjp(nb::callable vjp_fun) {
|
||||
vjp_fun_ = vjp_fun;
|
||||
return *this;
|
||||
}
|
||||
|
||||
PyCustomFunction& set_jvp(nb::callable jvp_fun) {
|
||||
jvp_fun_ = jvp_fun;
|
||||
return *this;
|
||||
}
|
||||
|
||||
PyCustomFunction& set_vmap(nb::callable vmap_fun) {
|
||||
vmap_fun_ = vmap_fun;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
std::optional<InnerVJPFunction> make_vjp_function(
|
||||
nb::object input_structure,
|
||||
std::shared_ptr<nb::object> output_structure) {
|
||||
if (!vjp_fun_.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return InnerVJPFunction(*vjp_fun_, input_structure, output_structure);
|
||||
}
|
||||
|
||||
std::optional<InnerJVPFunction> make_jvp_function(
|
||||
nb::object input_structure) {
|
||||
if (!jvp_fun_.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return InnerJVPFunction(*jvp_fun_, input_structure);
|
||||
}
|
||||
|
||||
std::optional<InnerVmapFunction> make_vmap_function(
|
||||
nb::object input_structure) {
|
||||
if (!vmap_fun_.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return InnerVmapFunction(*vmap_fun_, input_structure);
|
||||
}
|
||||
|
||||
nb::callable fun_;
|
||||
std::optional<nb::callable> vjp_fun_;
|
||||
std::optional<nb::callable> jvp_fun_;
|
||||
std::optional<nb::callable> vmap_fun_;
|
||||
};
|
||||
|
||||
void init_transforms(nb::module_& m) {
|
||||
nb::class_<PyCustomFunction>(
|
||||
m,
|
||||
"custom_function",
|
||||
R"pbdoc(
|
||||
Set up a function for custom gradient and vmap definitions.
|
||||
|
||||
This class is meant to be used as a function decorator. Instances are
|
||||
callables that behave identically to the wrapped function. However, when
|
||||
a function transformation is used (e.g. computing gradients using
|
||||
:func:`value_and_grad`) then the functions defined via :method:`vjp`,
|
||||
:method:`jvp` and :method:`vmap` are used instead of the default
|
||||
transformation.
|
||||
|
||||
Note, all custom transformations are optional. Undefined transformations
|
||||
fall back to the default behaviour.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@mx.custom_function
|
||||
def f(x, y):
|
||||
return mx.sin(x) * y
|
||||
|
||||
@f.vjp
|
||||
def f_vjp(primals, cotangent, output):
|
||||
x, y = primals
|
||||
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
|
||||
|
||||
@f.jvp
|
||||
def f_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
dx, dy = tangents
|
||||
return dx * mx.cos(x) * y + dy * mx.sin(x)
|
||||
|
||||
@f.vmap
|
||||
def f_vmap(inputs, axes):
|
||||
x, y = inputs
|
||||
ax, ay = axes
|
||||
if ay != ax and ax is not None:
|
||||
y = y.swapaxes(ay, ax)
|
||||
return mx.sin(x) * y, (ax or ay)
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<nb::callable>(),
|
||||
"f"_a,
|
||||
nb::sig("def __init__(self, f: callable)"))
|
||||
.def("__call__", &PyCustomFunction::call_impl)
|
||||
.def(
|
||||
"vjp",
|
||||
&PyCustomFunction::set_vjp,
|
||||
"f"_a,
|
||||
nb::sig("def vjp(self, f_vjp: callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom vjp for the wrapped function.
|
||||
|
||||
The vjp function takes three arguments:
|
||||
|
||||
- *primals*: A pytree that contains all the positional arguments to
|
||||
the function. It could be a single array, a tuple of arrays or a
|
||||
full blown tuple of dicts of arrays etc.
|
||||
- *cotangents*: A pytree that matches the structure of the output
|
||||
but contains the cotangents (usually the gradients of the loss
|
||||
function with respect to the outputs).
|
||||
- *outputs*: The outputs of the function to be used to avoid
|
||||
recomputing them for the gradient computation.
|
||||
|
||||
The vjp function should return the same pytree structure as the
|
||||
primals but containing the corresponding computed cotangents.
|
||||
)pbdoc")
|
||||
.def(
|
||||
"jvp",
|
||||
&PyCustomFunction::set_jvp,
|
||||
"f"_a,
|
||||
nb::sig("def jvp(self, f_jvp: callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom jvp for the wrapped function.
|
||||
|
||||
The jvp function takes two arguments:
|
||||
|
||||
- *primals*: A pytree that contains all the positional arguments to
|
||||
the function. It could be a single array, a tuple of arrays or a
|
||||
full blown tuple of dicts of arrays etc.
|
||||
- *tangents*: A pytree that matches the structure of the inputs but
|
||||
instead contains the gradients wrt to each input. Tangents could
|
||||
be ``None`` if some inputs don't have an associated gradient.
|
||||
|
||||
The jvp function should return the same pytree structure as the
|
||||
outputs of the function but containing the tangents.
|
||||
)pbdoc")
|
||||
.def(
|
||||
"vmap",
|
||||
&PyCustomFunction::set_vmap,
|
||||
"f"_a,
|
||||
nb::sig("def vmap(self, f_vmap: callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom vectorization transformation for the wrapped function.
|
||||
|
||||
The vmap function takes two arguments:
|
||||
|
||||
- *inputs*: A pytree that contains all the positional arguments to
|
||||
the function. It could be a single array, a tuple of arrays or a
|
||||
full blown tuple of dicts of arrays etc.
|
||||
- *axes*: A pytree that matches the structure of the inputs but
|
||||
instead contains the vectorization axis for each input or
|
||||
``None`` if an input is not vectorized.
|
||||
|
||||
The vmap function should return the outputs of the original
|
||||
function but vectorized over the provided axes. It should also
|
||||
return a pytree with the vectorization axes of each output. If some
|
||||
outputs are no longer vectorized, then their vectorization axis
|
||||
should be ``None``.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"eval",
|
||||
[](const nb::args& args) {
|
||||
@ -888,8 +1335,10 @@ void init_transforms(nb::module_& m) {
|
||||
const nb::object& outputs,
|
||||
bool shapeless) {
|
||||
// Try to get the name
|
||||
auto n = fun.attr("__name__");
|
||||
auto name = n.is_none() ? "compiled" : nb::cast<std::string>(n);
|
||||
auto n =
|
||||
nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none();
|
||||
auto name = n.is_none() ? "compiled"
|
||||
: nb::cast<std::string>(fun.attr("__name__"));
|
||||
|
||||
// Try to get the signature
|
||||
std::ostringstream sig;
|
||||
|
@ -496,6 +496,90 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_custom_function(self):
|
||||
# Make a custom function
|
||||
my_exp = mx.custom_function(mx.exp)
|
||||
|
||||
# Ensure everything works
|
||||
dy = mx.grad(my_exp)(mx.array(1.0))
|
||||
self.assertTrue(mx.allclose(dy, mx.exp(mx.array(1.0))))
|
||||
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
|
||||
self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0))))
|
||||
self.assertTrue(mx.allclose(ex, dex))
|
||||
ex = mx.vmap(my_exp)(mx.ones(10))
|
||||
self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10))))
|
||||
|
||||
# Ensure that the vjp is being overriden but everything else still
|
||||
# works.
|
||||
@my_exp.vjp
|
||||
def my_exp_vjp(x, dx, ex):
|
||||
return mx.ones_like(x) * 42
|
||||
|
||||
dy = mx.grad(my_exp)(mx.array(1.0))
|
||||
self.assertTrue(mx.allclose(dy, mx.array(42.0)))
|
||||
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
|
||||
self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0))))
|
||||
self.assertTrue(mx.allclose(ex, dex))
|
||||
ex = mx.vmap(my_exp)(mx.ones(10))
|
||||
self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10))))
|
||||
|
||||
# Ensure that setting the jvp and vmap also works.
|
||||
@my_exp.jvp
|
||||
def my_exp_jvp(x, dx):
|
||||
return mx.ones_like(x) * 7 * dx
|
||||
|
||||
@my_exp.vmap
|
||||
def my_exp_vmap(x, axis):
|
||||
return mx.ones_like(x) * 3, axis
|
||||
|
||||
dy = mx.grad(my_exp)(mx.array(1.0))
|
||||
self.assertTrue(mx.allclose(dy, mx.array(42.0)))
|
||||
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
|
||||
self.assertTrue(mx.allclose(dex, mx.array(7.0)))
|
||||
self.assertTrue(mx.allclose(ex, mx.exp(mx.array(1.0))))
|
||||
ex = mx.vmap(my_exp)(mx.ones(10))
|
||||
self.assertTrue(mx.allclose(ex, 3 * mx.ones(10)))
|
||||
|
||||
# Test pytrees
|
||||
@mx.custom_function
|
||||
def my_double(params):
|
||||
return {"out": 2 * params["x"] * params["y"]}
|
||||
|
||||
dy = mx.grad(lambda p: my_double(p)["out"].sum())(
|
||||
{"x": mx.ones(2), "y": mx.ones(2)}
|
||||
)
|
||||
self.assertTrue(mx.allclose(dy["x"], mx.ones(2) * 2))
|
||||
self.assertTrue(mx.allclose(dy["y"], mx.ones(2) * 2))
|
||||
|
||||
@my_double.vjp
|
||||
def random_grads(primals, cotangents, outputs):
|
||||
return {"x": mx.zeros_like(primals["x"]), "y": mx.ones_like(primals["y"])}
|
||||
|
||||
dy = mx.grad(lambda p: my_double(p)["out"].sum())(
|
||||
{"x": mx.ones(2), "y": mx.ones(2)}
|
||||
)
|
||||
self.assertTrue(mx.allclose(dy["x"], mx.zeros(2)))
|
||||
self.assertTrue(mx.allclose(dy["y"], mx.ones(2)))
|
||||
|
||||
def outer_f(a, b):
|
||||
return my_double({"x": a, "y": b})["out"]
|
||||
|
||||
inputs = [mx.random.normal(shape=(2,)) for i in range(2)]
|
||||
tans = [mx.random.normal(shape=(2,)) for i in range(2)]
|
||||
out1, dout1 = mx.jvp(outer_f, inputs, tans)
|
||||
|
||||
@my_double.jvp
|
||||
def random_grads(primals, tangents):
|
||||
return {
|
||||
"out": 2 * primals["x"] * tangents["y"]
|
||||
+ 2 * primals["y"] * tangents["x"]
|
||||
+ 1
|
||||
}
|
||||
|
||||
out2, dout2 = mx.jvp(outer_f, inputs, tans)
|
||||
self.assertTrue(mx.allclose(out1[0], out2[0]))
|
||||
self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user