diff --git a/mlx/array.cpp b/mlx/array.cpp index 2362b586f..59902c86d 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include @@ -97,11 +97,13 @@ void array::detach() { s.array_desc_->inputs.clear(); s.array_desc_->siblings.clear(); s.array_desc_->position = 0; + s.array_desc_->depth = 0; s.array_desc_->primitive = nullptr; } array_desc_->inputs.clear(); array_desc_->siblings.clear(); array_desc_->position = 0; + array_desc_->depth = 0; array_desc_->primitive = nullptr; } @@ -180,7 +182,9 @@ array::ArrayDesc::ArrayDesc( std::tie(size, strides) = cum_prod(this->shape); for (auto& in : inputs) { is_tracer |= in.is_tracer(); + depth = std::max(in.graph_depth(), depth); } + depth++; } array::ArrayDesc::ArrayDesc( @@ -195,7 +199,9 @@ array::ArrayDesc::ArrayDesc( std::tie(size, strides) = cum_prod(this->shape); for (auto& in : inputs) { is_tracer |= in.is_tracer(); + depth = std::max(in.graph_depth(), depth); } + depth++; } array::ArrayIterator::ArrayIterator(const array& arr, int idx) diff --git a/mlx/array.h b/mlx/array.h index 6e8375a71..2b849a7ae 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -267,6 +267,11 @@ class array { return outputs; }; + /** The depth of the array in the graph. Evaluated arrays have depth 0. */ + uint16_t graph_depth() const { + return array_desc_->depth; + } + /** Detach the array from the graph. */ void detach(); @@ -377,6 +382,9 @@ class array { // The arrays position in the output list uint32_t position{0}; + // The depth of the array in the graph. + uint16_t depth{0}; + explicit ArrayDesc(const std::vector& shape, Dtype dtype); explicit ArrayDesc( diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index bee736b50..b24a2bc74 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -35,6 +35,8 @@ DEFAULT(Broadcast) DEFAULT(Ceil) DEFAULT(Concatenate) DEFAULT(Copy) +DEFAULT_MULTI(CustomVJP) +DEFAULT_MULTI(Depends) DEFAULT(Equal) DEFAULT(Erf) DEFAULT(ErfInv) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 622f33fd4..bcb8bd7b4 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #ifdef ACCELERATE_NEW_LAPACK #include @@ -47,6 +47,8 @@ DEFAULT(Convolution) DEFAULT(Copy) DEFAULT(Cos) DEFAULT(Cosh) +DEFAULT_MULTI(CustomVJP) +DEFAULT_MULTI(Depends) DEFAULT(Divide) DEFAULT(Remainder) DEFAULT(Equal) diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index ce36a8d5d..37c61761a 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -232,6 +232,25 @@ void Cosh::eval(const std::vector& inputs, array& out) { } } +void CustomVJP::eval( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() > outputs.size()); + for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); + i++, j++) { + outputs[i].copy_shared_buffer(inputs[j]); + } +} + +void Depends::eval( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() > outputs.size()); + for (int i = 0; i < outputs.size(); i++) { + outputs[i].copy_shared_buffer(inputs[i]); + } +} + void Erf::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/common/quantized.cpp index 8482468a0..be55cca29 100644 --- a/mlx/backend/common/quantized.cpp +++ b/mlx/backend/common/quantized.cpp @@ -1,7 +1,6 @@ // Copyright © 2023 Apple Inc. #include -#include #include "mlx/backend/metal/copy.h" #include "mlx/primitives.h" diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index d632556f0..9fbbffb33 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 6d48f07a7..c0b3cb19b 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -615,7 +615,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { } void AddMM::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 2); + assert(inputs.size() == 3); if (!is_floating_point(out.dtype())) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 9dabe1f3f..757888a66 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -486,6 +486,18 @@ void Cosh::eval_gpu(const std::vector& inputs, array& out) { unary_op(inputs, out, "cosh"); } +void CustomVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + eval(inputs, outputs); +} + +void Depends::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + eval(inputs, outputs); +} + void Divide::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, "div"); } diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 34c3e6d4f..7754711e3 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 034fba760..477ceff20 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include "mlx/primitives.h" @@ -37,6 +37,8 @@ NO_GPU(Convolution) NO_GPU(Copy) NO_GPU(Cos) NO_GPU(Cosh) +NO_GPU_MULTI(CustomVJP) +NO_GPU_MULTI(Depends) NO_GPU(Divide) NO_GPU(Remainder) NO_GPU(Equal) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 57ebda494..31324f3cc 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1,4 +1,5 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. + #include #include #include @@ -7,6 +8,7 @@ #include "mlx/ops.h" #include "mlx/primitives.h" +#include "mlx/transforms.h" #include "mlx/utils.h" namespace mlx::core { @@ -3327,4 +3329,26 @@ array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) { } } +std::vector depends( + const std::vector& inputs, + const std::vector& dependencies) { + std::vector all_inputs = inputs; + all_inputs.insert(all_inputs.end(), dependencies.begin(), dependencies.end()); + + // Compute the stream. Maybe do it in a smarter way at some point in the + // future. + Stream s = (inputs[0].has_primitive()) ? inputs[0].primitive().stream() + : to_stream({}); + // Make the output info + std::vector> shapes; + std::vector dtypes; + for (const auto& in : inputs) { + shapes.emplace_back(in.shape()); + dtypes.emplace_back(in.dtype()); + } + + return array::make_arrays( + shapes, dtypes, std::make_shared(to_stream(s)), all_inputs); +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 506a29d84..a4b1dd1ef 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1116,4 +1116,13 @@ array diagonal( /** Extract diagonal from a 2d array or create a diagonal matrix. */ array diag(const array& a, int k = 0, StreamOrDevice s = {}); +/** + * Implements the identity function but allows injecting dependencies to other + * arrays. This ensures that these other arrays will have been computed + * when the outputs of this function are computed. + */ +std::vector depends( + const std::vector& inputs, + const std::vector& dependencies); + } // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 3b96a576d..018cce265 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1,4 +1,5 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. + #include #include #include @@ -797,6 +798,43 @@ std::pair, std::vector> Cosh::vmap( return {{cosh(inputs[0], stream())}, axes}; } +std::vector CustomVJP::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + std::vector inputs(primals.begin(), primals.end() - outputs.size()); + auto all_vjps = vjp_fun_(inputs, cotangents, outputs); + for (const auto& cot : cotangents) { + all_vjps.emplace_back(cot); + } + + std::vector vjps; + vjps.reserve(argnums.size()); + for (auto arg : argnums) { + vjps.push_back(all_vjps[arg]); + } + + return vjps; +} + +std::vector Depends::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + std::vector vjps; + + for (auto arg : argnums) { + if (arg < cotangents.size()) { + vjps.push_back(cotangents[arg]); + } else { + vjps.push_back(zeros_like(primals[arg])); + } + } + return vjps; +} + std::vector Divide::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 2f9f6d6b3..265095694 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -552,6 +552,60 @@ class Cosh : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class CustomVJP : public Primitive { + public: + explicit CustomVJP( + Stream stream, + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> fun) + : Primitive(stream), vjp_fun_(std::move(fun)) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotan, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_PRINT(CustomVJP); + + private: + void eval(const std::vector& inputs, std::vector& outputs); + + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> + vjp_fun_; +}; + +class Depends : public Primitive { + public: + explicit Depends(Stream stream) : Primitive(stream) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotan, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_PRINT(Depends); + + private: + void eval(const std::vector& inputs, std::vector& outputs); +}; + class Divide : public UnaryPrimitive { public: explicit Divide(Stream stream) : UnaryPrimitive(stream){}; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 88d5038f2..aeafcac7d 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -35,7 +35,7 @@ class Synchronizer : public Primitive { int detail::InTracing::tracing_counter{0}; void eval(const std::vector& outputs) { - std::function recurse; + std::function recurse; std::queue tape; std::unordered_set cache; std::unordered_map> deps; @@ -52,21 +52,57 @@ void eval(const std::vector& outputs) { auto synchronizer = array({}, bool_, std::make_unique(stream), outputs); - recurse = [&](const array& a) { + recurse = [&](const array& a, bool largest_branch_first) { auto id = a.id(); if (cache.find(id) != cache.end()) { return; } - for (auto in : a.inputs()) { - recurse(in); - // If one of the inputs is being computed on a different - // stream, we need to manage the dependency. + + // If the input is being computed on a different stream, we need to manage + // the dependency. + auto check_dependency = [&](const array& in) { if (!in.is_evaled()) { if (a.primitive().stream() != in.primitive().stream()) { deps.insert({in.primitive_id(), std::shared_future{}}); } } + }; + + // Recurse to the largest or smallest branch first. + size_t num_inputs = a.inputs().size(); + if (num_inputs == 1) { + auto& in = a.inputs()[0]; + recurse(in, true); + check_dependency(in); + } else if (num_inputs == 2) { + auto depth_1 = a.inputs()[0].graph_depth(); + auto depth_2 = a.inputs()[1].graph_depth(); + auto& in1 = a.inputs()[static_cast( + !((depth_1 > depth_2) == largest_branch_first))]; + auto& in2 = a.inputs()[static_cast( + ((depth_1 > depth_2) == largest_branch_first))]; + recurse(in1, true); + check_dependency(in1); + recurse(in2, true); + check_dependency(in2); + } else if (num_inputs > 2) { + std::vector recursion_order(a.inputs().size()); + std::iota(recursion_order.begin(), recursion_order.end(), 0); + std::sort( + recursion_order.begin(), + recursion_order.end(), + [&a, largest_branch_first](int i, int j) { + auto depth_i = a.inputs()[i].graph_depth(); + auto depth_j = a.inputs()[j].graph_depth(); + return largest_branch_first ? depth_i > depth_j : depth_j < depth_i; + }); + for (int idx : recursion_order) { + auto& in = a.inputs()[idx]; + recurse(in, true); + check_dependency(in); + } } + cache.insert(id); for (auto& s : a.siblings()) { cache.insert(s.id()); @@ -80,7 +116,7 @@ void eval(const std::vector& outputs) { } }; - recurse(synchronizer); + recurse(synchronizer, false); uintptr_t synch_id = synchronizer.primitive_id(); deps.insert({synch_id, std::shared_future{}}); @@ -713,4 +749,58 @@ std::function vmap( return [vfun](const array& a) { return vfun({a})[0]; }; } +std::function(const std::vector&)> custom_vjp( + std::function(const std::vector&)> fun, + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> fun_vjp) { + return [fun = std::move(fun), + fun_vjp = std::move(fun_vjp)](const std::vector& args) { + // Compute the outputs + auto outputs = fun(args); + for (auto& out : outputs) { + out = stop_gradient(out); + } + + // Prepare the inputs to the primitive + // We also add the outputs to the primitive so that it can "run" the forward + // pass. + std::vector inputs = args; + inputs.insert(inputs.end(), outputs.begin(), outputs.end()); + + // Compute the stream. Maybe do it in a smarter way at some point in the + // future. + Stream s = (outputs[0].has_primitive()) ? outputs[0].primitive().stream() + : default_stream(default_device()); + + // Make the output info + std::vector> shapes; + std::vector dtypes; + for (const auto& out : outputs) { + shapes.emplace_back(out.shape()); + dtypes.emplace_back(out.dtype()); + } + + return array::make_arrays( + shapes, + dtypes, + std::make_shared(to_stream(s), fun_vjp), + inputs); + }; +} + +std::function(const std::vector&)> checkpoint( + std::function(const std::vector&)> fun) { + auto vjp_fun = [fun]( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& outputs) -> std::vector { + auto [__, vjps] = vjp(fun, depends(primals, outputs), cotangents); + return vjps; + }; + + return custom_vjp(fun, vjp_fun); +} + } // namespace mlx::core diff --git a/mlx/transforms.h b/mlx/transforms.h index 7a13b8f7e..297571e1d 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -191,4 +191,22 @@ std::function(const std::vector&)> vmap( const std::vector& in_axes = {}, const std::vector& out_axes = {}); +/** + * Return the results of calling fun with args but if their vjp is computed it + * will be computed by fun_vjp. + */ +std::function(const std::vector&)> custom_vjp( + std::function(const std::vector&)> fun, + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> fun_vjp); + +/** + * Checkpoint the gradient of a function. Namely, discard all intermediate + * state and recalculate it when we need to compute the gradient. + */ +std::function(const std::vector&)> checkpoint( + std::function(const std::vector&)> fun); + } // namespace mlx::core diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py index c61f5405f..6b88d4a8f 100644 --- a/python/mlx/nn/layers/transformer.py +++ b/python/mlx/nn/layers/transformer.py @@ -9,6 +9,7 @@ from mlx.nn.layers.base import Module from mlx.nn.layers.dropout import Dropout from mlx.nn.layers.linear import Linear from mlx.nn.layers.normalization import LayerNorm +from mlx.nn.utils import checkpoint class MultiHeadAttention(Module): @@ -167,6 +168,7 @@ class TransformerEncoder(Module): dropout: float = 0.0, activation=relu, norm_first: bool = False, + checkpoint: bool = False, ): super().__init__() self.layers = [ @@ -176,10 +178,14 @@ class TransformerEncoder(Module): for i in range(num_layers) ] self.ln = LayerNorm(dims) + self.checkpoint = checkpoint def __call__(self, x, mask): for l in self.layers: - x = l(x, mask) + if self.checkpoint: + x = checkpoint(l)(x, mask) + else: + x = l(x, mask) return self.ln(x) @@ -255,6 +261,7 @@ class TransformerDecoder(Module): dropout: float = 0.0, activation=relu, norm_first: bool = False, + checkpoint: bool = False, ): super().__init__() self.layers = [ @@ -264,10 +271,14 @@ class TransformerDecoder(Module): for i in range(num_layers) ] self.ln = LayerNorm(dims) + self.checkpoint = checkpoint def __call__(self, x, memory, x_mask, memory_mask): for l in self.layers: - x = l(x, memory, x_mask, memory_mask) + if self.checkpoint: + x = checkpoint(l)(x, memory, x_mask, memory_mask) + else: + x = l(x, memory, x_mask, memory_mask) return self.ln(x) @@ -307,6 +318,9 @@ class Transformer(Module): norm_first (bool, optional): if ``True``, encoder and decoder layers will perform layer normalization before attention and MLP operations, otherwise after. Default: ``False``. + chekpoint (bool, optional): if ``True`` perform gradient checkpointing + to reduce the memory usage at the expense of more computation. + Default: ``False``. """ def __init__( @@ -321,6 +335,7 @@ class Transformer(Module): custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None, norm_first: bool = False, + checkpoint: bool = False, ): super().__init__() if custom_encoder is not None: @@ -334,6 +349,7 @@ class Transformer(Module): dropout, activation, norm_first, + checkpoint, ) if custom_decoder is not None: @@ -347,6 +363,7 @@ class Transformer(Module): dropout, activation, norm_first, + checkpoint, ) def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask): diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index 8ab80cb00..f651ce92e 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -1,11 +1,14 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. +from functools import wraps from typing import Callable import mlx.core as mx +from .layers.base import Module -def value_and_grad(model: "mlx.nn.Module", fn: Callable): + +def value_and_grad(model: Module, fn: Callable): """Transform the passed function ``fn`` to a function that computes the gradients of ``fn`` wrt the model's trainable parameters and also its value. @@ -26,8 +29,42 @@ def value_and_grad(model: "mlx.nn.Module", fn: Callable): value_grad_fn = mx.value_and_grad(inner_fn) + @wraps(fn) def wrapped_value_grad_fn(*args, **kwargs): value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs) return value, grad return wrapped_value_grad_fn + + +def checkpoint(module: Module, fn: Callable = None): + """Transform the passed callable to one that performs gradient + checkpointing with respect to the trainable parameters of the module (and + the callable's inputs). + + Args: + module (mlx.nn.Module): The module for whose parameters we will be + performing gradient checkpointing. + fn (Callable, optional): The function to checkpoint. If not provided it + defaults to the provided module. + + Returns: + A callable that saves the inputs and outputs during the forward pass + and recomputes all intermediate states during the backward pass. + """ + if fn is None: + # Capturing module instead of module.__call__ allows someone to + # monkey-patch __call__ later on and the correct method will be used + fn = module + + def inner_fn(params, *args, **kwargs): + module.update(params) + return fn(*args, **kwargs) + + checkpointed_fn = mx.checkpoint(inner_fn) + + @wraps(fn) + def wrapped_checkpointed_fn(*args, **kwargs): + return checkpointed_fn(module.trainable_parameters(), *args, **kwargs) + + return wrapped_checkpointed_fn diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 0c15daf85..8937892fb 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. -#include #include #include #include @@ -142,7 +141,8 @@ std::vector tree_flatten(py::object tree, bool strict = true) { if (py::isinstance(obj)) { flat_tree.push_back(py::cast(obj)); } else if (strict) { - throw std::invalid_argument("Argument is not an array"); + throw std::invalid_argument( + "[tree_flatten] The argument should contain only arrays"); } }); @@ -162,12 +162,48 @@ py::object tree_unflatten( }); } -py::object tree_unflatten_none( +py::object structure_sentinel() { + static py::object sentinel; + + if (sentinel.ptr() == nullptr) { + sentinel = py::capsule(&sentinel); + // probably not needed but this should make certain that we won't ever + // delete the sentinel + sentinel.inc_ref(); + } + + return sentinel; +} + +std::pair, py::object> tree_flatten_with_structure( py::object tree, + bool strict = true) { + auto sentinel = structure_sentinel(); + std::vector flat_tree; + auto structure = tree_map( + tree, + [&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) { + if (py::isinstance(obj)) { + flat_tree.push_back(py::cast(obj)); + return sentinel; + } else if (!strict) { + return py::cast(obj); + } else { + throw std::invalid_argument( + "[tree_flatten] The argument should contain only arrays"); + } + }); + + return {flat_tree, structure}; +} + +py::object tree_unflatten_from_structure( + py::object structure, const std::vector& values, int index = 0) { - return tree_map(tree, [&](py::handle obj) { - if (py::isinstance(obj)) { + auto sentinel = structure_sentinel(); + return tree_map(structure, [&](py::handle obj) { + if (obj.is(sentinel)) { return py::cast(values[index++]); } else { return py::cast(obj); @@ -472,14 +508,10 @@ struct PyCompiledFun { py::object operator()(const py::args& args) { auto compile_fun = [this, &args](const std::vector& a) { - // Call the python function - py::object py_outputs = this->fun(*tree_unflatten(args, a)); + // Call the python function and flatten the outputs + auto [outputs, py_outputs] = tree_flatten_with_structure( + std::move(this->fun(*tree_unflatten(args, a))), true); - // Flatten the outputs - auto outputs = tree_flatten(py_outputs, true); - - py_outputs = - tree_map(py_outputs, [](const py::handle& x) { return py::none(); }); tree_cache().insert({this->fun_id, py_outputs}); return outputs; }; @@ -492,15 +524,75 @@ struct PyCompiledFun { // Put the outputs back in the container py::object py_outputs = tree_cache().at(fun_id); - return tree_unflatten_none(py_outputs, outputs); + return tree_unflatten_from_structure(py_outputs, outputs); }; ~PyCompiledFun() { + py::gil_scoped_acquire gil; + tree_cache().erase(fun_id); detail::compile_erase(fun_id); + fun.release().dec_ref(); } }; +class PyCheckpointedFun { + public: + PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {} + + ~PyCheckpointedFun() { + py::gil_scoped_acquire gil; + + fun_.release().dec_ref(); + } + + struct InnerFunction { + py::object fun_; + py::object args_structure_; + std::weak_ptr output_structure_; + + InnerFunction( + py::object fun, + py::object args_structure, + std::weak_ptr output_structure) + : fun_(std::move(fun)), + args_structure_(std::move(args_structure)), + output_structure_(output_structure) {} + ~InnerFunction() { + py::gil_scoped_acquire gil; + + fun_.release().dec_ref(); + args_structure_.release().dec_ref(); + } + + std::vector operator()(const std::vector& inputs) { + auto args = py::cast( + tree_unflatten_from_structure(args_structure_, inputs)); + auto [outputs, output_structure] = + tree_flatten_with_structure(fun_(*args[0], **args[1]), false); + if (auto s = output_structure_.lock()) { + *s = output_structure; + } + return outputs; + } + }; + + py::object operator()(const py::args& args, const py::kwargs& kwargs) { + auto output_structure = std::make_shared(); + auto full_args = py::make_tuple(args, kwargs); + auto [inputs, args_structure] = + tree_flatten_with_structure(full_args, false); + + auto outputs = checkpoint( + InnerFunction(fun_, args_structure, output_structure))(inputs); + + return tree_unflatten_from_structure(*output_structure, outputs); + } + + private: + py::function fun_; +}; + void init_transforms(py::module_& m) { py::options options; options.disable_function_signatures(); @@ -802,6 +894,10 @@ void init_transforms(py::module_& m) { Globally enable compilation. This will override the environment variable ``MLX_DISABLE_COMPILE`` if set. )pbdoc"); + m.def( + "checkpoint", + [](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); }, + "fun"_a); // Register static Python object cleanup before the interpreter exits auto atexit = py::module_::import("atexit"); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f881d5792..be9570f7e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -21,6 +21,7 @@ target_sources(tests PRIVATE autograd_tests.cpp blas_tests.cpp compile_tests.cpp + custom_vjp_tests.cpp creations_tests.cpp device_tests.cpp eval_tests.cpp diff --git a/tests/custom_vjp_tests.cpp b/tests/custom_vjp_tests.cpp new file mode 100644 index 000000000..f916b694b --- /dev/null +++ b/tests/custom_vjp_tests.cpp @@ -0,0 +1,57 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test simple custom vjp") { + auto one = array(1.0); + auto x = array(2.0); + auto y = array(3.0); + + auto fn = [](const std::vector& inputs) { + return std::vector{inputs[0] * inputs[1], inputs[0] + inputs[1]}; + }; + auto transformed_fn = custom_vjp( + fn, + [&](const std::vector&, + const std::vector&, + const std::vector&) { + return std::vector{one, one}; + }); + + auto [z, g] = vjp(fn, {x, y}, {one, one}); + CHECK_EQ(z[0].item(), 6.0f); + CHECK_EQ(z[1].item(), 5.0f); + CHECK_EQ(g[0].item(), 4.0f); + CHECK_EQ(g[1].item(), 3.0f); + + std::tie(z, g) = vjp(transformed_fn, {x, y}, {one, one}); + CHECK_EQ(z[0].item(), 6.0f); + CHECK_EQ(z[1].item(), 5.0f); + CHECK_EQ(g[0].item(), 1.0f); + CHECK_EQ(g[1].item(), 1.0f); +} + +TEST_CASE("test checkpointing") { + auto one = array(1.0); + auto x = array(2.0); + auto y = array(3.0); + + int cnt = 0; + auto fn = [&cnt](const std::vector& inputs) { + cnt++; + auto x = inputs[0] * inputs[1]; + auto y = inputs[0] + inputs[1]; + return std::vector{square(x + y)}; + }; + auto checkpointed_fn = checkpoint(fn); + + auto [z, g] = vjp(checkpointed_fn, {x, y}, {one}); + CHECK_EQ(z[0].item(), 121.0f); + CHECK_EQ(g[0].item(), 88.0f); + CHECK_EQ(g[1].item(), 66.0f); + CHECK_EQ(cnt, 2); +}