diff --git a/docs/src/python/transforms.rst b/docs/src/python/transforms.rst index ad9ba579b..fbdfd4f08 100644 --- a/docs/src/python/transforms.rst +++ b/docs/src/python/transforms.rst @@ -10,6 +10,7 @@ Transforms eval compile + custom_function disable_compile enable_compile grad diff --git a/mlx/array.cpp b/mlx/array.cpp index f4c5700c6..29a86aaa0 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -17,6 +17,10 @@ bool in_tracing() { return detail::InTracing::in_tracing(); } +bool retain_graph() { + return detail::RetainGraph::retain_graph(); +} + } // namespace array::array(const std::complex& val, Dtype dtype /* = complex64 */) @@ -102,7 +106,7 @@ void array::eval() { } bool array::is_tracer() const { - return array_desc_->is_tracer && in_tracing(); + return array_desc_->is_tracer && in_tracing() || retain_graph(); } void array::set_data(allocator::Buffer buffer, deleter_t d) { diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 6de0a6416..eee93f2ab 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -36,7 +36,7 @@ DEFAULT(Ceil) DEFAULT(Concatenate) DEFAULT(Conjugate) DEFAULT(Copy) -DEFAULT_MULTI(CustomVJP) +DEFAULT_MULTI(CustomTransforms) DEFAULT_MULTI(Depends) DEFAULT_MULTI(DivMod) DEFAULT(NumberOfElements) diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 08dc34037..6fb0e9edb 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -66,7 +66,7 @@ void Copy::eval(const std::vector& inputs, array& out) { out.copy_shared_buffer(inputs[0]); } -void CustomVJP::eval( +void CustomTransforms::eval( const std::vector& inputs, std::vector& outputs) { assert(inputs.size() > outputs.size()); diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index c5b5e44b8..f8932c5f8 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -52,7 +52,7 @@ DEFAULT(Convolution) DEFAULT(Copy) DEFAULT(Cos) DEFAULT(Cosh) -DEFAULT_MULTI(CustomVJP) +DEFAULT_MULTI(CustomTransforms) DEFAULT_MULTI(Depends) DEFAULT(Divide) DEFAULT(NumberOfElements) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 64896c7e1..1b7f8122a 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -171,7 +171,7 @@ void Copy::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } -void CustomVJP::eval_gpu( +void CustomTransforms::eval_gpu( const std::vector& inputs, std::vector& outputs) { eval(inputs, outputs); diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 60c731930..ff60e4d22 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -42,7 +42,7 @@ NO_CPU(Convolution) NO_CPU(Copy) NO_CPU(Cos) NO_CPU(Cosh) -NO_CPU_MULTI(CustomVJP) +NO_CPU_MULTI(CustomTransforms) NO_CPU_MULTI(Depends) NO_CPU(Divide) NO_CPU_MULTI(DivMod) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 7dce75f7d..1410c92e8 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -43,7 +43,7 @@ NO_GPU(Convolution) NO_GPU(Copy) NO_GPU(Cos) NO_GPU(Cosh) -NO_GPU_MULTI(CustomVJP) +NO_GPU_MULTI(CustomTransforms) NO_GPU_MULTI(Depends) NO_GPU(Divide) NO_GPU_MULTI(DivMod) diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 0a0bd2066..135c79490 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -18,7 +18,7 @@ std::vector Custom::vjp( auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents); std::vector vjp_outs; for (int i = 0, j = 0; i < vjps.size(); ++i) { - if (i < argnums.size() && i == argnums[j]) { + if (j < argnums.size() && i == argnums[j]) { vjp_outs.push_back(vjps[i]); j++; } @@ -30,15 +30,16 @@ std::vector Custom::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents); - std::vector jvp_outs; - for (int i = 0, j = 0; i < jvps.size(); ++i) { - if (i < argnums.size() && i == argnums[j]) { - jvp_outs.push_back(jvps[i]); - j++; + std::vector all_tangents; + for (int i = 0, j = 0; i < primals.size(); i++) { + if (j < argnums.size() && i == argnums[j]) { + all_tangents.emplace_back(tangents[j++]); + } else { + all_tangents.emplace_back(zeros_like(primals[i])); } } - return jvp_outs; + auto [_, jvps] = mlx::core::jvp(fallback_, primals, all_tangents); + return jvps; } std::pair, std::vector> Custom::vmap( diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 8b2833ad4..289bd6053 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1113,17 +1113,21 @@ std::pair, std::vector> Cosh::vmap( return {{cosh(inputs[0], stream())}, axes}; } -std::vector CustomVJP::vjp( +std::vector CustomTransforms::vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) { - std::vector inputs(primals.begin(), primals.end() - outputs.size()); + // Extract the inputs to the VJP function + std::vector inputs(primals.begin(), primals.end() - num_outputs_); + + // Compute all the vjps auto all_vjps = vjp_fun_(inputs, cotangents, outputs); for (const auto& cot : cotangents) { all_vjps.emplace_back(cot); } + // Select the vjps requested std::vector vjps; vjps.reserve(argnums.size()); for (auto arg : argnums) { @@ -1133,6 +1137,26 @@ std::vector CustomVJP::vjp( return vjps; } +std::vector CustomTransforms::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + // Extract the inputs to the JVP function + std::vector inputs(primals.begin(), primals.end() - num_outputs_); + + // Compute the jvps + return jvp_fun_(inputs, tangents, argnums); +} + +std::pair, std::vector> CustomTransforms::vmap( + const std::vector& inputs_, + const std::vector& axes_) { + // Extract the inputs to the vmap function + std::vector inputs(inputs_.begin(), inputs_.end() - num_outputs_); + std::vector axes(axes_.begin(), axes_.end() - num_outputs_); + return vmap_fun_(inputs, axes); +} + std::vector Depends::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 4255cd4b1..c7fb29f1f 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -729,37 +729,56 @@ class Cosh : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; -class CustomVJP : public Primitive { +class CustomTransforms : public Primitive { public: - explicit CustomVJP( + explicit CustomTransforms( Stream stream, + int num_outputs, std::function( const std::vector&, const std::vector&, - const std::vector&)> fun) - : Primitive(stream), vjp_fun_(std::move(fun)) {} + const std::vector&)> vjp, + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> jvp, + std::function, std::vector>( + const std::vector&, + const std::vector&)> vmap) + : Primitive(stream), + num_outputs_(num_outputs), + vjp_fun_(std::move(vjp)), + jvp_fun_(std::move(jvp)), + vmap_fun_(std::move(vmap)) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - std::vector vjp( - const std::vector& primals, - const std::vector& cotan, - const std::vector& argnums, - const std::vector& outputs) override; - - DEFINE_PRINT(CustomVJP); + DEFINE_GRADS(); + DEFINE_VMAP(); + DEFINE_PRINT(CustomTransforms); private: void eval(const std::vector& inputs, std::vector& outputs); + int num_outputs_; + std::function( const std::vector&, const std::vector&, const std::vector&)> vjp_fun_; + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> + jvp_fun_; + std::function, std::vector>( + const std::vector&, + const std::vector&)> + vmap_fun_; }; class Depends : public Primitive { diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 39ba5304b..0624358e2 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -33,8 +33,11 @@ class Synchronizer : public Primitive { // Initialize the static tracing counter from transforms_impl.h . // // This is used to implement the in_tracing() function the returns true if we -// are currently under a function transformation. +// are currently under a function transformation and the retain_graph() +// function which returns true if we are forced to retain the graph during +// evaluation. int detail::InTracing::tracing_counter{0}; +int detail::RetainGraph::tracing_counter{0}; array eval_impl(std::vector outputs, bool async) { std::queue tape; @@ -331,7 +334,11 @@ std::pair, std::vector> vjp( } } - auto vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs); + std::vector vjps; + { + detail::RetainGraph retain; + vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs); + } // Accumulate the vector-jacobian products for each input for (int i = 0; i < argnums.size(); ++i) { auto in_id = a.inputs()[argnums[i]].id(); @@ -778,14 +785,27 @@ std::function vmap( return [vfun](const array& a) { return vfun({a})[0]; }; } -std::function(const std::vector&)> custom_vjp( +std::function(const std::vector&)> custom_function( std::function(const std::vector&)> fun, - std::function( + std::optional( const std::vector&, const std::vector&, - const std::vector&)> fun_vjp) { + const std::vector&)>> fun_vjp /* = std::nullopt */, + std::optional( + const std::vector&, + const std::vector&, + const std::vector&)>> fun_jvp /* = std::nullopt */, + std::optional, std::vector>( + const std::vector&, + const std::vector&)>> fun_vmap /* = std::nullopt */) { + if (!fun_vjp.has_value() && !fun_jvp.has_value() && !fun_vmap.has_value()) { + return fun; + } + return [fun = std::move(fun), - fun_vjp = std::move(fun_vjp)](const std::vector& args) { + fun_vjp = std::move(fun_vjp), + fun_jvp = std::move(fun_jvp), + fun_vmap = std::move(fun_vmap)](const std::vector& args) { // Compute the outputs auto outputs = fun(args); for (auto& out : outputs) { @@ -814,11 +834,63 @@ std::function(const std::vector&)> custom_vjp( return array::make_arrays( std::move(shapes), dtypes, - std::make_shared(to_stream(s), fun_vjp), + std::make_shared( + to_stream(s), + outputs.size(), + + // We use the passed vjp function or compute it from the inputs and + // passed cotangents. Note that this may be less efficient than + // using `fun` directly because we may not be able to fully reuse + // the outputs of the forward pass. + fun_vjp.value_or( + [fun](auto primals, auto cotangents, auto outputs) { + auto [__, vjps] = vjp(fun, primals, cotangents); + return vjps; + }), + + // We use the passed jvp function or compute it from the primals + // and tangents. Similarly we can't take full advantage of the + // argnums so it is best to use `fun` directly if we don't need a + // custom transform. + // + // TODO: Use stop_gradient to make full use of argnums and not + // waste computation. + fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) { + std::vector all_tangents; + for (int i = 0, j = 0; i < primals.size(); i++) { + if (j < argnums.size() && i == argnums[j]) { + all_tangents.emplace_back(tangents[j++]); + } else { + all_tangents.emplace_back(zeros_like(primals[i])); + } + } + auto [__, jvps] = jvp(fun, primals, all_tangents); + return jvps; + }), + + // Same as above, we use the passed vmap function or we compute it + // from `fun`. The output axes is selected to be all 0s which again + // may be suboptimal but the only thing we can do without any + // information for `fun`. + fun_vmap.value_or( + [fun, out_size = outputs.size()](auto inputs, auto in_axes) + -> std::pair, std::vector> { + std::vector out_axes(out_size, 0); + return {vmap(fun, in_axes, out_axes)(inputs), out_axes}; + })), inputs); }; } +std::function(const std::vector&)> custom_vjp( + std::function(const std::vector&)> fun, + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> fun_vjp) { + return custom_function(fun, fun_vjp, std::nullopt, std::nullopt); +} + std::function(const std::vector&)> checkpoint( std::function(const std::vector&)> fun) { auto vjp_fun = [fun]( diff --git a/mlx/transforms.h b/mlx/transforms.h index d64f6060e..a1a57d493 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -2,6 +2,8 @@ #pragma once +#include + #include "mlx/array.h" namespace mlx::core { @@ -179,8 +181,31 @@ std::function(const std::vector&)> vmap( const std::vector& out_axes = {}); /** - * Return the results of calling fun with args but if their vjp is computed it - * will be computed by fun_vjp. + * Redefine the transformations of `fun` according to the provided functions. + * + * Namely when calling the vjp of `fun` then `fun_vjp` will be called, + * `fun_jvp` for the jvp and `fun_vmap` for vmap. + * + * If any transformation is not provided, then a default one is created by + * calling `vjp`, `jvp` and `vmap` on the function directly. + */ +std::function(const std::vector&)> custom_function( + std::function(const std::vector&)> fun, + std::optional( + const std::vector&, + const std::vector&, + const std::vector&)>> fun_vjp = std::nullopt, + std::optional( + const std::vector&, + const std::vector&, + const std::vector&)>> fun_jvp = std::nullopt, + std::optional, std::vector>( + const std::vector&, + const std::vector&)>> fun_vmap = std::nullopt); + +/** + * Return a function that behaves exactly like `fun` but if the vjp of the + * results is computed `fun_vjp` will be used instead of `vjp(fun, ...)` . */ std::function(const std::vector&)> custom_vjp( std::function(const std::vector&)> fun, diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 76dc0ad84..6f67305e8 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -50,4 +50,20 @@ struct InTracing { static int tracing_counter; }; +struct RetainGraph { + RetainGraph() { + tracing_counter++; + } + ~RetainGraph() { + tracing_counter--; + } + + static bool retain_graph() { + return tracing_counter > 0; + } + + private: + static int tracing_counter; +}; + } // namespace mlx::core::detail diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index d14c41369..e5acf94e4 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -593,7 +593,454 @@ class PyCheckpointedFun { nb::callable fun_; }; +/** + * PyCustomFunction is the class that implements the python decorator + * `mx.custom_function`. + * + * It implements a callable that instead of simply calling `fun` it creates a + * CustomTransforms primitive via the `custom_function` C++ op which allows us + * to redefine the vjp, jvp and vmap transformations. + * + * The implementation is verbose due to explicit handling of the destruction of + * various python objects to make sure that there is no double-free and that + * all of them are deleted while under GIL. + * + * Namely, for every one of the functions passed to the C++ `custom_function` + * we create a callable struct that holds the following python objects (when + * needed). + * + * - An nb::callable which holds the passed function or transform + * - An nb::object holding input structure, namely the `(args, kwargs)` + * passed to the function in order to be able to recreate the arguments + * from the input arrays. + * - A std::shared_ptr holding the output structure name the + * structure of the return value of `fun`. It is a shared_ptr so that it + * can be set when the function is called and then used in the `vjp` + * transform. We delete the object only when the shared_ptr is about to be + * deleted see `output_structure_.use_count() == 1` to make sure that the + * object is deleted under GIL. + */ +class PyCustomFunction { + public: + PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {} + ~PyCustomFunction() { + nb::gil_scoped_acquire gil; + + fun_.release().dec_ref(); + if (vjp_fun_.has_value()) { + (*vjp_fun_).release().dec_ref(); + } + if (jvp_fun_.has_value()) { + (*jvp_fun_).release().dec_ref(); + } + if (vmap_fun_.has_value()) { + (*vmap_fun_).release().dec_ref(); + } + } + + struct InnerFunction { + nb::callable fun_; + nb::object input_structure_; + std::shared_ptr output_structure_; + + InnerFunction( + nb::callable fun, + nb::object input_structure, + std::shared_ptr output_structure) + : fun_(std::move(fun)), + input_structure_(std::move(input_structure)), + output_structure_(std::move(output_structure)) {} + ~InnerFunction() { + nb::gil_scoped_acquire gil; + + fun_.release().dec_ref(); + input_structure_.release().dec_ref(); + if (output_structure_.use_count() == 1) { + output_structure_->release().dec_ref(); + } + } + + std::vector operator()(const std::vector& inputs) { + nb::gil_scoped_acquire gil; + + auto new_inputs = nb::cast( + tree_unflatten_from_structure(input_structure_, inputs)); + std::vector outputs; + std::tie(outputs, *output_structure_) = + tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1])); + return outputs; + } + }; + + struct InnerVJPFunction { + nb::callable vjp_fun_; + nb::object input_structure_; + std::shared_ptr output_structure_; + + InnerVJPFunction( + nb::callable vjp_fun, + nb::object input_structure, + std::shared_ptr output_structure) + : vjp_fun_(std::move(vjp_fun)), + input_structure_(std::move(input_structure)), + output_structure_(std::move(output_structure)) {} + ~InnerVJPFunction() { + nb::gil_scoped_acquire gil; + + vjp_fun_.release().dec_ref(); + input_structure_.release().dec_ref(); + if (output_structure_.use_count() == 1) { + output_structure_->release().dec_ref(); + } + } + + std::vector operator()( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& outputs) { + nb::gil_scoped_acquire gil; + + auto new_inputs = nb::cast( + tree_unflatten_from_structure(input_structure_, primals)); + auto args = nb::cast(new_inputs[0]); + auto new_cotangents = + tree_unflatten_from_structure(*output_structure_, cotangents); + auto new_outputs = + tree_unflatten_from_structure(*output_structure_, outputs); + + if (args.size() == 1) { + return tree_flatten( + vjp_fun_(args[0], new_cotangents, new_outputs, **new_inputs[1]), + false); + } else { + return tree_flatten( + vjp_fun_(args, new_cotangents, new_outputs, **new_inputs[1]), + false); + } + } + }; + + struct InnerJVPFunction { + nb::callable jvp_fun_; + nb::object input_structure_; + + InnerJVPFunction(nb::callable jvp_fun, nb::object input_structure) + : jvp_fun_(std::move(jvp_fun)), + input_structure_(std::move(input_structure)) {} + ~InnerJVPFunction() { + nb::gil_scoped_acquire gil; + + jvp_fun_.release().dec_ref(); + input_structure_.release().dec_ref(); + } + + std::vector operator()( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + nb::gil_scoped_acquire gil; + + auto new_inputs = nb::cast( + tree_unflatten_from_structure(input_structure_, primals)); + auto args = nb::cast(new_inputs[0]); + auto kwargs = nb::cast(new_inputs[1]); + if (kwargs.size() > 0) { + throw std::invalid_argument( + "[custom jvp] Function should only accept positional arguments"); + } + + // Make a new pytree which has tangents or None when a tangent is not + // available. + std::vector have_tangents(primals.size(), false); + for (auto arg : argnums) { + have_tangents[arg] = true; + } + int array_index = 0; + int tangent_index = 0; + auto new_tangents = + nb::cast(tree_map(args, [&](nb::handle element) { + if (nb::isinstance(element) && + have_tangents[array_index++]) { + return nb::cast(tangents[tangent_index++]); + } else { + return nb::none(); + } + })); + + if (args.size() == 1) { + return tree_flatten(jvp_fun_(args[0], new_tangents[0]), false); + } else { + return tree_flatten(jvp_fun_(args, new_tangents), false); + } + } + }; + + struct InnerVmapFunction { + nb::callable vmap_fun_; + nb::object input_structure_; + + InnerVmapFunction(nb::callable vmap_fun, nb::object input_structure) + : vmap_fun_(std::move(vmap_fun)), + input_structure_(std::move(input_structure)) {} + ~InnerVmapFunction() { + nb::gil_scoped_acquire gil; + + vmap_fun_.release().dec_ref(); + input_structure_.release().dec_ref(); + } + + std::pair, std::vector> operator()( + const std::vector& inputs, + const std::vector& axes) { + nb::gil_scoped_acquire gil; + + auto new_inputs = nb::cast( + tree_unflatten_from_structure(input_structure_, inputs)); + auto args = nb::cast(new_inputs[0]); + auto kwargs = nb::cast(new_inputs[1]); + if (kwargs.size() > 0) { + throw std::invalid_argument( + "[custom vmap] Function should only accept positional arguments"); + } + + int arr_index; + auto new_axes = + nb::cast(tree_map(args, [&](nb::handle element) { + int axis = axes[arr_index++]; + if (nb::isinstance(element) && axis >= 0) { + return nb::cast(axis); + } else { + return nb::none(); + } + })); + + nb::object result; + if (args.size() == 1) { + result = vmap_fun_(args[0], new_axes[0]); + } else { + result = vmap_fun_(args, new_axes); + } + + if (!nb::isinstance(result)) { + throw std::invalid_argument( + "[custom vmap] Vmap function should return a tuple with 2 items."); + } + nb::tuple result_tuple = nb::cast(result); + if (result_tuple.size() != 2) { + throw std::invalid_argument( + "[custom vmap] Vmap function should return a tuple with 2 items."); + } + + std::vector outputs; + std::vector output_axes; + tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) { + if (nb::isinstance(objects[0])) { + outputs.push_back(nb::cast(objects[0])); + output_axes.push_back( + objects[1].is_none() ? -1 : nb::cast(objects[1])); + } + }); + + return {outputs, output_axes}; + } + }; + + nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { + if (!vjp_fun_.has_value() && !jvp_fun_.has_value() && + !vmap_fun_.has_value()) { + return fun_(*args, **kwargs); + } + + // Extract the inputs and their structure in capturable vars + std::vector input_arrays; + nb::object input_structure; + auto full_args = nb::make_tuple(args, kwargs); + std::tie(input_arrays, input_structure) = + tree_flatten_with_structure(full_args, false); + + // The output structure will be stored here to be used in the custom vjp + // function + auto output_structure = std::make_shared(); + + // Make a function that calls fun_ in the forward pass and vjp_ in the + // backward pass. Then call it immediately and return the results. + auto f = custom_function( + InnerFunction(fun_, input_structure, output_structure), + make_vjp_function(input_structure, output_structure), + make_jvp_function(input_structure), + make_vmap_function(input_structure)); + + auto outputs = f(input_arrays); + return tree_unflatten_from_structure(*output_structure, outputs); + } + + PyCustomFunction& set_vjp(nb::callable vjp_fun) { + vjp_fun_ = vjp_fun; + return *this; + } + + PyCustomFunction& set_jvp(nb::callable jvp_fun) { + jvp_fun_ = jvp_fun; + return *this; + } + + PyCustomFunction& set_vmap(nb::callable vmap_fun) { + vmap_fun_ = vmap_fun; + return *this; + } + + private: + std::optional make_vjp_function( + nb::object input_structure, + std::shared_ptr output_structure) { + if (!vjp_fun_.has_value()) { + return std::nullopt; + } + + return InnerVJPFunction(*vjp_fun_, input_structure, output_structure); + } + + std::optional make_jvp_function( + nb::object input_structure) { + if (!jvp_fun_.has_value()) { + return std::nullopt; + } + + return InnerJVPFunction(*jvp_fun_, input_structure); + } + + std::optional make_vmap_function( + nb::object input_structure) { + if (!vmap_fun_.has_value()) { + return std::nullopt; + } + + return InnerVmapFunction(*vmap_fun_, input_structure); + } + + nb::callable fun_; + std::optional vjp_fun_; + std::optional jvp_fun_; + std::optional vmap_fun_; +}; + void init_transforms(nb::module_& m) { + nb::class_( + m, + "custom_function", + R"pbdoc( + Set up a function for custom gradient and vmap definitions. + + This class is meant to be used as a function decorator. Instances are + callables that behave identically to the wrapped function. However, when + a function transformation is used (e.g. computing gradients using + :func:`value_and_grad`) then the functions defined via :method:`vjp`, + :method:`jvp` and :method:`vmap` are used instead of the default + transformation. + + Note, all custom transformations are optional. Undefined transformations + fall back to the default behaviour. + + Example usage: + + .. code-block:: python + + import mlx.core as mx + + @mx.custom_function + def f(x, y): + return mx.sin(x) * y + + @f.vjp + def f_vjp(primals, cotangent, output): + x, y = primals + return cotan * mx.cos(x) * y, cotan * mx.sin(x) + + @f.jvp + def f_jvp(primals, tangents): + x, y = primals + dx, dy = tangents + return dx * mx.cos(x) * y + dy * mx.sin(x) + + @f.vmap + def f_vmap(inputs, axes): + x, y = inputs + ax, ay = axes + if ay != ax and ax is not None: + y = y.swapaxes(ay, ax) + return mx.sin(x) * y, (ax or ay) + )pbdoc") + .def( + nb::init(), + "f"_a, + nb::sig("def __init__(self, f: callable)")) + .def("__call__", &PyCustomFunction::call_impl) + .def( + "vjp", + &PyCustomFunction::set_vjp, + "f"_a, + nb::sig("def vjp(self, f_vjp: callable)"), + R"pbdoc( + Define a custom vjp for the wrapped function. + + The vjp function takes three arguments: + + - *primals*: A pytree that contains all the positional arguments to + the function. It could be a single array, a tuple of arrays or a + full blown tuple of dicts of arrays etc. + - *cotangents*: A pytree that matches the structure of the output + but contains the cotangents (usually the gradients of the loss + function with respect to the outputs). + - *outputs*: The outputs of the function to be used to avoid + recomputing them for the gradient computation. + + The vjp function should return the same pytree structure as the + primals but containing the corresponding computed cotangents. + )pbdoc") + .def( + "jvp", + &PyCustomFunction::set_jvp, + "f"_a, + nb::sig("def jvp(self, f_jvp: callable)"), + R"pbdoc( + Define a custom jvp for the wrapped function. + + The jvp function takes two arguments: + + - *primals*: A pytree that contains all the positional arguments to + the function. It could be a single array, a tuple of arrays or a + full blown tuple of dicts of arrays etc. + - *tangents*: A pytree that matches the structure of the inputs but + instead contains the gradients wrt to each input. Tangents could + be ``None`` if some inputs don't have an associated gradient. + + The jvp function should return the same pytree structure as the + outputs of the function but containing the tangents. + )pbdoc") + .def( + "vmap", + &PyCustomFunction::set_vmap, + "f"_a, + nb::sig("def vmap(self, f_vmap: callable)"), + R"pbdoc( + Define a custom vectorization transformation for the wrapped function. + + The vmap function takes two arguments: + + - *inputs*: A pytree that contains all the positional arguments to + the function. It could be a single array, a tuple of arrays or a + full blown tuple of dicts of arrays etc. + - *axes*: A pytree that matches the structure of the inputs but + instead contains the vectorization axis for each input or + ``None`` if an input is not vectorized. + + The vmap function should return the outputs of the original + function but vectorized over the provided axes. It should also + return a pytree with the vectorization axes of each output. If some + outputs are no longer vectorized, then their vectorization axis + should be ``None``. + )pbdoc"); + m.def( "eval", [](const nb::args& args) { @@ -888,8 +1335,10 @@ void init_transforms(nb::module_& m) { const nb::object& outputs, bool shapeless) { // Try to get the name - auto n = fun.attr("__name__"); - auto name = n.is_none() ? "compiled" : nb::cast(n); + auto n = + nb::hasattr(fun, "__name__") ? fun.attr("__name__") : nb::none(); + auto name = n.is_none() ? "compiled" + : nb::cast(fun.attr("__name__")); // Try to get the signature std::ostringstream sig; diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index f5e49f402..7db0d49ff 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -496,6 +496,90 @@ class TestAutograd(mlx_tests.MLXTestCase): expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0]) self.assertTrue(mx.allclose(out, expected)) + def test_custom_function(self): + # Make a custom function + my_exp = mx.custom_function(mx.exp) + + # Ensure everything works + dy = mx.grad(my_exp)(mx.array(1.0)) + self.assertTrue(mx.allclose(dy, mx.exp(mx.array(1.0)))) + (ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)]) + self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0)))) + self.assertTrue(mx.allclose(ex, dex)) + ex = mx.vmap(my_exp)(mx.ones(10)) + self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10)))) + + # Ensure that the vjp is being overriden but everything else still + # works. + @my_exp.vjp + def my_exp_vjp(x, dx, ex): + return mx.ones_like(x) * 42 + + dy = mx.grad(my_exp)(mx.array(1.0)) + self.assertTrue(mx.allclose(dy, mx.array(42.0))) + (ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)]) + self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0)))) + self.assertTrue(mx.allclose(ex, dex)) + ex = mx.vmap(my_exp)(mx.ones(10)) + self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10)))) + + # Ensure that setting the jvp and vmap also works. + @my_exp.jvp + def my_exp_jvp(x, dx): + return mx.ones_like(x) * 7 * dx + + @my_exp.vmap + def my_exp_vmap(x, axis): + return mx.ones_like(x) * 3, axis + + dy = mx.grad(my_exp)(mx.array(1.0)) + self.assertTrue(mx.allclose(dy, mx.array(42.0))) + (ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)]) + self.assertTrue(mx.allclose(dex, mx.array(7.0))) + self.assertTrue(mx.allclose(ex, mx.exp(mx.array(1.0)))) + ex = mx.vmap(my_exp)(mx.ones(10)) + self.assertTrue(mx.allclose(ex, 3 * mx.ones(10))) + + # Test pytrees + @mx.custom_function + def my_double(params): + return {"out": 2 * params["x"] * params["y"]} + + dy = mx.grad(lambda p: my_double(p)["out"].sum())( + {"x": mx.ones(2), "y": mx.ones(2)} + ) + self.assertTrue(mx.allclose(dy["x"], mx.ones(2) * 2)) + self.assertTrue(mx.allclose(dy["y"], mx.ones(2) * 2)) + + @my_double.vjp + def random_grads(primals, cotangents, outputs): + return {"x": mx.zeros_like(primals["x"]), "y": mx.ones_like(primals["y"])} + + dy = mx.grad(lambda p: my_double(p)["out"].sum())( + {"x": mx.ones(2), "y": mx.ones(2)} + ) + self.assertTrue(mx.allclose(dy["x"], mx.zeros(2))) + self.assertTrue(mx.allclose(dy["y"], mx.ones(2))) + + def outer_f(a, b): + return my_double({"x": a, "y": b})["out"] + + inputs = [mx.random.normal(shape=(2,)) for i in range(2)] + tans = [mx.random.normal(shape=(2,)) for i in range(2)] + out1, dout1 = mx.jvp(outer_f, inputs, tans) + + @my_double.jvp + def random_grads(primals, tangents): + return { + "out": 2 * primals["x"] * tangents["y"] + + 2 * primals["y"] * tangents["x"] + + 1 + } + + out2, dout2 = mx.jvp(outer_f, inputs, tans) + self.assertTrue(mx.allclose(out1[0], out2[0])) + self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0])) + if __name__ == "__main__": unittest.main()