Limit grad recursion depth by not recursing through non-grad inputs (#1764)

* limit grad recursion depth

* add grad of module test
This commit is contained in:
Awni Hannun 2025-01-14 14:33:18 -08:00 committed by GitHub
parent 5cc5201914
commit 33421c1dd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 136 additions and 100 deletions

View File

@ -278,7 +278,8 @@ void eval(std::vector<array> outputs) {
std::pair<std::vector<array>, std::vector<array>> vjp( std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun, const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotans) { const std::vector<array>& cotans,
const std::vector<int>& argnums) {
// Set the global tracing flag. // Set the global tracing flag.
detail::InTracing in_tracing; detail::InTracing in_tracing;
@ -330,10 +331,14 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
// to the tape which need a gradient. // to the tape which need a gradient.
std::unordered_set<std::uintptr_t> cache; std::unordered_set<std::uintptr_t> cache;
std::unordered_set<std::uintptr_t> calc_grad; std::unordered_set<std::uintptr_t> calc_grad;
for (auto& primal : primals_) { for (int i = 0, j = 0; i < primals_.size(); ++i) {
auto& primal = primals_[i];
primal.set_tracer(false); primal.set_tracer(false);
calc_grad.insert(primal.id());
cache.insert(primal.id()); cache.insert(primal.id());
if (j < argnums.size() && argnums[j] == i) {
j++;
calc_grad.insert(primal.id());
}
} }
std::vector<array> tape; std::vector<array> tape;
@ -435,7 +440,8 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
} }
} }
std::vector<array> vjps; std::vector<array> vjps;
for (auto& primal : primals_) { for (auto arg : argnums) {
auto& primal = primals_[arg];
if (auto cotan_it = cotan_map.find(primal.id()); if (auto cotan_it = cotan_map.find(primal.id());
cotan_it != cotan_map.end()) { cotan_it != cotan_map.end()) {
vjps.push_back(cotan_it->second); vjps.push_back(cotan_it->second);
@ -448,6 +454,15 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
return {outputs, vjps}; return {outputs, vjps};
} }
std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& cotans) {
std::vector<int> argnums(primals.size());
std::iota(argnums.begin(), argnums.end(), 0);
return vjp(fun, primals, cotans, argnums);
}
std::pair<array, array> vjp( std::pair<array, array> vjp(
const std::function<array(const array&)>& fun, const std::function<array(const array&)>& fun,
const array& primal, const array& primal,
@ -606,15 +621,10 @@ ValueAndGradFn value_and_grad(
<< inputs.size() << " inputs."; << inputs.size() << " inputs.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
std::vector<int> sorted_argnums(args.begin(), args.end());
auto gfun = [&fun, &inputs, &args](const std::vector<array>& ginputs) { auto gfun = [&fun](const std::vector<array>& inputs) {
std::vector<array> inputs_(inputs); auto outputs = fun(inputs);
auto argit = args.begin();
for (int i = 0; i < ginputs.size(); ++i) {
inputs_[*argit] = ginputs[i];
++argit;
}
auto outputs = fun(inputs_);
for (int i = 1; i < outputs.size(); i++) { for (int i = 1; i < outputs.size(); i++) {
auto& out = outputs[i]; auto& out = outputs[i];
auto s = out.has_primitive() ? out.primitive().stream() auto s = out.has_primitive() ? out.primitive().stream()
@ -624,12 +634,8 @@ ValueAndGradFn value_and_grad(
return outputs; return outputs;
}; };
std::vector<array> ginputs;
for (auto arg : args) {
ginputs.push_back(inputs[arg]);
}
// Set the incoming gradient to float32, vjp will cast it to the output type // Set the incoming gradient to float32, vjp will cast it to the output type
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)}); auto [outputs, grads] = vjp(gfun, inputs, {array(1.0f)}, sorted_argnums);
return std::make_pair(outputs, grads); return std::make_pair(outputs, grads);
}; };
} }

View File

@ -1,16 +1,18 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <numeric>
#include <sstream>
#include <unordered_set>
#include <nanobind/nanobind.h> #include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h> #include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h> #include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h> #include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_set.h>
#include <nanobind/stl/variant.h> #include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h> #include <nanobind/stl/vector.h>
#include <algorithm>
#include <numeric>
#include <sstream>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/compile.h" #include "mlx/compile.h"
#include "mlx/compile_impl.h" #include "mlx/compile_impl.h"
@ -27,44 +29,45 @@ using namespace nb::literals;
using mx::operator<<; using mx::operator<<;
using IntOrVec = std::variant<int, std::vector<int>>; using IntOrVec = std::variant<int, std::vector<int>>;
using StrOrVec = std::variant<std::string, std::vector<std::string>>; using StrOrSet = std::variant<std::string, std::unordered_set<std::string>>;
inline std::string type_name_str(const nb::handle& o) { inline std::string type_name_str(const nb::handle& o) {
return nb::cast<std::string>(nb::type_name(o.type())); return nb::cast<std::string>(nb::type_name(o.type()));
} }
template <typename T>
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
std::vector<T> vals;
if (auto pv = std::get_if<T>(&v); pv) {
vals.push_back(*pv);
} else {
vals = std::get<std::vector<T>>(v);
}
return vals;
}
auto validate_argnums_argnames( auto validate_argnums_argnames(
const std::optional<IntOrVec>& argnums, const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) { const StrOrSet& argnames) {
auto vec_names = to_vector(argnames); std::unordered_set<std::string> setnames;
if (auto pv = std::get_if<std::string>(&argnames); pv) {
setnames = {*pv};
} else {
setnames = std::get<std::unordered_set<std::string>>(argnames);
}
if (!argnums.has_value()) { if (!argnums.has_value()) {
// argnums was not provided and argnames was empty // argnums was not provided and argnames was empty
if (vec_names.empty()) { if (setnames.empty()) {
return std::make_pair(std::vector<int>{0}, vec_names); return std::make_pair(std::vector<int>{0}, setnames);
} else { } else {
return std::make_pair(std::vector<int>{}, vec_names); return std::make_pair(std::vector<int>{}, setnames);
} }
} }
return std::make_pair(to_vector(*argnums), vec_names); std::vector<int> vecnums;
if (auto pv = std::get_if<int>(&(*argnums)); pv) {
vecnums = {*pv};
} else {
vecnums = std::get<std::vector<int>>(*argnums);
}
return std::make_pair(vecnums, setnames);
} }
auto py_value_and_grad( auto py_value_and_grad(
const nb::callable& fun, const nb::callable& fun,
std::vector<int> argnums, std::vector<int> argnums,
std::vector<std::string> argnames, std::unordered_set<std::string> argnames,
const std::string& error_msg_tag, const std::string& error_msg_tag,
bool scalar_func_only) { bool scalar_func_only) {
// Sanitize argnums // Sanitize argnums
@ -72,7 +75,7 @@ auto py_value_and_grad(
throw std::invalid_argument( throw std::invalid_argument(
error_msg_tag + " Gradient wrt no argument requested"); error_msg_tag + " Gradient wrt no argument requested");
} }
if (argnums.size() > 0) { for (auto arg : argnums) {
std::sort(argnums.begin(), argnums.end()); std::sort(argnums.begin(), argnums.end());
if (argnums[0] < 0) { if (argnums[0] < 0) {
std::ostringstream msg; std::ostringstream msg;
@ -81,10 +84,18 @@ auto py_value_and_grad(
<< argnums[0]; << argnums[0];
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
for (int i = 1; i < argnums.size(); ++i) {
if (argnums[i] == argnums[i - 1]) {
std::ostringstream msg;
msg << error_msg_tag << " Duplicate argument index " << argnums[0]
<< " is not allowed.";
throw std::invalid_argument(msg.str());
}
}
} }
return [fun, argnums, argnames, error_msg_tag, scalar_func_only]( return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
const nb::args& args, const nb::kwargs& kwargs) { nb::args& args, nb::kwargs& kwargs) {
// Sanitize the input // Sanitize the input
if (argnums.size() > 0 && argnums.back() >= args.size()) { if (argnums.size() > 0 && argnums.back() >= args.size()) {
std::ostringstream msg; std::ostringstream msg;
@ -112,59 +123,59 @@ auto py_value_and_grad(
// Collect the arrays // Collect the arrays
std::vector<mx::array> arrays; std::vector<mx::array> arrays;
std::vector<int> counts(1, 0); std::vector<int> counts(1, 0);
for (auto i : argnums) { std::vector<int> gradient_indices;
auto argsi = tree_flatten(args[i]); for (int i = 0, j = 0; i < args.size(); ++i) {
bool needs_grad = (j < argnums.size() && argnums[j] == i);
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
if (needs_grad) {
auto old_size = gradient_indices.size();
gradient_indices.resize(old_size + argsi.size());
std::iota(
gradient_indices.begin() + old_size,
gradient_indices.end(),
arrays.size());
j++;
counts.push_back(argsi.size());
}
arrays.insert(arrays.end(), argsi.begin(), argsi.end()); arrays.insert(arrays.end(), argsi.begin(), argsi.end());
counts.push_back(argsi.size());
} }
for (auto& key : argnames) { for (auto item : kwargs) {
auto argsk = tree_flatten(kwargs[key.c_str()]); bool needs_grad =
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
auto argsk = tree_flatten(item.second, /* strict = */ needs_grad);
if (needs_grad) {
auto old_size = gradient_indices.size();
gradient_indices.resize(old_size + argsk.size());
std::iota(
gradient_indices.begin() + old_size,
gradient_indices.end(),
arrays.size());
counts.push_back(argsk.size());
}
arrays.insert(arrays.end(), argsk.begin(), argsk.end()); arrays.insert(arrays.end(), argsk.begin(), argsk.end());
counts.push_back(argsk.size());
} }
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin()); std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
std::vector<int> gradient_indices(arrays.size());
std::iota(gradient_indices.begin(), gradient_indices.end(), 0);
// value_out will hold the output of the python function in order to be // value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values // able to reconstruct the python tree of extra return values
nb::object py_value_out; nb::object py_value_out;
auto value_and_grads = mx::value_and_grad( auto value_and_grads = mx::value_and_grad(
[&fun, [&fun,
&arrays,
&args, &args,
&kwargs, &kwargs,
&argnums,
&argnames,
&counts,
&py_value_out, &py_value_out,
&error_msg_tag, &error_msg_tag,
scalar_func_only](const std::vector<mx::array>& a) { scalar_func_only](const std::vector<mx::array>& a) {
// Copy the arguments nb::list tree;
nb::list args_cpy; tree.append(args);
nb::kwargs kwargs_cpy = nb::kwargs(); tree.append(kwargs);
int j = 0; tree_replace(tree, arrays, a);
for (int i = 0; i < args.size(); ++i) {
if (j < argnums.size() && i == argnums[j]) {
args_cpy.append(tree_unflatten(args[i], a, counts[j]));
j++;
} else {
args_cpy.append(args[i]);
}
}
for (auto& key : argnames) {
kwargs_cpy[key.c_str()] =
tree_unflatten(kwargs[key.c_str()], a, counts[j]);
j++;
}
for (auto item : kwargs) {
if (kwargs_cpy.contains(item.first)) {
continue;
}
kwargs_cpy[item.first] = item.second;
}
// Call the python function // Call the python function
py_value_out = fun(*args_cpy, **kwargs_cpy); py_value_out = fun(*tree[0], **tree[1]);
tree_replace(tree, arrays, a);
// Validate the return value of the python function // Validate the return value of the python function
if (!nb::isinstance<mx::array>(py_value_out)) { if (!nb::isinstance<mx::array>(py_value_out)) {
@ -247,10 +258,13 @@ auto py_value_and_grad(
py_grads = positional_grads; py_grads = positional_grads;
} else { } else {
nb::dict grads_; nb::dict grads_;
for (int i = 0; i < argnames.size(); i++) { int i = 0;
auto& k = argnames[i]; for (auto item : kwargs) {
grads_[k.c_str()] = tree_unflatten( auto k = nb::cast<std::string>(item.first);
kwargs[k.c_str()], gradients, counts[i + argnums.size()]); if (argnames.find(k) != argnames.end()) {
grads_[k.c_str()] = tree_unflatten(
nb::borrow(item.second), gradients, counts[i++ + argnums.size()]);
}
} }
keyword_grads = grads_; keyword_grads = grads_;
@ -1207,17 +1221,17 @@ void init_transforms(nb::module_& m) {
"value_and_grad", "value_and_grad",
[](const nb::callable& fun, [](const nb::callable& fun,
const std::optional<IntOrVec>& argnums, const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) { const StrOrSet& argnames) {
auto [argnums_vec, argnames_vec] = auto [argnums_vec, argnames_set] =
validate_argnums_argnames(argnums, argnames); validate_argnums_argnames(argnums, argnames);
return nb::cpp_function(py_value_and_grad( return nb::cpp_function(py_value_and_grad(
fun, argnums_vec, argnames_vec, "[value_and_grad]", false)); fun, argnums_vec, argnames_set, "[value_and_grad]", false));
}, },
"fun"_a, "fun"_a,
"argnums"_a = nb::none(), "argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{}, "argnames"_a = std::vector<std::string>{},
nb::sig( nb::sig(
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"), "def value_and_grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"),
R"pbdoc( R"pbdoc(
Returns a function which computes the value and gradient of ``fun``. Returns a function which computes the value and gradient of ``fun``.
@ -1271,21 +1285,20 @@ void init_transforms(nb::module_& m) {
"grad", "grad",
[](const nb::callable& fun, [](const nb::callable& fun,
const std::optional<IntOrVec>& argnums, const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) { const StrOrSet& argnames) {
auto [argnums_vec, argnames_vec] = auto [argnums_vec, argnames_set] =
validate_argnums_argnames(argnums, argnames); validate_argnums_argnames(argnums, argnames);
auto fn = auto fn =
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true); py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true);
return nb::cpp_function( return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) {
[fn](const nb::args& args, const nb::kwargs& kwargs) { return fn(args, kwargs).second;
return fn(args, kwargs).second; });
});
}, },
"fun"_a, "fun"_a,
"argnums"_a = nb::none(), "argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{}, "argnames"_a = std::vector<std::string>{},
nb::sig( nb::sig(
"def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"), "def grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"),
R"pbdoc( R"pbdoc(
Returns a function which computes the gradient of ``fun``. Returns a function which computes the gradient of ``fun``.

View File

@ -146,7 +146,7 @@ void tree_visit(
return recurse(trees); return recurse(trees);
} }
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) { void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) {
std::function<void(nb::handle)> recurse; std::function<void(nb::handle)> recurse;
recurse = [&](nb::handle subtree) { recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree) || if (nb::isinstance<nb::list>(subtree) ||
@ -178,10 +178,11 @@ void tree_visit_update(
} }
return nb::cast<nb::object>(l); return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtree)) { } else if (nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) { nb::list l(subtree);
recurse(item); for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
} }
return nb::cast<nb::object>(subtree); return nb::cast<nb::object>(nb::tuple(l));
} else if (nb::isinstance<nb::dict>(subtree)) { } else if (nb::isinstance<nb::dict>(subtree)) {
auto d = nb::cast<nb::dict>(subtree); auto d = nb::cast<nb::dict>(subtree);
for (auto item : d) { for (auto item : d) {
@ -224,7 +225,7 @@ void tree_replace(
}); });
} }
std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) { std::vector<mx::array> tree_flatten(nb::handle tree, bool strict /* = true */) {
std::vector<mx::array> flat_tree; std::vector<mx::array> flat_tree;
tree_visit(tree, [&](nb::handle obj) { tree_visit(tree, [&](nb::handle obj) {

View File

@ -10,7 +10,7 @@ namespace nb = nanobind;
void tree_visit( void tree_visit(
const std::vector<nb::object>& trees, const std::vector<nb::object>& trees,
std::function<void(const std::vector<nb::object>&)> visitor); std::function<void(const std::vector<nb::object>&)> visitor);
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor); void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor);
nb::object tree_map( nb::object tree_map(
const std::vector<nb::object>& trees, const std::vector<nb::object>& trees,
@ -42,7 +42,7 @@ void tree_replace(
* Flatten a tree into a vector of arrays. If strict is true, then the * Flatten a tree into a vector of arrays. If strict is true, then the
* function will throw if the tree contains a leaf which is not an array. * function will throw if the tree contains a leaf which is not an array.
*/ */
std::vector<mx::array> tree_flatten(nb::object tree, bool strict = true); std::vector<mx::array> tree_flatten(nb::handle tree, bool strict = true);
/** /**
* Unflatten a tree from a vector of arrays. * Unflatten a tree from a vector of arrays.

View File

@ -139,6 +139,8 @@ class TestAutograd(mlx_tests.MLXTestCase):
mx.value_and_grad(fun, (None, None)) mx.value_and_grad(fun, (None, None))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.value_and_grad(fun, tuple()) mx.value_and_grad(fun, tuple())
with self.assertRaises(ValueError):
mx.grad(fun, argnums=(0, 0))
def test_auxiliary_values(self): def test_auxiliary_values(self):
def fun(x, y): def fun(x, y):

View File

@ -195,6 +195,20 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertTrue(isinstance(m.layers[1], nn.ReLU)) self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
def test_grad_of_module(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.Linear(3, 3)
model = Model()
def loss_fn(model):
return model.m1(x).sum()
x = mx.zeros((3,))
mx.grad(loss_fn)(model)
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):