Added support for pytree types that inherit from tuple and typing.namedtuple (#2845)

This commit is contained in:
romanoneg
2025-12-04 11:06:45 -08:00
committed by GitHub
parent 50d3914c67
commit 9abb0b8123
6 changed files with 196 additions and 3 deletions

View File

@@ -41,6 +41,7 @@ nb::object tree_map(
int len = nb::cast<nb::tuple>(subtrees[0]).size();
nb::list l;
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
auto type = subtrees[0].type();
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::tuple>(subtrees[j])) {
@@ -51,7 +52,10 @@ nb::object tree_map(
}
l.append(recurse(items));
}
return nb::cast<nb::object>(nb::tuple(l));
if (PyTuple_CheckExact(subtrees[0].ptr())) {
return nb::cast<nb::object>(nb::tuple(l));
}
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
@@ -178,11 +182,15 @@ void tree_visit_update(
}
return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtree)) {
auto type = subtree.type();
nb::list l(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return nb::cast<nb::object>(nb::tuple(l));
if (PyTuple_CheckExact(subtree.ptr())) {
return nb::cast<nb::object>(nb::tuple(l));
}
return nb::hasattr(type, "_fields") ? type(*l) : type(l);
} else if (nb::isinstance<nb::dict>(subtree)) {
auto d = nb::cast<nb::dict>(subtree);
for (auto item : d) {