mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth * add grad of module test
This commit is contained in:
parent
5cc5201914
commit
33421c1dd3
@ -278,7 +278,8 @@ void eval(std::vector<array> outputs) {
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotans) {
|
||||
const std::vector<array>& cotans,
|
||||
const std::vector<int>& argnums) {
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
@ -330,10 +331,14 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
// to the tape which need a gradient.
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_set<std::uintptr_t> calc_grad;
|
||||
for (auto& primal : primals_) {
|
||||
for (int i = 0, j = 0; i < primals_.size(); ++i) {
|
||||
auto& primal = primals_[i];
|
||||
primal.set_tracer(false);
|
||||
calc_grad.insert(primal.id());
|
||||
cache.insert(primal.id());
|
||||
if (j < argnums.size() && argnums[j] == i) {
|
||||
j++;
|
||||
calc_grad.insert(primal.id());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> tape;
|
||||
@ -435,7 +440,8 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
}
|
||||
}
|
||||
std::vector<array> vjps;
|
||||
for (auto& primal : primals_) {
|
||||
for (auto arg : argnums) {
|
||||
auto& primal = primals_[arg];
|
||||
if (auto cotan_it = cotan_map.find(primal.id());
|
||||
cotan_it != cotan_map.end()) {
|
||||
vjps.push_back(cotan_it->second);
|
||||
@ -448,6 +454,15 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
return {outputs, vjps};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotans) {
|
||||
std::vector<int> argnums(primals.size());
|
||||
std::iota(argnums.begin(), argnums.end(), 0);
|
||||
return vjp(fun, primals, cotans, argnums);
|
||||
}
|
||||
|
||||
std::pair<array, array> vjp(
|
||||
const std::function<array(const array&)>& fun,
|
||||
const array& primal,
|
||||
@ -606,15 +621,10 @@ ValueAndGradFn value_and_grad(
|
||||
<< inputs.size() << " inputs.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
std::vector<int> sorted_argnums(args.begin(), args.end());
|
||||
|
||||
auto gfun = [&fun, &inputs, &args](const std::vector<array>& ginputs) {
|
||||
std::vector<array> inputs_(inputs);
|
||||
auto argit = args.begin();
|
||||
for (int i = 0; i < ginputs.size(); ++i) {
|
||||
inputs_[*argit] = ginputs[i];
|
||||
++argit;
|
||||
}
|
||||
auto outputs = fun(inputs_);
|
||||
auto gfun = [&fun](const std::vector<array>& inputs) {
|
||||
auto outputs = fun(inputs);
|
||||
for (int i = 1; i < outputs.size(); i++) {
|
||||
auto& out = outputs[i];
|
||||
auto s = out.has_primitive() ? out.primitive().stream()
|
||||
@ -624,12 +634,8 @@ ValueAndGradFn value_and_grad(
|
||||
return outputs;
|
||||
};
|
||||
|
||||
std::vector<array> ginputs;
|
||||
for (auto arg : args) {
|
||||
ginputs.push_back(inputs[arg]);
|
||||
}
|
||||
// Set the incoming gradient to float32, vjp will cast it to the output type
|
||||
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
|
||||
auto [outputs, grads] = vjp(gfun, inputs, {array(1.0f)}, sorted_argnums);
|
||||
return std::make_pair(outputs, grads);
|
||||
};
|
||||
}
|
||||
|
@ -1,16 +1,18 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/unordered_set.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/compile_impl.h"
|
||||
@ -27,44 +29,45 @@ using namespace nb::literals;
|
||||
using mx::operator<<;
|
||||
|
||||
using IntOrVec = std::variant<int, std::vector<int>>;
|
||||
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
|
||||
using StrOrSet = std::variant<std::string, std::unordered_set<std::string>>;
|
||||
|
||||
inline std::string type_name_str(const nb::handle& o) {
|
||||
return nb::cast<std::string>(nb::type_name(o.type()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
|
||||
std::vector<T> vals;
|
||||
if (auto pv = std::get_if<T>(&v); pv) {
|
||||
vals.push_back(*pv);
|
||||
} else {
|
||||
vals = std::get<std::vector<T>>(v);
|
||||
}
|
||||
return vals;
|
||||
}
|
||||
|
||||
auto validate_argnums_argnames(
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto vec_names = to_vector(argnames);
|
||||
const StrOrSet& argnames) {
|
||||
std::unordered_set<std::string> setnames;
|
||||
if (auto pv = std::get_if<std::string>(&argnames); pv) {
|
||||
setnames = {*pv};
|
||||
} else {
|
||||
setnames = std::get<std::unordered_set<std::string>>(argnames);
|
||||
}
|
||||
|
||||
if (!argnums.has_value()) {
|
||||
// argnums was not provided and argnames was empty
|
||||
if (vec_names.empty()) {
|
||||
return std::make_pair(std::vector<int>{0}, vec_names);
|
||||
if (setnames.empty()) {
|
||||
return std::make_pair(std::vector<int>{0}, setnames);
|
||||
} else {
|
||||
return std::make_pair(std::vector<int>{}, vec_names);
|
||||
return std::make_pair(std::vector<int>{}, setnames);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(to_vector(*argnums), vec_names);
|
||||
std::vector<int> vecnums;
|
||||
if (auto pv = std::get_if<int>(&(*argnums)); pv) {
|
||||
vecnums = {*pv};
|
||||
} else {
|
||||
vecnums = std::get<std::vector<int>>(*argnums);
|
||||
}
|
||||
|
||||
return std::make_pair(vecnums, setnames);
|
||||
}
|
||||
|
||||
auto py_value_and_grad(
|
||||
const nb::callable& fun,
|
||||
std::vector<int> argnums,
|
||||
std::vector<std::string> argnames,
|
||||
std::unordered_set<std::string> argnames,
|
||||
const std::string& error_msg_tag,
|
||||
bool scalar_func_only) {
|
||||
// Sanitize argnums
|
||||
@ -72,7 +75,7 @@ auto py_value_and_grad(
|
||||
throw std::invalid_argument(
|
||||
error_msg_tag + " Gradient wrt no argument requested");
|
||||
}
|
||||
if (argnums.size() > 0) {
|
||||
for (auto arg : argnums) {
|
||||
std::sort(argnums.begin(), argnums.end());
|
||||
if (argnums[0] < 0) {
|
||||
std::ostringstream msg;
|
||||
@ -81,10 +84,18 @@ auto py_value_and_grad(
|
||||
<< argnums[0];
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
for (int i = 1; i < argnums.size(); ++i) {
|
||||
if (argnums[i] == argnums[i - 1]) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " Duplicate argument index " << argnums[0]
|
||||
<< " is not allowed.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
|
||||
const nb::args& args, const nb::kwargs& kwargs) {
|
||||
nb::args& args, nb::kwargs& kwargs) {
|
||||
// Sanitize the input
|
||||
if (argnums.size() > 0 && argnums.back() >= args.size()) {
|
||||
std::ostringstream msg;
|
||||
@ -112,59 +123,59 @@ auto py_value_and_grad(
|
||||
// Collect the arrays
|
||||
std::vector<mx::array> arrays;
|
||||
std::vector<int> counts(1, 0);
|
||||
for (auto i : argnums) {
|
||||
auto argsi = tree_flatten(args[i]);
|
||||
std::vector<int> gradient_indices;
|
||||
for (int i = 0, j = 0; i < args.size(); ++i) {
|
||||
bool needs_grad = (j < argnums.size() && argnums[j] == i);
|
||||
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
|
||||
if (needs_grad) {
|
||||
auto old_size = gradient_indices.size();
|
||||
gradient_indices.resize(old_size + argsi.size());
|
||||
std::iota(
|
||||
gradient_indices.begin() + old_size,
|
||||
gradient_indices.end(),
|
||||
arrays.size());
|
||||
j++;
|
||||
counts.push_back(argsi.size());
|
||||
}
|
||||
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
|
||||
counts.push_back(argsi.size());
|
||||
}
|
||||
for (auto& key : argnames) {
|
||||
auto argsk = tree_flatten(kwargs[key.c_str()]);
|
||||
for (auto item : kwargs) {
|
||||
bool needs_grad =
|
||||
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
|
||||
auto argsk = tree_flatten(item.second, /* strict = */ needs_grad);
|
||||
if (needs_grad) {
|
||||
auto old_size = gradient_indices.size();
|
||||
gradient_indices.resize(old_size + argsk.size());
|
||||
std::iota(
|
||||
gradient_indices.begin() + old_size,
|
||||
gradient_indices.end(),
|
||||
arrays.size());
|
||||
counts.push_back(argsk.size());
|
||||
}
|
||||
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
|
||||
counts.push_back(argsk.size());
|
||||
}
|
||||
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
|
||||
std::vector<int> gradient_indices(arrays.size());
|
||||
std::iota(gradient_indices.begin(), gradient_indices.end(), 0);
|
||||
|
||||
// value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
nb::object py_value_out;
|
||||
auto value_and_grads = mx::value_and_grad(
|
||||
[&fun,
|
||||
&arrays,
|
||||
&args,
|
||||
&kwargs,
|
||||
&argnums,
|
||||
&argnames,
|
||||
&counts,
|
||||
&py_value_out,
|
||||
&error_msg_tag,
|
||||
scalar_func_only](const std::vector<mx::array>& a) {
|
||||
// Copy the arguments
|
||||
nb::list args_cpy;
|
||||
nb::kwargs kwargs_cpy = nb::kwargs();
|
||||
int j = 0;
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
args_cpy.append(tree_unflatten(args[i], a, counts[j]));
|
||||
j++;
|
||||
} else {
|
||||
args_cpy.append(args[i]);
|
||||
}
|
||||
}
|
||||
for (auto& key : argnames) {
|
||||
kwargs_cpy[key.c_str()] =
|
||||
tree_unflatten(kwargs[key.c_str()], a, counts[j]);
|
||||
j++;
|
||||
}
|
||||
for (auto item : kwargs) {
|
||||
if (kwargs_cpy.contains(item.first)) {
|
||||
continue;
|
||||
}
|
||||
kwargs_cpy[item.first] = item.second;
|
||||
}
|
||||
nb::list tree;
|
||||
tree.append(args);
|
||||
tree.append(kwargs);
|
||||
tree_replace(tree, arrays, a);
|
||||
|
||||
// Call the python function
|
||||
py_value_out = fun(*args_cpy, **kwargs_cpy);
|
||||
py_value_out = fun(*tree[0], **tree[1]);
|
||||
|
||||
tree_replace(tree, arrays, a);
|
||||
|
||||
// Validate the return value of the python function
|
||||
if (!nb::isinstance<mx::array>(py_value_out)) {
|
||||
@ -247,10 +258,13 @@ auto py_value_and_grad(
|
||||
py_grads = positional_grads;
|
||||
} else {
|
||||
nb::dict grads_;
|
||||
for (int i = 0; i < argnames.size(); i++) {
|
||||
auto& k = argnames[i];
|
||||
grads_[k.c_str()] = tree_unflatten(
|
||||
kwargs[k.c_str()], gradients, counts[i + argnums.size()]);
|
||||
int i = 0;
|
||||
for (auto item : kwargs) {
|
||||
auto k = nb::cast<std::string>(item.first);
|
||||
if (argnames.find(k) != argnames.end()) {
|
||||
grads_[k.c_str()] = tree_unflatten(
|
||||
nb::borrow(item.second), gradients, counts[i++ + argnums.size()]);
|
||||
}
|
||||
}
|
||||
keyword_grads = grads_;
|
||||
|
||||
@ -1207,17 +1221,17 @@ void init_transforms(nb::module_& m) {
|
||||
"value_and_grad",
|
||||
[](const nb::callable& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto [argnums_vec, argnames_vec] =
|
||||
const StrOrSet& argnames) {
|
||||
auto [argnums_vec, argnames_set] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
return nb::cpp_function(py_value_and_grad(
|
||||
fun, argnums_vec, argnames_vec, "[value_and_grad]", false));
|
||||
fun, argnums_vec, argnames_set, "[value_and_grad]", false));
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = nb::none(),
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
nb::sig(
|
||||
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
|
||||
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"),
|
||||
R"pbdoc(
|
||||
Returns a function which computes the value and gradient of ``fun``.
|
||||
|
||||
@ -1271,21 +1285,20 @@ void init_transforms(nb::module_& m) {
|
||||
"grad",
|
||||
[](const nb::callable& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto [argnums_vec, argnames_vec] =
|
||||
const StrOrSet& argnames) {
|
||||
auto [argnums_vec, argnames_set] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
auto fn =
|
||||
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true);
|
||||
return nb::cpp_function(
|
||||
[fn](const nb::args& args, const nb::kwargs& kwargs) {
|
||||
return fn(args, kwargs).second;
|
||||
});
|
||||
py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true);
|
||||
return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) {
|
||||
return fn(args, kwargs).second;
|
||||
});
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = nb::none(),
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
nb::sig(
|
||||
"def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
|
||||
"def grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"),
|
||||
R"pbdoc(
|
||||
Returns a function which computes the gradient of ``fun``.
|
||||
|
||||
|
@ -146,7 +146,7 @@ void tree_visit(
|
||||
return recurse(trees);
|
||||
}
|
||||
|
||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
|
||||
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) {
|
||||
std::function<void(nb::handle)> recurse;
|
||||
recurse = [&](nb::handle subtree) {
|
||||
if (nb::isinstance<nb::list>(subtree) ||
|
||||
@ -178,10 +178,11 @@ void tree_visit_update(
|
||||
}
|
||||
return nb::cast<nb::object>(l);
|
||||
} else if (nb::isinstance<nb::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
nb::list l(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return nb::cast<nb::object>(subtree);
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||
auto d = nb::cast<nb::dict>(subtree);
|
||||
for (auto item : d) {
|
||||
@ -224,7 +225,7 @@ void tree_replace(
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
||||
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict /* = true */) {
|
||||
std::vector<mx::array> flat_tree;
|
||||
|
||||
tree_visit(tree, [&](nb::handle obj) {
|
||||
|
@ -10,7 +10,7 @@ namespace nb = nanobind;
|
||||
void tree_visit(
|
||||
const std::vector<nb::object>& trees,
|
||||
std::function<void(const std::vector<nb::object>&)> visitor);
|
||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
|
||||
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor);
|
||||
|
||||
nb::object tree_map(
|
||||
const std::vector<nb::object>& trees,
|
||||
@ -42,7 +42,7 @@ void tree_replace(
|
||||
* Flatten a tree into a vector of arrays. If strict is true, then the
|
||||
* function will throw if the tree contains a leaf which is not an array.
|
||||
*/
|
||||
std::vector<mx::array> tree_flatten(nb::object tree, bool strict = true);
|
||||
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict = true);
|
||||
|
||||
/**
|
||||
* Unflatten a tree from a vector of arrays.
|
||||
|
@ -139,6 +139,8 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
mx.value_and_grad(fun, (None, None))
|
||||
with self.assertRaises(ValueError):
|
||||
mx.value_and_grad(fun, tuple())
|
||||
with self.assertRaises(ValueError):
|
||||
mx.grad(fun, argnums=(0, 0))
|
||||
|
||||
def test_auxiliary_values(self):
|
||||
def fun(x, y):
|
||||
|
@ -195,6 +195,20 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||
|
||||
def test_grad_of_module(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.m1 = nn.Linear(3, 3)
|
||||
|
||||
model = Model()
|
||||
|
||||
def loss_fn(model):
|
||||
return model.m1(x).sum()
|
||||
|
||||
x = mx.zeros((3,))
|
||||
mx.grad(loss_fn)(model)
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
|
Loading…
Reference in New Issue
Block a user