mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth * add grad of module test
This commit is contained in:
		| @@ -1,16 +1,18 @@ | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <numeric> | ||||
| #include <sstream> | ||||
| #include <unordered_set> | ||||
|  | ||||
| #include <nanobind/nanobind.h> | ||||
| #include <nanobind/stl/optional.h> | ||||
| #include <nanobind/stl/pair.h> | ||||
| #include <nanobind/stl/string.h> | ||||
| #include <nanobind/stl/unordered_set.h> | ||||
| #include <nanobind/stl/variant.h> | ||||
| #include <nanobind/stl/vector.h> | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <numeric> | ||||
| #include <sstream> | ||||
|  | ||||
| #include "mlx/array.h" | ||||
| #include "mlx/compile.h" | ||||
| #include "mlx/compile_impl.h" | ||||
| @@ -27,44 +29,45 @@ using namespace nb::literals; | ||||
| using mx::operator<<; | ||||
|  | ||||
| using IntOrVec = std::variant<int, std::vector<int>>; | ||||
| using StrOrVec = std::variant<std::string, std::vector<std::string>>; | ||||
| using StrOrSet = std::variant<std::string, std::unordered_set<std::string>>; | ||||
|  | ||||
| inline std::string type_name_str(const nb::handle& o) { | ||||
|   return nb::cast<std::string>(nb::type_name(o.type())); | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) { | ||||
|   std::vector<T> vals; | ||||
|   if (auto pv = std::get_if<T>(&v); pv) { | ||||
|     vals.push_back(*pv); | ||||
|   } else { | ||||
|     vals = std::get<std::vector<T>>(v); | ||||
|   } | ||||
|   return vals; | ||||
| } | ||||
|  | ||||
| auto validate_argnums_argnames( | ||||
|     const std::optional<IntOrVec>& argnums, | ||||
|     const StrOrVec& argnames) { | ||||
|   auto vec_names = to_vector(argnames); | ||||
|     const StrOrSet& argnames) { | ||||
|   std::unordered_set<std::string> setnames; | ||||
|   if (auto pv = std::get_if<std::string>(&argnames); pv) { | ||||
|     setnames = {*pv}; | ||||
|   } else { | ||||
|     setnames = std::get<std::unordered_set<std::string>>(argnames); | ||||
|   } | ||||
|  | ||||
|   if (!argnums.has_value()) { | ||||
|     // argnums was not provided and argnames was empty | ||||
|     if (vec_names.empty()) { | ||||
|       return std::make_pair(std::vector<int>{0}, vec_names); | ||||
|     if (setnames.empty()) { | ||||
|       return std::make_pair(std::vector<int>{0}, setnames); | ||||
|     } else { | ||||
|       return std::make_pair(std::vector<int>{}, vec_names); | ||||
|       return std::make_pair(std::vector<int>{}, setnames); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return std::make_pair(to_vector(*argnums), vec_names); | ||||
|   std::vector<int> vecnums; | ||||
|   if (auto pv = std::get_if<int>(&(*argnums)); pv) { | ||||
|     vecnums = {*pv}; | ||||
|   } else { | ||||
|     vecnums = std::get<std::vector<int>>(*argnums); | ||||
|   } | ||||
|  | ||||
|   return std::make_pair(vecnums, setnames); | ||||
| } | ||||
|  | ||||
| auto py_value_and_grad( | ||||
|     const nb::callable& fun, | ||||
|     std::vector<int> argnums, | ||||
|     std::vector<std::string> argnames, | ||||
|     std::unordered_set<std::string> argnames, | ||||
|     const std::string& error_msg_tag, | ||||
|     bool scalar_func_only) { | ||||
|   // Sanitize argnums | ||||
| @@ -72,7 +75,7 @@ auto py_value_and_grad( | ||||
|     throw std::invalid_argument( | ||||
|         error_msg_tag + " Gradient wrt no argument requested"); | ||||
|   } | ||||
|   if (argnums.size() > 0) { | ||||
|   for (auto arg : argnums) { | ||||
|     std::sort(argnums.begin(), argnums.end()); | ||||
|     if (argnums[0] < 0) { | ||||
|       std::ostringstream msg; | ||||
| @@ -81,10 +84,18 @@ auto py_value_and_grad( | ||||
|           << argnums[0]; | ||||
|       throw std::invalid_argument(msg.str()); | ||||
|     } | ||||
|     for (int i = 1; i < argnums.size(); ++i) { | ||||
|       if (argnums[i] == argnums[i - 1]) { | ||||
|         std::ostringstream msg; | ||||
|         msg << error_msg_tag << " Duplicate argument index " << argnums[0] | ||||
|             << " is not allowed."; | ||||
|         throw std::invalid_argument(msg.str()); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   return [fun, argnums, argnames, error_msg_tag, scalar_func_only]( | ||||
|              const nb::args& args, const nb::kwargs& kwargs) { | ||||
|              nb::args& args, nb::kwargs& kwargs) { | ||||
|     // Sanitize the input | ||||
|     if (argnums.size() > 0 && argnums.back() >= args.size()) { | ||||
|       std::ostringstream msg; | ||||
| @@ -112,59 +123,59 @@ auto py_value_and_grad( | ||||
|     // Collect the arrays | ||||
|     std::vector<mx::array> arrays; | ||||
|     std::vector<int> counts(1, 0); | ||||
|     for (auto i : argnums) { | ||||
|       auto argsi = tree_flatten(args[i]); | ||||
|     std::vector<int> gradient_indices; | ||||
|     for (int i = 0, j = 0; i < args.size(); ++i) { | ||||
|       bool needs_grad = (j < argnums.size() && argnums[j] == i); | ||||
|       auto argsi = tree_flatten(args[i], /* strict = */ needs_grad); | ||||
|       if (needs_grad) { | ||||
|         auto old_size = gradient_indices.size(); | ||||
|         gradient_indices.resize(old_size + argsi.size()); | ||||
|         std::iota( | ||||
|             gradient_indices.begin() + old_size, | ||||
|             gradient_indices.end(), | ||||
|             arrays.size()); | ||||
|         j++; | ||||
|         counts.push_back(argsi.size()); | ||||
|       } | ||||
|       arrays.insert(arrays.end(), argsi.begin(), argsi.end()); | ||||
|       counts.push_back(argsi.size()); | ||||
|     } | ||||
|     for (auto& key : argnames) { | ||||
|       auto argsk = tree_flatten(kwargs[key.c_str()]); | ||||
|     for (auto item : kwargs) { | ||||
|       bool needs_grad = | ||||
|           (argnames.find(nb::cast<std::string>(item.first)) != argnames.end()); | ||||
|       auto argsk = tree_flatten(item.second, /* strict = */ needs_grad); | ||||
|       if (needs_grad) { | ||||
|         auto old_size = gradient_indices.size(); | ||||
|         gradient_indices.resize(old_size + argsk.size()); | ||||
|         std::iota( | ||||
|             gradient_indices.begin() + old_size, | ||||
|             gradient_indices.end(), | ||||
|             arrays.size()); | ||||
|         counts.push_back(argsk.size()); | ||||
|       } | ||||
|       arrays.insert(arrays.end(), argsk.begin(), argsk.end()); | ||||
|       counts.push_back(argsk.size()); | ||||
|     } | ||||
|     std::partial_sum(counts.cbegin(), counts.cend(), counts.begin()); | ||||
|     std::vector<int> gradient_indices(arrays.size()); | ||||
|     std::iota(gradient_indices.begin(), gradient_indices.end(), 0); | ||||
|  | ||||
|     // value_out will hold the output of the python function in order to be | ||||
|     // able to reconstruct the python tree of extra return values | ||||
|     nb::object py_value_out; | ||||
|     auto value_and_grads = mx::value_and_grad( | ||||
|         [&fun, | ||||
|          &arrays, | ||||
|          &args, | ||||
|          &kwargs, | ||||
|          &argnums, | ||||
|          &argnames, | ||||
|          &counts, | ||||
|          &py_value_out, | ||||
|          &error_msg_tag, | ||||
|          scalar_func_only](const std::vector<mx::array>& a) { | ||||
|           // Copy the arguments | ||||
|           nb::list args_cpy; | ||||
|           nb::kwargs kwargs_cpy = nb::kwargs(); | ||||
|           int j = 0; | ||||
|           for (int i = 0; i < args.size(); ++i) { | ||||
|             if (j < argnums.size() && i == argnums[j]) { | ||||
|               args_cpy.append(tree_unflatten(args[i], a, counts[j])); | ||||
|               j++; | ||||
|             } else { | ||||
|               args_cpy.append(args[i]); | ||||
|             } | ||||
|           } | ||||
|           for (auto& key : argnames) { | ||||
|             kwargs_cpy[key.c_str()] = | ||||
|                 tree_unflatten(kwargs[key.c_str()], a, counts[j]); | ||||
|             j++; | ||||
|           } | ||||
|           for (auto item : kwargs) { | ||||
|             if (kwargs_cpy.contains(item.first)) { | ||||
|               continue; | ||||
|             } | ||||
|             kwargs_cpy[item.first] = item.second; | ||||
|           } | ||||
|           nb::list tree; | ||||
|           tree.append(args); | ||||
|           tree.append(kwargs); | ||||
|           tree_replace(tree, arrays, a); | ||||
|  | ||||
|           // Call the python function | ||||
|           py_value_out = fun(*args_cpy, **kwargs_cpy); | ||||
|           py_value_out = fun(*tree[0], **tree[1]); | ||||
|  | ||||
|           tree_replace(tree, arrays, a); | ||||
|  | ||||
|           // Validate the return value of the python function | ||||
|           if (!nb::isinstance<mx::array>(py_value_out)) { | ||||
| @@ -247,10 +258,13 @@ auto py_value_and_grad( | ||||
|       py_grads = positional_grads; | ||||
|     } else { | ||||
|       nb::dict grads_; | ||||
|       for (int i = 0; i < argnames.size(); i++) { | ||||
|         auto& k = argnames[i]; | ||||
|         grads_[k.c_str()] = tree_unflatten( | ||||
|             kwargs[k.c_str()], gradients, counts[i + argnums.size()]); | ||||
|       int i = 0; | ||||
|       for (auto item : kwargs) { | ||||
|         auto k = nb::cast<std::string>(item.first); | ||||
|         if (argnames.find(k) != argnames.end()) { | ||||
|           grads_[k.c_str()] = tree_unflatten( | ||||
|               nb::borrow(item.second), gradients, counts[i++ + argnums.size()]); | ||||
|         } | ||||
|       } | ||||
|       keyword_grads = grads_; | ||||
|  | ||||
| @@ -1207,17 +1221,17 @@ void init_transforms(nb::module_& m) { | ||||
|       "value_and_grad", | ||||
|       [](const nb::callable& fun, | ||||
|          const std::optional<IntOrVec>& argnums, | ||||
|          const StrOrVec& argnames) { | ||||
|         auto [argnums_vec, argnames_vec] = | ||||
|          const StrOrSet& argnames) { | ||||
|         auto [argnums_vec, argnames_set] = | ||||
|             validate_argnums_argnames(argnums, argnames); | ||||
|         return nb::cpp_function(py_value_and_grad( | ||||
|             fun, argnums_vec, argnames_vec, "[value_and_grad]", false)); | ||||
|             fun, argnums_vec, argnames_set, "[value_and_grad]", false)); | ||||
|       }, | ||||
|       "fun"_a, | ||||
|       "argnums"_a = nb::none(), | ||||
|       "argnames"_a = std::vector<std::string>{}, | ||||
|       nb::sig( | ||||
|           "def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"), | ||||
|           "def value_and_grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"), | ||||
|       R"pbdoc( | ||||
|         Returns a function which computes the value and gradient of ``fun``. | ||||
|  | ||||
| @@ -1271,21 +1285,20 @@ void init_transforms(nb::module_& m) { | ||||
|       "grad", | ||||
|       [](const nb::callable& fun, | ||||
|          const std::optional<IntOrVec>& argnums, | ||||
|          const StrOrVec& argnames) { | ||||
|         auto [argnums_vec, argnames_vec] = | ||||
|          const StrOrSet& argnames) { | ||||
|         auto [argnums_vec, argnames_set] = | ||||
|             validate_argnums_argnames(argnums, argnames); | ||||
|         auto fn = | ||||
|             py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true); | ||||
|         return nb::cpp_function( | ||||
|             [fn](const nb::args& args, const nb::kwargs& kwargs) { | ||||
|               return fn(args, kwargs).second; | ||||
|             }); | ||||
|             py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true); | ||||
|         return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) { | ||||
|           return fn(args, kwargs).second; | ||||
|         }); | ||||
|       }, | ||||
|       "fun"_a, | ||||
|       "argnums"_a = nb::none(), | ||||
|       "argnames"_a = std::vector<std::string>{}, | ||||
|       nb::sig( | ||||
|           "def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"), | ||||
|           "def grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"), | ||||
|       R"pbdoc( | ||||
|         Returns a function which computes the gradient of ``fun``. | ||||
|  | ||||
|   | ||||
| @@ -146,7 +146,7 @@ void tree_visit( | ||||
|   return recurse(trees); | ||||
| } | ||||
|  | ||||
| void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) { | ||||
| void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) { | ||||
|   std::function<void(nb::handle)> recurse; | ||||
|   recurse = [&](nb::handle subtree) { | ||||
|     if (nb::isinstance<nb::list>(subtree) || | ||||
| @@ -178,10 +178,11 @@ void tree_visit_update( | ||||
|       } | ||||
|       return nb::cast<nb::object>(l); | ||||
|     } else if (nb::isinstance<nb::tuple>(subtree)) { | ||||
|       for (auto item : subtree) { | ||||
|         recurse(item); | ||||
|       nb::list l(subtree); | ||||
|       for (int i = 0; i < l.size(); ++i) { | ||||
|         l[i] = recurse(l[i]); | ||||
|       } | ||||
|       return nb::cast<nb::object>(subtree); | ||||
|       return nb::cast<nb::object>(nb::tuple(l)); | ||||
|     } else if (nb::isinstance<nb::dict>(subtree)) { | ||||
|       auto d = nb::cast<nb::dict>(subtree); | ||||
|       for (auto item : d) { | ||||
| @@ -224,7 +225,7 @@ void tree_replace( | ||||
|   }); | ||||
| } | ||||
|  | ||||
| std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) { | ||||
| std::vector<mx::array> tree_flatten(nb::handle tree, bool strict /* = true */) { | ||||
|   std::vector<mx::array> flat_tree; | ||||
|  | ||||
|   tree_visit(tree, [&](nb::handle obj) { | ||||
|   | ||||
| @@ -10,7 +10,7 @@ namespace nb = nanobind; | ||||
| void tree_visit( | ||||
|     const std::vector<nb::object>& trees, | ||||
|     std::function<void(const std::vector<nb::object>&)> visitor); | ||||
| void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor); | ||||
| void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor); | ||||
|  | ||||
| nb::object tree_map( | ||||
|     const std::vector<nb::object>& trees, | ||||
| @@ -42,7 +42,7 @@ void tree_replace( | ||||
|  * Flatten a tree into a vector of arrays. If strict is true, then the | ||||
|  * function will throw if the tree contains a leaf which is not an array. | ||||
|  */ | ||||
| std::vector<mx::array> tree_flatten(nb::object tree, bool strict = true); | ||||
| std::vector<mx::array> tree_flatten(nb::handle tree, bool strict = true); | ||||
|  | ||||
| /** | ||||
|  * Unflatten a tree from a vector of arrays. | ||||
|   | ||||
| @@ -139,6 +139,8 @@ class TestAutograd(mlx_tests.MLXTestCase): | ||||
|             mx.value_and_grad(fun, (None, None)) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.value_and_grad(fun, tuple()) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun, argnums=(0, 0)) | ||||
|  | ||||
|     def test_auxiliary_values(self): | ||||
|         def fun(x, y): | ||||
|   | ||||
| @@ -195,6 +195,20 @@ class TestBase(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(isinstance(m.layers[1], nn.ReLU)) | ||||
|         self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) | ||||
|  | ||||
|     def test_grad_of_module(self): | ||||
|         class Model(nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.m1 = nn.Linear(3, 3) | ||||
|  | ||||
|         model = Model() | ||||
|  | ||||
|         def loss_fn(model): | ||||
|             return model.m1(x).sum() | ||||
|  | ||||
|         x = mx.zeros((3,)) | ||||
|         mx.grad(loss_fn)(model) | ||||
|  | ||||
|  | ||||
| class TestLayers(mlx_tests.MLXTestCase): | ||||
|     def test_identity(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun