Custom VJP and checkpointing (#541)

* Implement custom_vjp and checkpointing
* Add a dependency management primitive
* Change the eval order to deep branches first
* Add graph depth tracking to the array
This commit is contained in:
Angelos Katharopoulos 2024-01-30 16:04:45 -08:00 committed by GitHub
parent 143e2690d5
commit 0de5988f92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 527 additions and 37 deletions

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <functional> #include <functional>
@ -97,11 +97,13 @@ void array::detach() {
s.array_desc_->inputs.clear(); s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear(); s.array_desc_->siblings.clear();
s.array_desc_->position = 0; s.array_desc_->position = 0;
s.array_desc_->depth = 0;
s.array_desc_->primitive = nullptr; s.array_desc_->primitive = nullptr;
} }
array_desc_->inputs.clear(); array_desc_->inputs.clear();
array_desc_->siblings.clear(); array_desc_->siblings.clear();
array_desc_->position = 0; array_desc_->position = 0;
array_desc_->depth = 0;
array_desc_->primitive = nullptr; array_desc_->primitive = nullptr;
} }
@ -180,7 +182,9 @@ array::ArrayDesc::ArrayDesc(
std::tie(size, strides) = cum_prod(this->shape); std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) { for (auto& in : inputs) {
is_tracer |= in.is_tracer(); is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
} }
depth++;
} }
array::ArrayDesc::ArrayDesc( array::ArrayDesc::ArrayDesc(
@ -195,7 +199,9 @@ array::ArrayDesc::ArrayDesc(
std::tie(size, strides) = cum_prod(this->shape); std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : inputs) { for (auto& in : inputs) {
is_tracer |= in.is_tracer(); is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
} }
depth++;
} }
array::ArrayIterator::ArrayIterator(const array& arr, int idx) array::ArrayIterator::ArrayIterator(const array& arr, int idx)

View File

@ -267,6 +267,11 @@ class array {
return outputs; return outputs;
}; };
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
uint16_t graph_depth() const {
return array_desc_->depth;
}
/** Detach the array from the graph. */ /** Detach the array from the graph. */
void detach(); void detach();
@ -377,6 +382,9 @@ class array {
// The arrays position in the output list // The arrays position in the output list
uint32_t position{0}; uint32_t position{0};
// The depth of the array in the graph.
uint16_t depth{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype); explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc( explicit ArrayDesc(

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@ -35,6 +35,8 @@ DEFAULT(Broadcast)
DEFAULT(Ceil) DEFAULT(Ceil)
DEFAULT(Concatenate) DEFAULT(Concatenate)
DEFAULT(Copy) DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Equal) DEFAULT(Equal)
DEFAULT(Erf) DEFAULT(Erf)
DEFAULT(ErfInv) DEFAULT(ErfInv)

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK #ifdef ACCELERATE_NEW_LAPACK
#include <vecLib/cblas_new.h> #include <vecLib/cblas_new.h>
@ -47,6 +47,8 @@ DEFAULT(Convolution)
DEFAULT(Copy) DEFAULT(Copy)
DEFAULT(Cos) DEFAULT(Cos)
DEFAULT(Cosh) DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Divide) DEFAULT(Divide)
DEFAULT(Remainder) DEFAULT(Remainder)
DEFAULT(Equal) DEFAULT(Equal)

View File

@ -232,6 +232,25 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) { void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];

View File

@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cassert> #include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@ -2,7 +2,6 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <iostream>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
@ -615,7 +615,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 3);
if (!is_floating_point(out.dtype())) { if (!is_floating_point(out.dtype())) {
throw std::runtime_error( throw std::runtime_error(
"[matmul] Does not yet support non-floating point types."); "[matmul] Does not yet support non-floating point types.");

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
@ -486,6 +486,18 @@ void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "cosh"); unary_op(inputs, out, "cosh");
} }
void CustomVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Depends::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) { void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div"); binary_op(inputs, out, "div");
} }

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/primitives.h" #include "mlx/primitives.h"
@ -37,6 +37,8 @@ NO_GPU(Convolution)
NO_GPU(Copy) NO_GPU(Copy)
NO_GPU(Cos) NO_GPU(Cos)
NO_GPU(Cosh) NO_GPU(Cosh)
NO_GPU_MULTI(CustomVJP)
NO_GPU_MULTI(Depends)
NO_GPU(Divide) NO_GPU(Divide)
NO_GPU(Remainder) NO_GPU(Remainder)
NO_GPU(Equal) NO_GPU(Equal)

View File

@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <numeric> #include <numeric>
@ -7,6 +8,7 @@
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@ -3327,4 +3329,26 @@ array diag(const array& a, int k /* = 0 */, StreamOrDevice s /* = {} */) {
} }
} }
std::vector<array> depends(
const std::vector<array>& inputs,
const std::vector<array>& dependencies) {
std::vector<array> all_inputs = inputs;
all_inputs.insert(all_inputs.end(), dependencies.begin(), dependencies.end());
// Compute the stream. Maybe do it in a smarter way at some point in the
// future.
Stream s = (inputs[0].has_primitive()) ? inputs[0].primitive().stream()
: to_stream({});
// Make the output info
std::vector<std::vector<int>> shapes;
std::vector<Dtype> dtypes;
for (const auto& in : inputs) {
shapes.emplace_back(in.shape());
dtypes.emplace_back(in.dtype());
}
return array::make_arrays(
shapes, dtypes, std::make_shared<Depends>(to_stream(s)), all_inputs);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1116,4 +1116,13 @@ array diagonal(
/** Extract diagonal from a 2d array or create a diagonal matrix. */ /** Extract diagonal from a 2d array or create a diagonal matrix. */
array diag(const array& a, int k = 0, StreamOrDevice s = {}); array diag(const array& a, int k = 0, StreamOrDevice s = {});
/**
* Implements the identity function but allows injecting dependencies to other
* arrays. This ensures that these other arrays will have been computed
* when the outputs of this function are computed.
*/
std::vector<array> depends(
const std::vector<array>& inputs,
const std::vector<array>& dependencies);
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@ -797,6 +798,43 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
return {{cosh(inputs[0], stream())}, axes}; return {{cosh(inputs[0], stream())}, axes};
} }
std::vector<array> CustomVJP::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
std::vector<array> inputs(primals.begin(), primals.end() - outputs.size());
auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
for (const auto& cot : cotangents) {
all_vjps.emplace_back(cot);
}
std::vector<array> vjps;
vjps.reserve(argnums.size());
for (auto arg : argnums) {
vjps.push_back(all_vjps[arg]);
}
return vjps;
}
std::vector<array> Depends::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
std::vector<array> vjps;
for (auto arg : argnums) {
if (arg < cotangents.size()) {
vjps.push_back(cotangents[arg]);
} else {
vjps.push_back(zeros_like(primals[arg]));
}
}
return vjps;
}
std::vector<array> Divide::vjp( std::vector<array> Divide::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents, const std::vector<array>& cotangents,

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@ -552,6 +552,60 @@ class Cosh : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class CustomVJP : public Primitive {
public:
explicit CustomVJP(
Stream stream,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun)
: Primitive(stream), vjp_fun_(std::move(fun)) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(CustomVJP);
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)>
vjp_fun_;
};
class Depends : public Primitive {
public:
explicit Depends(Stream stream) : Primitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotan,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(Depends);
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};
class Divide : public UnaryPrimitive { class Divide : public UnaryPrimitive {
public: public:
explicit Divide(Stream stream) : UnaryPrimitive(stream){}; explicit Divide(Stream stream) : UnaryPrimitive(stream){};

View File

@ -35,7 +35,7 @@ class Synchronizer : public Primitive {
int detail::InTracing::tracing_counter{0}; int detail::InTracing::tracing_counter{0};
void eval(const std::vector<array>& outputs) { void eval(const std::vector<array>& outputs) {
std::function<void(const array&)> recurse; std::function<void(const array&, bool)> recurse;
std::queue<array> tape; std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache; std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps; std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
@ -52,21 +52,57 @@ void eval(const std::vector<array>& outputs) {
auto synchronizer = auto synchronizer =
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs); array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
recurse = [&](const array& a) { recurse = [&](const array& a, bool largest_branch_first) {
auto id = a.id(); auto id = a.id();
if (cache.find(id) != cache.end()) { if (cache.find(id) != cache.end()) {
return; return;
} }
for (auto in : a.inputs()) {
recurse(in); // If the input is being computed on a different stream, we need to manage
// If one of the inputs is being computed on a different // the dependency.
// stream, we need to manage the dependency. auto check_dependency = [&](const array& in) {
if (!in.is_evaled()) { if (!in.is_evaled()) {
if (a.primitive().stream() != in.primitive().stream()) { if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.primitive_id(), std::shared_future<void>{}}); deps.insert({in.primitive_id(), std::shared_future<void>{}});
} }
} }
};
// Recurse to the largest or smallest branch first.
size_t num_inputs = a.inputs().size();
if (num_inputs == 1) {
auto& in = a.inputs()[0];
recurse(in, true);
check_dependency(in);
} else if (num_inputs == 2) {
auto depth_1 = a.inputs()[0].graph_depth();
auto depth_2 = a.inputs()[1].graph_depth();
auto& in1 = a.inputs()[static_cast<int>(
!((depth_1 > depth_2) == largest_branch_first))];
auto& in2 = a.inputs()[static_cast<int>(
((depth_1 > depth_2) == largest_branch_first))];
recurse(in1, true);
check_dependency(in1);
recurse(in2, true);
check_dependency(in2);
} else if (num_inputs > 2) {
std::vector<int> recursion_order(a.inputs().size());
std::iota(recursion_order.begin(), recursion_order.end(), 0);
std::sort(
recursion_order.begin(),
recursion_order.end(),
[&a, largest_branch_first](int i, int j) {
auto depth_i = a.inputs()[i].graph_depth();
auto depth_j = a.inputs()[j].graph_depth();
return largest_branch_first ? depth_i > depth_j : depth_j < depth_i;
});
for (int idx : recursion_order) {
auto& in = a.inputs()[idx];
recurse(in, true);
check_dependency(in);
}
} }
cache.insert(id); cache.insert(id);
for (auto& s : a.siblings()) { for (auto& s : a.siblings()) {
cache.insert(s.id()); cache.insert(s.id());
@ -80,7 +116,7 @@ void eval(const std::vector<array>& outputs) {
} }
}; };
recurse(synchronizer); recurse(synchronizer, false);
uintptr_t synch_id = synchronizer.primitive_id(); uintptr_t synch_id = synchronizer.primitive_id();
deps.insert({synch_id, std::shared_future<void>{}}); deps.insert({synch_id, std::shared_future<void>{}});
@ -713,4 +749,58 @@ std::function<array(const array&)> vmap(
return [vfun](const array& a) { return vfun({a})[0]; }; return [vfun](const array& a) { return vfun({a})[0]; };
} }
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp) {
return [fun = std::move(fun),
fun_vjp = std::move(fun_vjp)](const std::vector<array>& args) {
// Compute the outputs
auto outputs = fun(args);
for (auto& out : outputs) {
out = stop_gradient(out);
}
// Prepare the inputs to the primitive
// We also add the outputs to the primitive so that it can "run" the forward
// pass.
std::vector<array> inputs = args;
inputs.insert(inputs.end(), outputs.begin(), outputs.end());
// Compute the stream. Maybe do it in a smarter way at some point in the
// future.
Stream s = (outputs[0].has_primitive()) ? outputs[0].primitive().stream()
: default_stream(default_device());
// Make the output info
std::vector<std::vector<int>> shapes;
std::vector<Dtype> dtypes;
for (const auto& out : outputs) {
shapes.emplace_back(out.shape());
dtypes.emplace_back(out.dtype());
}
return array::make_arrays(
shapes,
dtypes,
std::make_shared<CustomVJP>(to_stream(s), fun_vjp),
inputs);
};
}
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
std::function<std::vector<array>(const std::vector<array>&)> fun) {
auto vjp_fun = [fun](
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<array>& outputs) -> std::vector<array> {
auto [__, vjps] = vjp(fun, depends(primals, outputs), cotangents);
return vjps;
};
return custom_vjp(fun, vjp_fun);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -191,4 +191,22 @@ std::function<std::vector<array>(const std::vector<array>&)> vmap(
const std::vector<int>& in_axes = {}, const std::vector<int>& in_axes = {},
const std::vector<int>& out_axes = {}); const std::vector<int>& out_axes = {});
/**
* Return the results of calling fun with args but if their vjp is computed it
* will be computed by fun_vjp.
*/
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp);
/**
* Checkpoint the gradient of a function. Namely, discard all intermediate
* state and recalculate it when we need to compute the gradient.
*/
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
std::function<std::vector<array>(const std::vector<array>&)> fun);
} // namespace mlx::core } // namespace mlx::core

View File

@ -9,6 +9,7 @@ from mlx.nn.layers.base import Module
from mlx.nn.layers.dropout import Dropout from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.linear import Linear from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import LayerNorm from mlx.nn.layers.normalization import LayerNorm
from mlx.nn.utils import checkpoint
class MultiHeadAttention(Module): class MultiHeadAttention(Module):
@ -167,6 +168,7 @@ class TransformerEncoder(Module):
dropout: float = 0.0, dropout: float = 0.0,
activation=relu, activation=relu,
norm_first: bool = False, norm_first: bool = False,
checkpoint: bool = False,
): ):
super().__init__() super().__init__()
self.layers = [ self.layers = [
@ -176,10 +178,14 @@ class TransformerEncoder(Module):
for i in range(num_layers) for i in range(num_layers)
] ]
self.ln = LayerNorm(dims) self.ln = LayerNorm(dims)
self.checkpoint = checkpoint
def __call__(self, x, mask): def __call__(self, x, mask):
for l in self.layers: for l in self.layers:
x = l(x, mask) if self.checkpoint:
x = checkpoint(l)(x, mask)
else:
x = l(x, mask)
return self.ln(x) return self.ln(x)
@ -255,6 +261,7 @@ class TransformerDecoder(Module):
dropout: float = 0.0, dropout: float = 0.0,
activation=relu, activation=relu,
norm_first: bool = False, norm_first: bool = False,
checkpoint: bool = False,
): ):
super().__init__() super().__init__()
self.layers = [ self.layers = [
@ -264,10 +271,14 @@ class TransformerDecoder(Module):
for i in range(num_layers) for i in range(num_layers)
] ]
self.ln = LayerNorm(dims) self.ln = LayerNorm(dims)
self.checkpoint = checkpoint
def __call__(self, x, memory, x_mask, memory_mask): def __call__(self, x, memory, x_mask, memory_mask):
for l in self.layers: for l in self.layers:
x = l(x, memory, x_mask, memory_mask) if self.checkpoint:
x = checkpoint(l)(x, memory, x_mask, memory_mask)
else:
x = l(x, memory, x_mask, memory_mask)
return self.ln(x) return self.ln(x)
@ -307,6 +318,9 @@ class Transformer(Module):
norm_first (bool, optional): if ``True``, encoder and decoder layers norm_first (bool, optional): if ``True``, encoder and decoder layers
will perform layer normalization before attention and MLP will perform layer normalization before attention and MLP
operations, otherwise after. Default: ``False``. operations, otherwise after. Default: ``False``.
chekpoint (bool, optional): if ``True`` perform gradient checkpointing
to reduce the memory usage at the expense of more computation.
Default: ``False``.
""" """
def __init__( def __init__(
@ -321,6 +335,7 @@ class Transformer(Module):
custom_encoder: Optional[Any] = None, custom_encoder: Optional[Any] = None,
custom_decoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
norm_first: bool = False, norm_first: bool = False,
checkpoint: bool = False,
): ):
super().__init__() super().__init__()
if custom_encoder is not None: if custom_encoder is not None:
@ -334,6 +349,7 @@ class Transformer(Module):
dropout, dropout,
activation, activation,
norm_first, norm_first,
checkpoint,
) )
if custom_decoder is not None: if custom_decoder is not None:
@ -347,6 +363,7 @@ class Transformer(Module):
dropout, dropout,
activation, activation,
norm_first, norm_first,
checkpoint,
) )
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask): def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):

View File

@ -1,11 +1,14 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from functools import wraps
from typing import Callable from typing import Callable
import mlx.core as mx import mlx.core as mx
from .layers.base import Module
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
def value_and_grad(model: Module, fn: Callable):
"""Transform the passed function ``fn`` to a function that computes the """Transform the passed function ``fn`` to a function that computes the
gradients of ``fn`` wrt the model's trainable parameters and also its gradients of ``fn`` wrt the model's trainable parameters and also its
value. value.
@ -26,8 +29,42 @@ def value_and_grad(model: "mlx.nn.Module", fn: Callable):
value_grad_fn = mx.value_and_grad(inner_fn) value_grad_fn = mx.value_and_grad(inner_fn)
@wraps(fn)
def wrapped_value_grad_fn(*args, **kwargs): def wrapped_value_grad_fn(*args, **kwargs):
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs) value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
return value, grad return value, grad
return wrapped_value_grad_fn return wrapped_value_grad_fn
def checkpoint(module: Module, fn: Callable = None):
"""Transform the passed callable to one that performs gradient
checkpointing with respect to the trainable parameters of the module (and
the callable's inputs).
Args:
module (mlx.nn.Module): The module for whose parameters we will be
performing gradient checkpointing.
fn (Callable, optional): The function to checkpoint. If not provided it
defaults to the provided module.
Returns:
A callable that saves the inputs and outputs during the forward pass
and recomputes all intermediate states during the backward pass.
"""
if fn is None:
# Capturing module instead of module.__call__ allows someone to
# monkey-patch __call__ later on and the correct method will be used
fn = module
def inner_fn(params, *args, **kwargs):
module.update(params)
return fn(*args, **kwargs)
checkpointed_fn = mx.checkpoint(inner_fn)
@wraps(fn)
def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)
return wrapped_checkpointed_fn

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <pybind11/functional.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <algorithm> #include <algorithm>
@ -142,7 +141,8 @@ std::vector<array> tree_flatten(py::object tree, bool strict = true) {
if (py::isinstance<array>(obj)) { if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj)); flat_tree.push_back(py::cast<array>(obj));
} else if (strict) { } else if (strict) {
throw std::invalid_argument("Argument is not an array"); throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
} }
}); });
@ -162,12 +162,48 @@ py::object tree_unflatten(
}); });
} }
py::object tree_unflatten_none( py::object structure_sentinel() {
static py::object sentinel;
if (sentinel.ptr() == nullptr) {
sentinel = py::capsule(&sentinel);
// probably not needed but this should make certain that we won't ever
// delete the sentinel
sentinel.inc_ref();
}
return sentinel;
}
std::pair<std::vector<array>, py::object> tree_flatten_with_structure(
py::object tree, py::object tree,
bool strict = true) {
auto sentinel = structure_sentinel();
std::vector<array> flat_tree;
auto structure = tree_map(
tree,
[&flat_tree, sentinel = std::move(sentinel), strict](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
return sentinel;
} else if (!strict) {
return py::cast<py::object>(obj);
} else {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
return {flat_tree, structure};
}
py::object tree_unflatten_from_structure(
py::object structure,
const std::vector<array>& values, const std::vector<array>& values,
int index = 0) { int index = 0) {
return tree_map(tree, [&](py::handle obj) { auto sentinel = structure_sentinel();
if (py::isinstance<py::none>(obj)) { return tree_map(structure, [&](py::handle obj) {
if (obj.is(sentinel)) {
return py::cast(values[index++]); return py::cast(values[index++]);
} else { } else {
return py::cast<py::object>(obj); return py::cast<py::object>(obj);
@ -472,14 +508,10 @@ struct PyCompiledFun {
py::object operator()(const py::args& args) { py::object operator()(const py::args& args) {
auto compile_fun = [this, &args](const std::vector<array>& a) { auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function // Call the python function and flatten the outputs
py::object py_outputs = this->fun(*tree_unflatten(args, a)); auto [outputs, py_outputs] = tree_flatten_with_structure(
std::move(this->fun(*tree_unflatten(args, a))), true);
// Flatten the outputs
auto outputs = tree_flatten(py_outputs, true);
py_outputs =
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
tree_cache().insert({this->fun_id, py_outputs}); tree_cache().insert({this->fun_id, py_outputs});
return outputs; return outputs;
}; };
@ -492,15 +524,75 @@ struct PyCompiledFun {
// Put the outputs back in the container // Put the outputs back in the container
py::object py_outputs = tree_cache().at(fun_id); py::object py_outputs = tree_cache().at(fun_id);
return tree_unflatten_none(py_outputs, outputs); return tree_unflatten_from_structure(py_outputs, outputs);
}; };
~PyCompiledFun() { ~PyCompiledFun() {
py::gil_scoped_acquire gil;
tree_cache().erase(fun_id); tree_cache().erase(fun_id);
detail::compile_erase(fun_id); detail::compile_erase(fun_id);
fun.release().dec_ref();
} }
}; };
class PyCheckpointedFun {
public:
PyCheckpointedFun(py::function fun) : fun_(std::move(fun)) {}
~PyCheckpointedFun() {
py::gil_scoped_acquire gil;
fun_.release().dec_ref();
}
struct InnerFunction {
py::object fun_;
py::object args_structure_;
std::weak_ptr<py::object> output_structure_;
InnerFunction(
py::object fun,
py::object args_structure,
std::weak_ptr<py::object> output_structure)
: fun_(std::move(fun)),
args_structure_(std::move(args_structure)),
output_structure_(output_structure) {}
~InnerFunction() {
py::gil_scoped_acquire gil;
fun_.release().dec_ref();
args_structure_.release().dec_ref();
}
std::vector<array> operator()(const std::vector<array>& inputs) {
auto args = py::cast<py::tuple>(
tree_unflatten_from_structure(args_structure_, inputs));
auto [outputs, output_structure] =
tree_flatten_with_structure(fun_(*args[0], **args[1]), false);
if (auto s = output_structure_.lock()) {
*s = output_structure;
}
return outputs;
}
};
py::object operator()(const py::args& args, const py::kwargs& kwargs) {
auto output_structure = std::make_shared<py::object>();
auto full_args = py::make_tuple(args, kwargs);
auto [inputs, args_structure] =
tree_flatten_with_structure(full_args, false);
auto outputs = checkpoint(
InnerFunction(fun_, args_structure, output_structure))(inputs);
return tree_unflatten_from_structure(*output_structure, outputs);
}
private:
py::function fun_;
};
void init_transforms(py::module_& m) { void init_transforms(py::module_& m) {
py::options options; py::options options;
options.disable_function_signatures(); options.disable_function_signatures();
@ -802,6 +894,10 @@ void init_transforms(py::module_& m) {
Globally enable compilation. This will override the environment Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILE`` if set. variable ``MLX_DISABLE_COMPILE`` if set.
)pbdoc"); )pbdoc");
m.def(
"checkpoint",
[](py::function fun) { return py::cpp_function(PyCheckpointedFun{fun}); },
"fun"_a);
// Register static Python object cleanup before the interpreter exits // Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit"); auto atexit = py::module_::import("atexit");

View File

@ -21,6 +21,7 @@ target_sources(tests PRIVATE
autograd_tests.cpp autograd_tests.cpp
blas_tests.cpp blas_tests.cpp
compile_tests.cpp compile_tests.cpp
custom_vjp_tests.cpp
creations_tests.cpp creations_tests.cpp
device_tests.cpp device_tests.cpp
eval_tests.cpp eval_tests.cpp

View File

@ -0,0 +1,57 @@
// Copyright © 2023-2024 Apple Inc.
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test simple custom vjp") {
auto one = array(1.0);
auto x = array(2.0);
auto y = array(3.0);
auto fn = [](const std::vector<array>& inputs) {
return std::vector<array>{inputs[0] * inputs[1], inputs[0] + inputs[1]};
};
auto transformed_fn = custom_vjp(
fn,
[&](const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&) {
return std::vector<array>{one, one};
});
auto [z, g] = vjp(fn, {x, y}, {one, one});
CHECK_EQ(z[0].item<float>(), 6.0f);
CHECK_EQ(z[1].item<float>(), 5.0f);
CHECK_EQ(g[0].item<float>(), 4.0f);
CHECK_EQ(g[1].item<float>(), 3.0f);
std::tie(z, g) = vjp(transformed_fn, {x, y}, {one, one});
CHECK_EQ(z[0].item<float>(), 6.0f);
CHECK_EQ(z[1].item<float>(), 5.0f);
CHECK_EQ(g[0].item<float>(), 1.0f);
CHECK_EQ(g[1].item<float>(), 1.0f);
}
TEST_CASE("test checkpointing") {
auto one = array(1.0);
auto x = array(2.0);
auto y = array(3.0);
int cnt = 0;
auto fn = [&cnt](const std::vector<array>& inputs) {
cnt++;
auto x = inputs[0] * inputs[1];
auto y = inputs[0] + inputs[1];
return std::vector<array>{square(x + y)};
};
auto checkpointed_fn = checkpoint(fn);
auto [z, g] = vjp(checkpointed_fn, {x, y}, {one});
CHECK_EQ(z[0].item<float>(), 121.0f);
CHECK_EQ(g[0].item<float>(), 88.0f);
CHECK_EQ(g[1].item<float>(), 66.0f);
CHECK_EQ(cnt, 2);
}