mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Added support for pytree types that inherit from tuple and typing.namedtuple (#2845)
This commit is contained in:
@@ -41,6 +41,7 @@ nb::object tree_map(
|
||||
int len = nb::cast<nb::tuple>(subtrees[0]).size();
|
||||
nb::list l;
|
||||
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
||||
auto type = subtrees[0].type();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
||||
@@ -51,7 +52,10 @@ nb::object tree_map(
|
||||
}
|
||||
l.append(recurse(items));
|
||||
}
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
if (PyTuple_CheckExact(subtrees[0].ptr())) {
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
}
|
||||
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
|
||||
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
||||
@@ -178,11 +182,15 @@ void tree_visit_update(
|
||||
}
|
||||
return nb::cast<nb::object>(l);
|
||||
} else if (nb::isinstance<nb::tuple>(subtree)) {
|
||||
auto type = subtree.type();
|
||||
nb::list l(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
if (PyTuple_CheckExact(subtree.ptr())) {
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
}
|
||||
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
|
||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||
auto d = nb::cast<nb::dict>(subtree);
|
||||
for (auto item : d) {
|
||||
|
||||
Reference in New Issue
Block a user