From 33421c1dd31e3632f9032fadc8363dadc738fffe Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 14 Jan 2025 14:33:18 -0800 Subject: [PATCH] Limit grad recursion depth by not recursing through non-grad inputs (#1764) * limit grad recursion depth * add grad of module test --- mlx/transforms.cpp | 40 +++++---- python/src/transforms.cpp | 165 ++++++++++++++++++---------------- python/src/trees.cpp | 11 +-- python/src/trees.h | 4 +- python/tests/test_autograd.py | 2 + python/tests/test_nn.py | 14 +++ 6 files changed, 136 insertions(+), 100 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 3ac0f8c68..06bdd1cd3 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -278,7 +278,8 @@ void eval(std::vector outputs) { std::pair, std::vector> vjp( const std::function(const std::vector&)>& fun, const std::vector& primals, - const std::vector& cotans) { + const std::vector& cotans, + const std::vector& argnums) { // Set the global tracing flag. detail::InTracing in_tracing; @@ -330,10 +331,14 @@ std::pair, std::vector> vjp( // to the tape which need a gradient. std::unordered_set cache; std::unordered_set calc_grad; - for (auto& primal : primals_) { + for (int i = 0, j = 0; i < primals_.size(); ++i) { + auto& primal = primals_[i]; primal.set_tracer(false); - calc_grad.insert(primal.id()); cache.insert(primal.id()); + if (j < argnums.size() && argnums[j] == i) { + j++; + calc_grad.insert(primal.id()); + } } std::vector tape; @@ -435,7 +440,8 @@ std::pair, std::vector> vjp( } } std::vector vjps; - for (auto& primal : primals_) { + for (auto arg : argnums) { + auto& primal = primals_[arg]; if (auto cotan_it = cotan_map.find(primal.id()); cotan_it != cotan_map.end()) { vjps.push_back(cotan_it->second); @@ -448,6 +454,15 @@ std::pair, std::vector> vjp( return {outputs, vjps}; } +std::pair, std::vector> vjp( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + const std::vector& cotans) { + std::vector argnums(primals.size()); + std::iota(argnums.begin(), argnums.end(), 0); + return vjp(fun, primals, cotans, argnums); +} + std::pair vjp( const std::function& fun, const array& primal, @@ -606,15 +621,10 @@ ValueAndGradFn value_and_grad( << inputs.size() << " inputs."; throw std::invalid_argument(msg.str()); } + std::vector sorted_argnums(args.begin(), args.end()); - auto gfun = [&fun, &inputs, &args](const std::vector& ginputs) { - std::vector inputs_(inputs); - auto argit = args.begin(); - for (int i = 0; i < ginputs.size(); ++i) { - inputs_[*argit] = ginputs[i]; - ++argit; - } - auto outputs = fun(inputs_); + auto gfun = [&fun](const std::vector& inputs) { + auto outputs = fun(inputs); for (int i = 1; i < outputs.size(); i++) { auto& out = outputs[i]; auto s = out.has_primitive() ? out.primitive().stream() @@ -624,12 +634,8 @@ ValueAndGradFn value_and_grad( return outputs; }; - std::vector ginputs; - for (auto arg : args) { - ginputs.push_back(inputs[arg]); - } // Set the incoming gradient to float32, vjp will cast it to the output type - auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)}); + auto [outputs, grads] = vjp(gfun, inputs, {array(1.0f)}, sorted_argnums); return std::make_pair(outputs, grads); }; } diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index fdf1dc0e5..e351b9f68 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1,16 +1,18 @@ // Copyright © 2023-2024 Apple Inc. +#include +#include +#include +#include + #include #include #include #include +#include #include #include -#include -#include -#include - #include "mlx/array.h" #include "mlx/compile.h" #include "mlx/compile_impl.h" @@ -27,44 +29,45 @@ using namespace nb::literals; using mx::operator<<; using IntOrVec = std::variant>; -using StrOrVec = std::variant>; +using StrOrSet = std::variant>; inline std::string type_name_str(const nb::handle& o) { return nb::cast(nb::type_name(o.type())); } -template -std::vector to_vector(const std::variant>& v) { - std::vector vals; - if (auto pv = std::get_if(&v); pv) { - vals.push_back(*pv); - } else { - vals = std::get>(v); - } - return vals; -} - auto validate_argnums_argnames( const std::optional& argnums, - const StrOrVec& argnames) { - auto vec_names = to_vector(argnames); + const StrOrSet& argnames) { + std::unordered_set setnames; + if (auto pv = std::get_if(&argnames); pv) { + setnames = {*pv}; + } else { + setnames = std::get>(argnames); + } if (!argnums.has_value()) { // argnums was not provided and argnames was empty - if (vec_names.empty()) { - return std::make_pair(std::vector{0}, vec_names); + if (setnames.empty()) { + return std::make_pair(std::vector{0}, setnames); } else { - return std::make_pair(std::vector{}, vec_names); + return std::make_pair(std::vector{}, setnames); } } - return std::make_pair(to_vector(*argnums), vec_names); + std::vector vecnums; + if (auto pv = std::get_if(&(*argnums)); pv) { + vecnums = {*pv}; + } else { + vecnums = std::get>(*argnums); + } + + return std::make_pair(vecnums, setnames); } auto py_value_and_grad( const nb::callable& fun, std::vector argnums, - std::vector argnames, + std::unordered_set argnames, const std::string& error_msg_tag, bool scalar_func_only) { // Sanitize argnums @@ -72,7 +75,7 @@ auto py_value_and_grad( throw std::invalid_argument( error_msg_tag + " Gradient wrt no argument requested"); } - if (argnums.size() > 0) { + for (auto arg : argnums) { std::sort(argnums.begin(), argnums.end()); if (argnums[0] < 0) { std::ostringstream msg; @@ -81,10 +84,18 @@ auto py_value_and_grad( << argnums[0]; throw std::invalid_argument(msg.str()); } + for (int i = 1; i < argnums.size(); ++i) { + if (argnums[i] == argnums[i - 1]) { + std::ostringstream msg; + msg << error_msg_tag << " Duplicate argument index " << argnums[0] + << " is not allowed."; + throw std::invalid_argument(msg.str()); + } + } } return [fun, argnums, argnames, error_msg_tag, scalar_func_only]( - const nb::args& args, const nb::kwargs& kwargs) { + nb::args& args, nb::kwargs& kwargs) { // Sanitize the input if (argnums.size() > 0 && argnums.back() >= args.size()) { std::ostringstream msg; @@ -112,59 +123,59 @@ auto py_value_and_grad( // Collect the arrays std::vector arrays; std::vector counts(1, 0); - for (auto i : argnums) { - auto argsi = tree_flatten(args[i]); + std::vector gradient_indices; + for (int i = 0, j = 0; i < args.size(); ++i) { + bool needs_grad = (j < argnums.size() && argnums[j] == i); + auto argsi = tree_flatten(args[i], /* strict = */ needs_grad); + if (needs_grad) { + auto old_size = gradient_indices.size(); + gradient_indices.resize(old_size + argsi.size()); + std::iota( + gradient_indices.begin() + old_size, + gradient_indices.end(), + arrays.size()); + j++; + counts.push_back(argsi.size()); + } arrays.insert(arrays.end(), argsi.begin(), argsi.end()); - counts.push_back(argsi.size()); } - for (auto& key : argnames) { - auto argsk = tree_flatten(kwargs[key.c_str()]); + for (auto item : kwargs) { + bool needs_grad = + (argnames.find(nb::cast(item.first)) != argnames.end()); + auto argsk = tree_flatten(item.second, /* strict = */ needs_grad); + if (needs_grad) { + auto old_size = gradient_indices.size(); + gradient_indices.resize(old_size + argsk.size()); + std::iota( + gradient_indices.begin() + old_size, + gradient_indices.end(), + arrays.size()); + counts.push_back(argsk.size()); + } arrays.insert(arrays.end(), argsk.begin(), argsk.end()); - counts.push_back(argsk.size()); } std::partial_sum(counts.cbegin(), counts.cend(), counts.begin()); - std::vector gradient_indices(arrays.size()); - std::iota(gradient_indices.begin(), gradient_indices.end(), 0); // value_out will hold the output of the python function in order to be // able to reconstruct the python tree of extra return values nb::object py_value_out; auto value_and_grads = mx::value_and_grad( [&fun, + &arrays, &args, &kwargs, - &argnums, - &argnames, - &counts, &py_value_out, &error_msg_tag, scalar_func_only](const std::vector& a) { - // Copy the arguments - nb::list args_cpy; - nb::kwargs kwargs_cpy = nb::kwargs(); - int j = 0; - for (int i = 0; i < args.size(); ++i) { - if (j < argnums.size() && i == argnums[j]) { - args_cpy.append(tree_unflatten(args[i], a, counts[j])); - j++; - } else { - args_cpy.append(args[i]); - } - } - for (auto& key : argnames) { - kwargs_cpy[key.c_str()] = - tree_unflatten(kwargs[key.c_str()], a, counts[j]); - j++; - } - for (auto item : kwargs) { - if (kwargs_cpy.contains(item.first)) { - continue; - } - kwargs_cpy[item.first] = item.second; - } + nb::list tree; + tree.append(args); + tree.append(kwargs); + tree_replace(tree, arrays, a); // Call the python function - py_value_out = fun(*args_cpy, **kwargs_cpy); + py_value_out = fun(*tree[0], **tree[1]); + + tree_replace(tree, arrays, a); // Validate the return value of the python function if (!nb::isinstance(py_value_out)) { @@ -247,10 +258,13 @@ auto py_value_and_grad( py_grads = positional_grads; } else { nb::dict grads_; - for (int i = 0; i < argnames.size(); i++) { - auto& k = argnames[i]; - grads_[k.c_str()] = tree_unflatten( - kwargs[k.c_str()], gradients, counts[i + argnums.size()]); + int i = 0; + for (auto item : kwargs) { + auto k = nb::cast(item.first); + if (argnames.find(k) != argnames.end()) { + grads_[k.c_str()] = tree_unflatten( + nb::borrow(item.second), gradients, counts[i++ + argnums.size()]); + } } keyword_grads = grads_; @@ -1207,17 +1221,17 @@ void init_transforms(nb::module_& m) { "value_and_grad", [](const nb::callable& fun, const std::optional& argnums, - const StrOrVec& argnames) { - auto [argnums_vec, argnames_vec] = + const StrOrSet& argnames) { + auto [argnums_vec, argnames_set] = validate_argnums_argnames(argnums, argnames); return nb::cpp_function(py_value_and_grad( - fun, argnums_vec, argnames_vec, "[value_and_grad]", false)); + fun, argnums_vec, argnames_set, "[value_and_grad]", false)); }, "fun"_a, "argnums"_a = nb::none(), "argnames"_a = std::vector{}, nb::sig( - "def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"), + "def value_and_grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"), R"pbdoc( Returns a function which computes the value and gradient of ``fun``. @@ -1271,21 +1285,20 @@ void init_transforms(nb::module_& m) { "grad", [](const nb::callable& fun, const std::optional& argnums, - const StrOrVec& argnames) { - auto [argnums_vec, argnames_vec] = + const StrOrSet& argnames) { + auto [argnums_vec, argnames_set] = validate_argnums_argnames(argnums, argnames); auto fn = - py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true); - return nb::cpp_function( - [fn](const nb::args& args, const nb::kwargs& kwargs) { - return fn(args, kwargs).second; - }); + py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true); + return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) { + return fn(args, kwargs).second; + }); }, "fun"_a, "argnums"_a = nb::none(), "argnames"_a = std::vector{}, nb::sig( - "def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"), + "def grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"), R"pbdoc( Returns a function which computes the gradient of ``fun``. diff --git a/python/src/trees.cpp b/python/src/trees.cpp index d9fe6d2d3..b75d1187c 100644 --- a/python/src/trees.cpp +++ b/python/src/trees.cpp @@ -146,7 +146,7 @@ void tree_visit( return recurse(trees); } -void tree_visit(nb::object tree, std::function visitor) { +void tree_visit(nb::handle tree, std::function visitor) { std::function recurse; recurse = [&](nb::handle subtree) { if (nb::isinstance(subtree) || @@ -178,10 +178,11 @@ void tree_visit_update( } return nb::cast(l); } else if (nb::isinstance(subtree)) { - for (auto item : subtree) { - recurse(item); + nb::list l(subtree); + for (int i = 0; i < l.size(); ++i) { + l[i] = recurse(l[i]); } - return nb::cast(subtree); + return nb::cast(nb::tuple(l)); } else if (nb::isinstance(subtree)) { auto d = nb::cast(subtree); for (auto item : d) { @@ -224,7 +225,7 @@ void tree_replace( }); } -std::vector tree_flatten(nb::object tree, bool strict /* = true */) { +std::vector tree_flatten(nb::handle tree, bool strict /* = true */) { std::vector flat_tree; tree_visit(tree, [&](nb::handle obj) { diff --git a/python/src/trees.h b/python/src/trees.h index fc146c29d..3faa3e39c 100644 --- a/python/src/trees.h +++ b/python/src/trees.h @@ -10,7 +10,7 @@ namespace nb = nanobind; void tree_visit( const std::vector& trees, std::function&)> visitor); -void tree_visit(nb::object tree, std::function visitor); +void tree_visit(nb::handle tree, std::function visitor); nb::object tree_map( const std::vector& trees, @@ -42,7 +42,7 @@ void tree_replace( * Flatten a tree into a vector of arrays. If strict is true, then the * function will throw if the tree contains a leaf which is not an array. */ -std::vector tree_flatten(nb::object tree, bool strict = true); +std::vector tree_flatten(nb::handle tree, bool strict = true); /** * Unflatten a tree from a vector of arrays. diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 3553824aa..727d3c060 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -139,6 +139,8 @@ class TestAutograd(mlx_tests.MLXTestCase): mx.value_and_grad(fun, (None, None)) with self.assertRaises(ValueError): mx.value_and_grad(fun, tuple()) + with self.assertRaises(ValueError): + mx.grad(fun, argnums=(0, 0)) def test_auxiliary_values(self): def fun(x, y): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 5aa230175..9d632b488 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -195,6 +195,20 @@ class TestBase(mlx_tests.MLXTestCase): self.assertTrue(isinstance(m.layers[1], nn.ReLU)) self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) + def test_grad_of_module(self): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.m1 = nn.Linear(3, 3) + + model = Model() + + def loss_fn(model): + return model.m1(x).sum() + + x = mx.zeros((3,)) + mx.grad(loss_fn)(model) + class TestLayers(mlx_tests.MLXTestCase): def test_identity(self):