Custom VJP and checkpointing (#541)

* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
This commit is contained in:
Angelos Katharopoulos 2024-01-30 16:04:45 -08:00 committed by GitHub
parent 143e2690d5
commit 0de5988f92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 527 additions and 37 deletions

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <functional>
@ -97,11 +97,13 @@ void array::detach() {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->depth = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->depth = 0;
array_desc_->primitive = nullptr;
}
@ -180,7 +182,9 @@ array::ArrayDesc::ArrayDesc(
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayDesc::ArrayDesc(
@ -195,7 +199,9 @@ array::ArrayDesc::ArrayDesc(
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)

View File

@ -267,6 +267,11 @@ class array {
return outputs;
};
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
uint16_t graph_depth() const {
return array_desc_->depth;
}
/** Detach the array from the graph. */
void detach();
@ -377,6 +382,9 @@ class array {
// The arrays position in the output list
uint32_t position{0};
// The depth of the array in the graph.
uint16_t depth{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
@ -35,6 +35,8 @@ DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h>
@ -47,6 +47,8 @@ DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(Remainder)
DEFAULT(Equal)

View File

@ -232,6 +232,25 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
}
}
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h"

View File

@ -2,7 +2,6 @@
#include <algorithm>
#include <cassert>
#include <iostream>
#include <numeric>
#include <sstream>

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
@ -615,7 +615,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
assert(inputs.size() == 3);
if (!is_floating_point(out.dtype())) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
@ -486,6 +486,18 @@ void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cosh");
}
void CustomVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
}

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cassert>

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include "mlx/primitives.h"
@ -37,6 +37,8 @@ NO_GPU(Convolution)
NO_GPU(Copy)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU_MULTI(CustomVJP)
NO_GPU_MULTI(Depends)
NO_GPU(Divide)
NO_GPU(Remainder)
NO_GPU(Equal)

View File

@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cmath>
#include <numeric>
@ -7,6 +8,7 @@
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/utils.h"
namespace mlx::core {
@ -3327,4 +3329,26 @@ array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) {
}
}
std::vector<array> depends(
const std::vector<array>& inputs,
const std::vector<array>& dependencies) {
std::vector<array> all_inputs = inputs;
all_inputs.insert(all_inputs.end(), dependencies.begin(), dependencies.end());
// Compute the stream. Maybe do it in a smarter way at some point in the
// future.
Stream s = (inputs[0].has_primitive()) ? inputs[0].primitive().stream()
: to_stream({});
// Make the output info
std::vector<std::vector<int>> shapes;
std::vector<Dtype> dtypes;
for (const auto& in : inputs) {
shapes.emplace_back(in.shape());
dtypes.emplace_back(in.dtype());
}
return array::make_arrays(
shapes, dtypes, std::make_shared<Depends>(to_stream(s)), all_inputs);
}
} // namespace mlx::core

View File

@ -1116,4 +1116,13 @@ array diagonal(
/** Extract diagonal from a 2d array or create a diagonal matrix. */
array diag(const array& a, int k = 0, StreamOrDevice s = {});
/**
* Implements the identity function but allows injecting dependencies to other
* arrays. This ensures that these other arrays will have been computed
* when the outputs of this function are computed.
*/
std::vector<array> depends(
const std::vector<array>& inputs,
const std::vector<array>& dependencies);
} // namespace mlx::core

View File

@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <cmath>
@ -797,6 +798,43 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
return {{cosh(inputs[0], stream())}, axes};
}
std::vector<array> CustomVJP::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
std::vector<array> inputs(primals.begin(), primals.end() - outputs.size());
auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
for (const auto& cot : cotangents) {
all_vjps.emplace_back(cot);
}
std::vector<array> vjps;
vjps.reserve(argnums.size());
for (auto arg : argnums) {
vjps.push_back(all_vjps[arg]);
}
return vjps;
}
std::vector<array> Depends::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg < cotangents.size()) {
vjps.push_back(cotangents[arg]);
} else {
vjps.push_back(zeros_like(primals[arg]));
}
}
return vjps;
}
std::vector<array> Divide::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@ -552,6 +552,60 @@ class Cosh : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class CustomVJP : public Primitive {
public:
explicit CustomVJP(
Stream stream,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun)
: Primitive(stream), vjp_fun_(std::move(fun)) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(CustomVJP);
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)>
vjp_fun_;
};
class Depends : public Primitive {
public:
explicit Depends(Stream stream) : Primitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(Depends);
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};
class Divide : public UnaryPrimitive {
public:
explicit Divide(Stream stream) : UnaryPrimitive(stream){};

View File

@ -35,7 +35,7 @@ class Synchronizer : public Primitive {
int detail::InTracing::tracing_counter{0};
void eval(const std::vector<array>& outputs) {
std::function<void(const array&)> recurse;
std::function<void(const array&, bool)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
@ -52,21 +52,57 @@ void eval(const std::vector<array>& outputs) {
auto synchronizer =
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
recurse = [&](const array& a) {
recurse = [&](const array& a, bool largest_branch_first) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
for (auto in : a.inputs()) {
recurse(in);
// If one of the inputs is being computed on a different
// stream, we need to manage the dependency.
// If the input is being computed on a different stream, we need to manage
// the dependency.
auto check_dependency = [&](const array& in) {
if (!in.is_evaled()) {
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.primitive_id(), std::shared_future<void>{}});
}
}
};
// Recurse to the largest or smallest branch first.
size_t num_inputs = a.inputs().size();
if (num_inputs == 1) {
auto& in = a.inputs()[0];
recurse(in, true);
check_dependency(in);
} else if (num_inputs == 2) {
auto depth_1 = a.inputs()[0].graph_depth();
auto depth_2 = a.inputs()[1].graph_depth();
auto& in1 = a.inputs()[static_cast<int>(
!((depth_1 > depth_2) == largest_branch_first))];
auto& in2 = a.inputs()[static_cast<int>(
((depth_1 > depth_2) == largest_branch_first))];
recurse(in1, true);
check_dependency(in1);
recurse(in2, true);
check_dependency(in2);
} else if (num_inputs > 2) {
std::vector<int> recursion_order(a.inputs().size());
std::iota(recursion_order.begin(), recursion_order.end(), 0);
std::sort(
recursion_order.begin(),
recursion_order.end(),
[&a, largest_branch_first](int i, int j) {
auto depth_i = a.inputs()[i].graph_depth();
auto depth_j = a.inputs()[j].graph_depth();
return largest_branch_first ? depth_i > depth_j : depth_j < depth_i;
});
for (int idx : recursion_order) {
auto& in = a.inputs()[idx];
recurse(in, true);
check_dependency(in);
}
}
cache.insert(id);
for (auto& s : a.siblings()) {
cache.insert(s.id());
@ -80,7 +116,7 @@ void eval(const std::vector<array>& outputs) {
}
};
recurse(synchronizer);
recurse(synchronizer, false);
uintptr_t synch_id = synchronizer.primitive_id();
deps.insert({synch_id, std::shared_future<void>{}});
@ -713,4 +749,58 @@ std::function<array(const array&)> vmap(
return [vfun](const array& a) { return vfun({a})[0]; };
}
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp) {
return [fun = std::move(fun),
fun_vjp = std::move(fun_vjp)](const std::vector<array>& args) {
// Compute the outputs
auto outputs = fun(args);
for (auto& out : outputs) {
out = stop_gradient(out);
}
// Prepare the inputs to the primitive
// We also add the outputs to the primitive so that it can "run" the forward
// pass.
std::vector<array> inputs = args;
inputs.insert(inputs.end(), outputs.begin(), outputs.end());
// Compute the stream. Maybe do it in a smarter way at some point in the
// future.
Stream s = (outputs[0].has_primitive()) ? outputs[0].primitive().stream()
: default_stream(default_device());
// Make the output info
std::vector<std::vector<int>> shapes;
std::vector<Dtype> dtypes;
for (const auto& out : outputs) {
shapes.emplace_back(out.shape());
dtypes.emplace_back(out.dtype());
}
return array::make_arrays(
shapes,
dtypes,
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
inputs);
};
}
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
std::function<std::vector<array>(const std::vector<array>&)> fun) {
auto vjp_fun = [fun](
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<array>& outputs) -> std::vector<array> {
auto [__, vjps] = vjp(fun, depends(primals, outputs), cotangents);
return vjps;
};
return custom_vjp(fun, vjp_fun);
}
} // namespace mlx::core

View File

@ -191,4 +191,22 @@ std::function<std::vector<array>(const std::vector<array>&)> vmap(
const std::vector<int>& in_axes = {},
const std::vector<int>& out_axes = {});
/**
* Return the results of calling fun with args but if their vjp is computed it
* will be computed by fun_vjp.
*/
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp);
/**
* Checkpoint the gradient of a function. Namely, discard all intermediate
* state and recalculate it when we need to compute the gradient.
*/
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
std::function<std::vector<array>(const std::vector<array>&)> fun);
} // namespace mlx::core

View File

@ -9,6 +9,7 @@ from mlx.nn.layers.base import Module
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm
from mlx.nn.utils import checkpoint
class MultiHeadAttention(Module):
@ -167,6 +168,7 @@ class TransformerEncoder(Module):
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
checkpoint: bool = False,
):
super().__init__()
self.layers = [
@ -176,10 +178,14 @@ class TransformerEncoder(Module):
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
self.checkpoint = checkpoint
def __call__(self, x, mask):
for l in self.layers:
x = l(x, mask)
if self.checkpoint:
x = checkpoint(l)(x, mask)
else:
x = l(x, mask)
return self.ln(x)
@ -255,6 +261,7 @@ class TransformerDecoder(Module):
dropout: float = 0.0,
activation=relu,
norm_first: bool = False,
checkpoint: bool = False,
):
super().__init__()
self.layers = [
@ -264,10 +271,14 @@ class TransformerDecoder(Module):
for i in range(num_layers)
]
self.ln = LayerNorm(dims)
self.checkpoint = checkpoint
def __call__(self, x, memory, x_mask, memory_mask):
for l in self.layers:
x = l(x, memory, x_mask, memory_mask)
if self.checkpoint:
x = checkpoint(l)(x, memory, x_mask, memory_mask)
else:
x = l(x, memory, x_mask, memory_mask)
return self.ln(x)
@ -307,6 +318,9 @@ class Transformer(Module):
norm_first (bool, optional): if ``True``, encoder and decoder layers
will perform layer normalization before attention and MLP
operations, otherwise after. Default: ``False``.
chekpoint (bool, optional): if ``True`` perform gradient checkpointing
to reduce the memory usage at the expense of more computation.
Default: ``False``.
"""
def __init__(
@ -321,6 +335,7 @@ class Transformer(Module):
custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None,
norm_first: bool = False,
checkpoint: bool = False,
):
super().__init__()
if custom_encoder is not None:
@ -334,6 +349,7 @@ class Transformer(Module):
dropout,
activation,
norm_first,
checkpoint,
)
if custom_decoder is not None:
@ -347,6 +363,7 @@ class Transformer(Module):
dropout,
activation,
norm_first,
checkpoint,
)
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):

View File

@ -1,11 +1,14 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
from functools import wraps
from typing import Callable
import mlx.core as mx
from .layers.base import Module
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
def value_and_grad(model: Module, fn: Callable):
"""Transform the passed function ``fn`` to a function that computes the
gradients of ``fn`` wrt the model's trainable parameters and also its
value.
@ -26,8 +29,42 @@ def value_and_grad(model: "mlx.nn.Module", fn: Callable):
value_grad_fn = mx.value_and_grad(inner_fn)
@wraps(fn)
def wrapped_value_grad_fn(*args, **kwargs):
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
return value, grad
return wrapped_value_grad_fn
def checkpoint(module: Module, fn: Callable = None):
"""Transform the passed callable to one that performs gradient
checkpointing with respect to the trainable parameters of the module (and
the callable's inputs).
Args:
module (mlx.nn.Module): The module for whose parameters we will be
performing gradient checkpointing.
fn (Callable, optional): The function to checkpoint. If not provided it
defaults to the provided module.
Returns:
A callable that saves the inputs and outputs during the forward pass
and recomputes all intermediate states during the backward pass.
"""
if fn is None:
# Capturing module instead of module.__call__ allows someone to
# monkey-patch __call__ later on and the correct method will be used
fn = module
def inner_fn(params, *args, **kwargs):
module.update(params)
return fn(*args, **kwargs)
checkpointed_fn = mx.checkpoint(inner_fn)
@wraps(fn)
def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
return wrapped_checkpointed_fn

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <algorithm>
@ -142,7 +141,8 @@ std::vector<array> tree_flatten(py::object tree, bool strict = true) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
} else if (strict) {
throw std::invalid_argument("Argument is not an array");
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
@ -162,12 +162,48 @@ py::object tree_unflatten(
});
}
py::object tree_unflatten_none(
py::object structure_sentinel() {
static py::object sentinel;
if (sentinel.ptr() == nullptr) {
sentinel = py::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever
// delete the sentinel
sentinel.inc_ref();
}
return sentinel;
}
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree,
bool strict = true) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
return sentinel;
} else if (!strict) {
return py::cast<py::object>(obj);
} else {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return {flat_tree, structure};
}
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values,
int index = 0) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<py::none>(obj)) {
auto sentinel = structure_sentinel();
return tree_map(structure, [&](py::handle obj) {
if (obj.is(sentinel)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
@ -472,14 +508,10 @@ struct PyCompiledFun {
py::object operator()(const py::args& args) {
auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function
py::object py_outputs = this->fun(*tree_unflatten(args, a));
// Call the python function and flatten the outputs
auto [outputs, py_outputs] = tree_flatten_with_structure(
std::move(this->fun(*tree_unflatten(args, a))), true);
// Flatten the outputs
auto outputs = tree_flatten(py_outputs, true);
py_outputs =
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
tree_cache().insert({this->fun_id, py_outputs});
return outputs;
};
@ -492,15 +524,75 @@ struct PyCompiledFun {
// Put the outputs back in the container
py::object py_outputs = tree_cache().at(fun_id);
return tree_unflatten_none(py_outputs, outputs);
return tree_unflatten_from_structure(py_outputs, outputs);
};
~PyCompiledFun() {
py::gil_scoped_acquire gil;
tree_cache().erase(fun_id);
detail::compile_erase(fun_id);
fun.release().dec_ref();
}
};
class PyCheckpointedFun {
public:
PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {}
~PyCheckpointedFun() {
py::gil_scoped_acquire gil;
fun_.release().dec_ref();
}
struct InnerFunction {
py::object fun_;
py::object args_structure_;
std::weak_ptr<py::object> output_structure_;
InnerFunction(
py::object fun,
py::object args_structure,
std::weak_ptr<py::object> output_structure)
: fun_(std::move(fun)),
args_structure_(std::move(args_structure)),
output_structure_(output_structure) {}
~InnerFunction() {
py::gil_scoped_acquire gil;
fun_.release().dec_ref();
args_structure_.release().dec_ref();
}
std::vector<array> operator()(const std::vector<array>& inputs) {
auto args = py::cast<py::tuple>(
tree_unflatten_from_structure(args_structure_, inputs));
auto [outputs, output_structure] =
tree_flatten_with_structure(fun_(*args[0], **args[1]), false);
if (auto s = output_structure_.lock()) {
*s = output_structure;
}
return outputs;
}
};
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
auto output_structure = std::make_shared<py::object>();
auto full_args = py::make_tuple(args, kwargs);
auto [inputs, args_structure] =
tree_flatten_with_structure(full_args, false);
auto outputs = checkpoint(
InnerFunction(fun_, args_structure, output_structure))(inputs);
return tree_unflatten_from_structure(*output_structure, outputs);
}
private:
py::function fun_;
};
void init_transforms(py::module_& m) {
py::options options;
options.disable_function_signatures();
@ -802,6 +894,10 @@ void init_transforms(py::module_& m) {
Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILE`` if set.
)pbdoc");
m.def(
"checkpoint",
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
"fun"_a);
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");

View File

@ -21,6 +21,7 @@ target_sources(tests PRIVATE
autograd_tests.cpp
blas_tests.cpp
compile_tests.cpp
custom_vjp_tests.cpp
creations_tests.cpp
device_tests.cpp
eval_tests.cpp

View File

@ -0,0 +1,57 @@
// Copyright © 2023-2024 Apple Inc.
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test simple custom vjp") {
auto one = array(1.0);
auto x = array(2.0);
auto y = array(3.0);
auto fn = [](const std::vector<array>& inputs) {
return std::vector<array>{inputs[0] * inputs[1], inputs[0] + inputs[1]};
};
auto transformed_fn = custom_vjp(
fn,
[&](const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&) {
return std::vector<array>{one, one};
});
auto [z, g] = vjp(fn, {x, y}, {one, one});
CHECK_EQ(z[0].item<float>(), 6.0f);
CHECK_EQ(z[1].item<float>(), 5.0f);
CHECK_EQ(g[0].item<float>(), 4.0f);
CHECK_EQ(g[1].item<float>(), 3.0f);
std::tie(z, g) = vjp(transformed_fn, {x, y}, {one, one});
CHECK_EQ(z[0].item<float>(), 6.0f);
CHECK_EQ(z[1].item<float>(), 5.0f);
CHECK_EQ(g[0].item<float>(), 1.0f);
CHECK_EQ(g[1].item<float>(), 1.0f);
}
TEST_CASE("test checkpointing") {
auto one = array(1.0);
auto x = array(2.0);
auto y = array(3.0);
int cnt = 0;
auto fn = [&cnt](const std::vector<array>& inputs) {
cnt++;
auto x = inputs[0] * inputs[1];
auto y = inputs[0] + inputs[1];
return std::vector<array>{square(x + y)};
};
auto checkpointed_fn = checkpoint(fn);
auto [z, g] = vjp(checkpointed_fn, {x, y}, {one});
CHECK_EQ(z[0].item<float>(), 121.0f);
CHECK_EQ(g[0].item<float>(), 88.0f);
CHECK_EQ(g[1].item<float>(), 66.0f);
CHECK_EQ(cnt, 2);
}