mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
Properly handle negative axes in python vmap (#944)
This commit is contained in:

committed by
GitHub

parent
741eb28443
commit
3fc993f82d
@@ -2,26 +2,6 @@
|
||||
|
||||
#include "python/src/trees.h"
|
||||
|
||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
|
||||
std::function<void(nb::handle)> recurse;
|
||||
recurse = [&](nb::handle subtree) {
|
||||
if (nb::isinstance<nb::list>(subtree) ||
|
||||
nb::isinstance<nb::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||
for (auto item : nb::cast<nb::dict>(subtree)) {
|
||||
recurse(item.second);
|
||||
}
|
||||
} else {
|
||||
visitor(subtree);
|
||||
}
|
||||
};
|
||||
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void validate_subtrees(const std::vector<nb::object>& subtrees) {
|
||||
int len = nb::cast<T>(subtrees[0]).size();
|
||||
@@ -107,6 +87,85 @@ nb::object tree_map(
|
||||
});
|
||||
}
|
||||
|
||||
void tree_visit(
|
||||
const std::vector<nb::object>& trees,
|
||||
std::function<void(const std::vector<nb::object>&)> visitor) {
|
||||
std::function<void(const std::vector<nb::object>&)> recurse;
|
||||
|
||||
recurse = [&](const std::vector<nb::object>& subtrees) {
|
||||
if (nb::isinstance<nb::list>(subtrees[0])) {
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
|
||||
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (nb::isinstance<nb::list>(subtrees[j])) {
|
||||
items[j] = nb::cast<nb::list>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
recurse(items);
|
||||
}
|
||||
} else if (nb::isinstance<nb::tuple>(subtrees[0])) {
|
||||
// Check the rest of the subtrees
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
int len = nb::cast<nb::tuple>(subtrees[0]).size();
|
||||
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
||||
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
recurse(items);
|
||||
}
|
||||
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
||||
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (nb::isinstance<nb::dict>(subtrees[j])) {
|
||||
auto subdict = nb::cast<nb::dict>(subtrees[j]);
|
||||
if (!subdict.contains(item.first)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_visit] Tree is not a valid prefix tree of the first tree.");
|
||||
}
|
||||
items[j] = subdict[item.first];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
recurse(items);
|
||||
}
|
||||
} else {
|
||||
visitor(subtrees);
|
||||
}
|
||||
};
|
||||
return recurse(trees);
|
||||
}
|
||||
|
||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
|
||||
std::function<void(nb::handle)> recurse;
|
||||
recurse = [&](nb::handle subtree) {
|
||||
if (nb::isinstance<nb::list>(subtree) ||
|
||||
nb::isinstance<nb::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||
for (auto item : nb::cast<nb::dict>(subtree)) {
|
||||
recurse(item.second);
|
||||
}
|
||||
} else {
|
||||
visitor(subtree);
|
||||
}
|
||||
};
|
||||
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
void tree_visit_update(
|
||||
nb::object tree,
|
||||
std::function<nb::object(nb::handle)> visitor) {
|
||||
|
Reference in New Issue
Block a user