mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Limit grad recursion depth by not recursing through non-grad inputs (#1764)
* limit grad recursion depth * add grad of module test
This commit is contained in:
parent
5cc5201914
commit
33421c1dd3
@ -278,7 +278,8 @@ void eval(std::vector<array> outputs) {
|
|||||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotans) {
|
const std::vector<array>& cotans,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
// Set the global tracing flag.
|
// Set the global tracing flag.
|
||||||
detail::InTracing in_tracing;
|
detail::InTracing in_tracing;
|
||||||
|
|
||||||
@ -330,10 +331,14 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
|||||||
// to the tape which need a gradient.
|
// to the tape which need a gradient.
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
std::unordered_set<std::uintptr_t> calc_grad;
|
std::unordered_set<std::uintptr_t> calc_grad;
|
||||||
for (auto& primal : primals_) {
|
for (int i = 0, j = 0; i < primals_.size(); ++i) {
|
||||||
|
auto& primal = primals_[i];
|
||||||
primal.set_tracer(false);
|
primal.set_tracer(false);
|
||||||
calc_grad.insert(primal.id());
|
|
||||||
cache.insert(primal.id());
|
cache.insert(primal.id());
|
||||||
|
if (j < argnums.size() && argnums[j] == i) {
|
||||||
|
j++;
|
||||||
|
calc_grad.insert(primal.id());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> tape;
|
std::vector<array> tape;
|
||||||
@ -435,7 +440,8 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
for (auto& primal : primals_) {
|
for (auto arg : argnums) {
|
||||||
|
auto& primal = primals_[arg];
|
||||||
if (auto cotan_it = cotan_map.find(primal.id());
|
if (auto cotan_it = cotan_map.find(primal.id());
|
||||||
cotan_it != cotan_map.end()) {
|
cotan_it != cotan_map.end()) {
|
||||||
vjps.push_back(cotan_it->second);
|
vjps.push_back(cotan_it->second);
|
||||||
@ -448,6 +454,15 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
|||||||
return {outputs, vjps};
|
return {outputs, vjps};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotans) {
|
||||||
|
std::vector<int> argnums(primals.size());
|
||||||
|
std::iota(argnums.begin(), argnums.end(), 0);
|
||||||
|
return vjp(fun, primals, cotans, argnums);
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<array, array> vjp(
|
std::pair<array, array> vjp(
|
||||||
const std::function<array(const array&)>& fun,
|
const std::function<array(const array&)>& fun,
|
||||||
const array& primal,
|
const array& primal,
|
||||||
@ -606,15 +621,10 @@ ValueAndGradFn value_and_grad(
|
|||||||
<< inputs.size() << " inputs.";
|
<< inputs.size() << " inputs.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
std::vector<int> sorted_argnums(args.begin(), args.end());
|
||||||
|
|
||||||
auto gfun = [&fun, &inputs, &args](const std::vector<array>& ginputs) {
|
auto gfun = [&fun](const std::vector<array>& inputs) {
|
||||||
std::vector<array> inputs_(inputs);
|
auto outputs = fun(inputs);
|
||||||
auto argit = args.begin();
|
|
||||||
for (int i = 0; i < ginputs.size(); ++i) {
|
|
||||||
inputs_[*argit] = ginputs[i];
|
|
||||||
++argit;
|
|
||||||
}
|
|
||||||
auto outputs = fun(inputs_);
|
|
||||||
for (int i = 1; i < outputs.size(); i++) {
|
for (int i = 1; i < outputs.size(); i++) {
|
||||||
auto& out = outputs[i];
|
auto& out = outputs[i];
|
||||||
auto s = out.has_primitive() ? out.primitive().stream()
|
auto s = out.has_primitive() ? out.primitive().stream()
|
||||||
@ -624,12 +634,8 @@ ValueAndGradFn value_and_grad(
|
|||||||
return outputs;
|
return outputs;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<array> ginputs;
|
|
||||||
for (auto arg : args) {
|
|
||||||
ginputs.push_back(inputs[arg]);
|
|
||||||
}
|
|
||||||
// Set the incoming gradient to float32, vjp will cast it to the output type
|
// Set the incoming gradient to float32, vjp will cast it to the output type
|
||||||
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
|
auto [outputs, grads] = vjp(gfun, inputs, {array(1.0f)}, sorted_argnums);
|
||||||
return std::make_pair(outputs, grads);
|
return std::make_pair(outputs, grads);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
#include <nanobind/stl/pair.h>
|
#include <nanobind/stl/pair.h>
|
||||||
#include <nanobind/stl/string.h>
|
#include <nanobind/stl/string.h>
|
||||||
|
#include <nanobind/stl/unordered_set.h>
|
||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <numeric>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/compile.h"
|
#include "mlx/compile.h"
|
||||||
#include "mlx/compile_impl.h"
|
#include "mlx/compile_impl.h"
|
||||||
@ -27,44 +29,45 @@ using namespace nb::literals;
|
|||||||
using mx::operator<<;
|
using mx::operator<<;
|
||||||
|
|
||||||
using IntOrVec = std::variant<int, std::vector<int>>;
|
using IntOrVec = std::variant<int, std::vector<int>>;
|
||||||
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
|
using StrOrSet = std::variant<std::string, std::unordered_set<std::string>>;
|
||||||
|
|
||||||
inline std::string type_name_str(const nb::handle& o) {
|
inline std::string type_name_str(const nb::handle& o) {
|
||||||
return nb::cast<std::string>(nb::type_name(o.type()));
|
return nb::cast<std::string>(nb::type_name(o.type()));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
|
|
||||||
std::vector<T> vals;
|
|
||||||
if (auto pv = std::get_if<T>(&v); pv) {
|
|
||||||
vals.push_back(*pv);
|
|
||||||
} else {
|
|
||||||
vals = std::get<std::vector<T>>(v);
|
|
||||||
}
|
|
||||||
return vals;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto validate_argnums_argnames(
|
auto validate_argnums_argnames(
|
||||||
const std::optional<IntOrVec>& argnums,
|
const std::optional<IntOrVec>& argnums,
|
||||||
const StrOrVec& argnames) {
|
const StrOrSet& argnames) {
|
||||||
auto vec_names = to_vector(argnames);
|
std::unordered_set<std::string> setnames;
|
||||||
|
if (auto pv = std::get_if<std::string>(&argnames); pv) {
|
||||||
|
setnames = {*pv};
|
||||||
|
} else {
|
||||||
|
setnames = std::get<std::unordered_set<std::string>>(argnames);
|
||||||
|
}
|
||||||
|
|
||||||
if (!argnums.has_value()) {
|
if (!argnums.has_value()) {
|
||||||
// argnums was not provided and argnames was empty
|
// argnums was not provided and argnames was empty
|
||||||
if (vec_names.empty()) {
|
if (setnames.empty()) {
|
||||||
return std::make_pair(std::vector<int>{0}, vec_names);
|
return std::make_pair(std::vector<int>{0}, setnames);
|
||||||
} else {
|
} else {
|
||||||
return std::make_pair(std::vector<int>{}, vec_names);
|
return std::make_pair(std::vector<int>{}, setnames);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_pair(to_vector(*argnums), vec_names);
|
std::vector<int> vecnums;
|
||||||
|
if (auto pv = std::get_if<int>(&(*argnums)); pv) {
|
||||||
|
vecnums = {*pv};
|
||||||
|
} else {
|
||||||
|
vecnums = std::get<std::vector<int>>(*argnums);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_pair(vecnums, setnames);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto py_value_and_grad(
|
auto py_value_and_grad(
|
||||||
const nb::callable& fun,
|
const nb::callable& fun,
|
||||||
std::vector<int> argnums,
|
std::vector<int> argnums,
|
||||||
std::vector<std::string> argnames,
|
std::unordered_set<std::string> argnames,
|
||||||
const std::string& error_msg_tag,
|
const std::string& error_msg_tag,
|
||||||
bool scalar_func_only) {
|
bool scalar_func_only) {
|
||||||
// Sanitize argnums
|
// Sanitize argnums
|
||||||
@ -72,7 +75,7 @@ auto py_value_and_grad(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
error_msg_tag + " Gradient wrt no argument requested");
|
error_msg_tag + " Gradient wrt no argument requested");
|
||||||
}
|
}
|
||||||
if (argnums.size() > 0) {
|
for (auto arg : argnums) {
|
||||||
std::sort(argnums.begin(), argnums.end());
|
std::sort(argnums.begin(), argnums.end());
|
||||||
if (argnums[0] < 0) {
|
if (argnums[0] < 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -81,10 +84,18 @@ auto py_value_and_grad(
|
|||||||
<< argnums[0];
|
<< argnums[0];
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
for (int i = 1; i < argnums.size(); ++i) {
|
||||||
|
if (argnums[i] == argnums[i - 1]) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << error_msg_tag << " Duplicate argument index " << argnums[0]
|
||||||
|
<< " is not allowed.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
|
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
|
||||||
const nb::args& args, const nb::kwargs& kwargs) {
|
nb::args& args, nb::kwargs& kwargs) {
|
||||||
// Sanitize the input
|
// Sanitize the input
|
||||||
if (argnums.size() > 0 && argnums.back() >= args.size()) {
|
if (argnums.size() > 0 && argnums.back() >= args.size()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -112,59 +123,59 @@ auto py_value_and_grad(
|
|||||||
// Collect the arrays
|
// Collect the arrays
|
||||||
std::vector<mx::array> arrays;
|
std::vector<mx::array> arrays;
|
||||||
std::vector<int> counts(1, 0);
|
std::vector<int> counts(1, 0);
|
||||||
for (auto i : argnums) {
|
std::vector<int> gradient_indices;
|
||||||
auto argsi = tree_flatten(args[i]);
|
for (int i = 0, j = 0; i < args.size(); ++i) {
|
||||||
|
bool needs_grad = (j < argnums.size() && argnums[j] == i);
|
||||||
|
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
|
||||||
|
if (needs_grad) {
|
||||||
|
auto old_size = gradient_indices.size();
|
||||||
|
gradient_indices.resize(old_size + argsi.size());
|
||||||
|
std::iota(
|
||||||
|
gradient_indices.begin() + old_size,
|
||||||
|
gradient_indices.end(),
|
||||||
|
arrays.size());
|
||||||
|
j++;
|
||||||
|
counts.push_back(argsi.size());
|
||||||
|
}
|
||||||
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
|
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
|
||||||
counts.push_back(argsi.size());
|
|
||||||
}
|
}
|
||||||
for (auto& key : argnames) {
|
for (auto item : kwargs) {
|
||||||
auto argsk = tree_flatten(kwargs[key.c_str()]);
|
bool needs_grad =
|
||||||
|
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
|
||||||
|
auto argsk = tree_flatten(item.second, /* strict = */ needs_grad);
|
||||||
|
if (needs_grad) {
|
||||||
|
auto old_size = gradient_indices.size();
|
||||||
|
gradient_indices.resize(old_size + argsk.size());
|
||||||
|
std::iota(
|
||||||
|
gradient_indices.begin() + old_size,
|
||||||
|
gradient_indices.end(),
|
||||||
|
arrays.size());
|
||||||
|
counts.push_back(argsk.size());
|
||||||
|
}
|
||||||
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
|
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
|
||||||
counts.push_back(argsk.size());
|
|
||||||
}
|
}
|
||||||
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
|
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
|
||||||
std::vector<int> gradient_indices(arrays.size());
|
|
||||||
std::iota(gradient_indices.begin(), gradient_indices.end(), 0);
|
|
||||||
|
|
||||||
// value_out will hold the output of the python function in order to be
|
// value_out will hold the output of the python function in order to be
|
||||||
// able to reconstruct the python tree of extra return values
|
// able to reconstruct the python tree of extra return values
|
||||||
nb::object py_value_out;
|
nb::object py_value_out;
|
||||||
auto value_and_grads = mx::value_and_grad(
|
auto value_and_grads = mx::value_and_grad(
|
||||||
[&fun,
|
[&fun,
|
||||||
|
&arrays,
|
||||||
&args,
|
&args,
|
||||||
&kwargs,
|
&kwargs,
|
||||||
&argnums,
|
|
||||||
&argnames,
|
|
||||||
&counts,
|
|
||||||
&py_value_out,
|
&py_value_out,
|
||||||
&error_msg_tag,
|
&error_msg_tag,
|
||||||
scalar_func_only](const std::vector<mx::array>& a) {
|
scalar_func_only](const std::vector<mx::array>& a) {
|
||||||
// Copy the arguments
|
nb::list tree;
|
||||||
nb::list args_cpy;
|
tree.append(args);
|
||||||
nb::kwargs kwargs_cpy = nb::kwargs();
|
tree.append(kwargs);
|
||||||
int j = 0;
|
tree_replace(tree, arrays, a);
|
||||||
for (int i = 0; i < args.size(); ++i) {
|
|
||||||
if (j < argnums.size() && i == argnums[j]) {
|
|
||||||
args_cpy.append(tree_unflatten(args[i], a, counts[j]));
|
|
||||||
j++;
|
|
||||||
} else {
|
|
||||||
args_cpy.append(args[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto& key : argnames) {
|
|
||||||
kwargs_cpy[key.c_str()] =
|
|
||||||
tree_unflatten(kwargs[key.c_str()], a, counts[j]);
|
|
||||||
j++;
|
|
||||||
}
|
|
||||||
for (auto item : kwargs) {
|
|
||||||
if (kwargs_cpy.contains(item.first)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
kwargs_cpy[item.first] = item.second;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call the python function
|
// Call the python function
|
||||||
py_value_out = fun(*args_cpy, **kwargs_cpy);
|
py_value_out = fun(*tree[0], **tree[1]);
|
||||||
|
|
||||||
|
tree_replace(tree, arrays, a);
|
||||||
|
|
||||||
// Validate the return value of the python function
|
// Validate the return value of the python function
|
||||||
if (!nb::isinstance<mx::array>(py_value_out)) {
|
if (!nb::isinstance<mx::array>(py_value_out)) {
|
||||||
@ -247,10 +258,13 @@ auto py_value_and_grad(
|
|||||||
py_grads = positional_grads;
|
py_grads = positional_grads;
|
||||||
} else {
|
} else {
|
||||||
nb::dict grads_;
|
nb::dict grads_;
|
||||||
for (int i = 0; i < argnames.size(); i++) {
|
int i = 0;
|
||||||
auto& k = argnames[i];
|
for (auto item : kwargs) {
|
||||||
grads_[k.c_str()] = tree_unflatten(
|
auto k = nb::cast<std::string>(item.first);
|
||||||
kwargs[k.c_str()], gradients, counts[i + argnums.size()]);
|
if (argnames.find(k) != argnames.end()) {
|
||||||
|
grads_[k.c_str()] = tree_unflatten(
|
||||||
|
nb::borrow(item.second), gradients, counts[i++ + argnums.size()]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
keyword_grads = grads_;
|
keyword_grads = grads_;
|
||||||
|
|
||||||
@ -1207,17 +1221,17 @@ void init_transforms(nb::module_& m) {
|
|||||||
"value_and_grad",
|
"value_and_grad",
|
||||||
[](const nb::callable& fun,
|
[](const nb::callable& fun,
|
||||||
const std::optional<IntOrVec>& argnums,
|
const std::optional<IntOrVec>& argnums,
|
||||||
const StrOrVec& argnames) {
|
const StrOrSet& argnames) {
|
||||||
auto [argnums_vec, argnames_vec] =
|
auto [argnums_vec, argnames_set] =
|
||||||
validate_argnums_argnames(argnums, argnames);
|
validate_argnums_argnames(argnums, argnames);
|
||||||
return nb::cpp_function(py_value_and_grad(
|
return nb::cpp_function(py_value_and_grad(
|
||||||
fun, argnums_vec, argnames_vec, "[value_and_grad]", false));
|
fun, argnums_vec, argnames_set, "[value_and_grad]", false));
|
||||||
},
|
},
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
"argnums"_a = nb::none(),
|
"argnums"_a = nb::none(),
|
||||||
"argnames"_a = std::vector<std::string>{},
|
"argnames"_a = std::vector<std::string>{},
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
|
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Returns a function which computes the value and gradient of ``fun``.
|
Returns a function which computes the value and gradient of ``fun``.
|
||||||
|
|
||||||
@ -1271,21 +1285,20 @@ void init_transforms(nb::module_& m) {
|
|||||||
"grad",
|
"grad",
|
||||||
[](const nb::callable& fun,
|
[](const nb::callable& fun,
|
||||||
const std::optional<IntOrVec>& argnums,
|
const std::optional<IntOrVec>& argnums,
|
||||||
const StrOrVec& argnames) {
|
const StrOrSet& argnames) {
|
||||||
auto [argnums_vec, argnames_vec] =
|
auto [argnums_vec, argnames_set] =
|
||||||
validate_argnums_argnames(argnums, argnames);
|
validate_argnums_argnames(argnums, argnames);
|
||||||
auto fn =
|
auto fn =
|
||||||
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true);
|
py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true);
|
||||||
return nb::cpp_function(
|
return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) {
|
||||||
[fn](const nb::args& args, const nb::kwargs& kwargs) {
|
return fn(args, kwargs).second;
|
||||||
return fn(args, kwargs).second;
|
});
|
||||||
});
|
|
||||||
},
|
},
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
"argnums"_a = nb::none(),
|
"argnums"_a = nb::none(),
|
||||||
"argnames"_a = std::vector<std::string>{},
|
"argnames"_a = std::vector<std::string>{},
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
|
"def grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Returns a function which computes the gradient of ``fun``.
|
Returns a function which computes the gradient of ``fun``.
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ void tree_visit(
|
|||||||
return recurse(trees);
|
return recurse(trees);
|
||||||
}
|
}
|
||||||
|
|
||||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
|
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) {
|
||||||
std::function<void(nb::handle)> recurse;
|
std::function<void(nb::handle)> recurse;
|
||||||
recurse = [&](nb::handle subtree) {
|
recurse = [&](nb::handle subtree) {
|
||||||
if (nb::isinstance<nb::list>(subtree) ||
|
if (nb::isinstance<nb::list>(subtree) ||
|
||||||
@ -178,10 +178,11 @@ void tree_visit_update(
|
|||||||
}
|
}
|
||||||
return nb::cast<nb::object>(l);
|
return nb::cast<nb::object>(l);
|
||||||
} else if (nb::isinstance<nb::tuple>(subtree)) {
|
} else if (nb::isinstance<nb::tuple>(subtree)) {
|
||||||
for (auto item : subtree) {
|
nb::list l(subtree);
|
||||||
recurse(item);
|
for (int i = 0; i < l.size(); ++i) {
|
||||||
|
l[i] = recurse(l[i]);
|
||||||
}
|
}
|
||||||
return nb::cast<nb::object>(subtree);
|
return nb::cast<nb::object>(nb::tuple(l));
|
||||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||||
auto d = nb::cast<nb::dict>(subtree);
|
auto d = nb::cast<nb::dict>(subtree);
|
||||||
for (auto item : d) {
|
for (auto item : d) {
|
||||||
@ -224,7 +225,7 @@ void tree_replace(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) {
|
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict /* = true */) {
|
||||||
std::vector<mx::array> flat_tree;
|
std::vector<mx::array> flat_tree;
|
||||||
|
|
||||||
tree_visit(tree, [&](nb::handle obj) {
|
tree_visit(tree, [&](nb::handle obj) {
|
||||||
|
@ -10,7 +10,7 @@ namespace nb = nanobind;
|
|||||||
void tree_visit(
|
void tree_visit(
|
||||||
const std::vector<nb::object>& trees,
|
const std::vector<nb::object>& trees,
|
||||||
std::function<void(const std::vector<nb::object>&)> visitor);
|
std::function<void(const std::vector<nb::object>&)> visitor);
|
||||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
|
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor);
|
||||||
|
|
||||||
nb::object tree_map(
|
nb::object tree_map(
|
||||||
const std::vector<nb::object>& trees,
|
const std::vector<nb::object>& trees,
|
||||||
@ -42,7 +42,7 @@ void tree_replace(
|
|||||||
* Flatten a tree into a vector of arrays. If strict is true, then the
|
* Flatten a tree into a vector of arrays. If strict is true, then the
|
||||||
* function will throw if the tree contains a leaf which is not an array.
|
* function will throw if the tree contains a leaf which is not an array.
|
||||||
*/
|
*/
|
||||||
std::vector<mx::array> tree_flatten(nb::object tree, bool strict = true);
|
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unflatten a tree from a vector of arrays.
|
* Unflatten a tree from a vector of arrays.
|
||||||
|
@ -139,6 +139,8 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
mx.value_and_grad(fun, (None, None))
|
mx.value_and_grad(fun, (None, None))
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.value_and_grad(fun, tuple())
|
mx.value_and_grad(fun, tuple())
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.grad(fun, argnums=(0, 0))
|
||||||
|
|
||||||
def test_auxiliary_values(self):
|
def test_auxiliary_values(self):
|
||||||
def fun(x, y):
|
def fun(x, y):
|
||||||
|
@ -195,6 +195,20 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||||
|
|
||||||
|
def test_grad_of_module(self):
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.m1 = nn.Linear(3, 3)
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
|
||||||
|
def loss_fn(model):
|
||||||
|
return model.m1(x).sum()
|
||||||
|
|
||||||
|
x = mx.zeros((3,))
|
||||||
|
mx.grad(loss_fn)(model)
|
||||||
|
|
||||||
|
|
||||||
class TestLayers(mlx_tests.MLXTestCase):
|
class TestLayers(mlx_tests.MLXTestCase):
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user