Properly handle negative axes in python vmap (#944)

This commit is contained in:
Angelos Katharopoulos
2024-04-02 18:07:23 -07:00
committed by GitHub
parent 741eb28443
commit 3fc993f82d
4 changed files with 147 additions and 41 deletions

View File

@@ -2,26 +2,6 @@
#include "python/src/trees.h"
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
std::function<void(nb::handle)> recurse;
recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree) ||
nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (nb::isinstance<nb::dict>(subtree)) {
for (auto item : nb::cast<nb::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<nb::object>& subtrees) {
int len = nb::cast<T>(subtrees[0]).size();
@@ -107,6 +87,85 @@ nb::object tree_map(
});
}
void tree_visit(
const std::vector<nb::object>& trees,
std::function<void(const std::vector<nb::object>&)> visitor) {
std::function<void(const std::vector<nb::object>&)> recurse;
recurse = [&](const std::vector<nb::object>& subtrees) {
if (nb::isinstance<nb::list>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::list>(subtrees[j])) {
items[j] = nb::cast<nb::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
recurse(items);
}
} else if (nb::isinstance<nb::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<nb::object> items(subtrees.size());
int len = nb::cast<nb::tuple>(subtrees[0]).size();
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::tuple>(subtrees[j])) {
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
recurse(items);
}
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::dict>(subtrees[j])) {
auto subdict = nb::cast<nb::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_visit] Tree is not a valid prefix tree of the first tree.");
}
items[j] = subdict[item.first];
} else {
items[j] = subtrees[j];
}
}
recurse(items);
}
} else {
visitor(subtrees);
}
};
return recurse(trees);
}
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
std::function<void(nb::handle)> recurse;
recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree) ||
nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (nb::isinstance<nb::dict>(subtree)) {
for (auto item : nb::cast<nb::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
void tree_visit_update(
nb::object tree,
std::function<nb::object(nb::handle)> visitor) {