mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-06 01:11:12 +08:00
Properly handle negative axes in python vmap (#944)
This commit is contained in:
parent
741eb28443
commit
3fc993f82d
@ -16,6 +16,7 @@
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
@ -265,26 +266,72 @@ auto py_vmap(
|
||||
const nb::object& out_axes) {
|
||||
return [fun, in_axes, out_axes](const nb::args& args) {
|
||||
auto axes_to_flat_tree = [](const nb::object& tree,
|
||||
const nb::object& axes) {
|
||||
auto tree_axes = tree_map(
|
||||
{tree, axes},
|
||||
[](const std::vector<nb::object>& inputs) { return inputs[1]; });
|
||||
const nb::object& axes,
|
||||
bool output_axes) {
|
||||
std::vector<int> flat_axes;
|
||||
tree_visit(tree_axes, [&flat_axes](nb::handle obj) {
|
||||
if (obj.is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
flat_axes.push_back(nb::cast<int>(nb::cast<nb::int_>(obj)));
|
||||
} else {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
});
|
||||
bool encountered_tuple = false;
|
||||
tree_visit(
|
||||
{tree, axes},
|
||||
[&flat_axes, &encountered_tuple, output_axes](
|
||||
const std::vector<nb::object>& inputs) {
|
||||
if (nb::isinstance<array>(inputs[0])) {
|
||||
if (inputs[1].is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else if (nb::isinstance<nb::int_>(inputs[1])) {
|
||||
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
|
||||
const array& x = nb::cast<array>(inputs[0]);
|
||||
if (axis < 0) {
|
||||
axis += x.ndim() + output_axes;
|
||||
}
|
||||
if (axis < 0 || axis >= (x.ndim() + output_axes)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
|
||||
<< "vectorization axis " << axis
|
||||
<< " for array with shape " << x.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
flat_axes.push_back(axis);
|
||||
} else if (nb::isinstance<nb::tuple>(inputs[1])) {
|
||||
encountered_tuple = true;
|
||||
auto l = nb::cast<nb::tuple>(inputs[1]);
|
||||
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
|
||||
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
|
||||
const array& x = nb::cast<array>(inputs[0]);
|
||||
if (axis < 0) {
|
||||
axis += x.ndim() + output_axes;
|
||||
}
|
||||
if (axis < 0 || axis >= (x.ndim() + output_axes)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
|
||||
<< "vectorization axis " << axis
|
||||
<< " for array with shape " << x.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
flat_axes.push_back(axis);
|
||||
} else if (l.size() == 1 && l[0].is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] axis must be int or None.");
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] The arguments should contain only arrays");
|
||||
}
|
||||
});
|
||||
if (encountered_tuple && !nb::isinstance<array>(tree)) {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
return flat_axes;
|
||||
};
|
||||
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
auto flat_in_axes = axes_to_flat_tree(args, in_axes);
|
||||
auto flat_in_axes =
|
||||
axes_to_flat_tree((args.size() == 1) ? args[0] : args, in_axes, false);
|
||||
|
||||
// py_value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
@ -302,7 +349,7 @@ auto py_vmap(
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||
|
||||
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes);
|
||||
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
|
||||
|
||||
// Perform the vmap
|
||||
auto outputs = detail::vmap_replace(
|
||||
|
@ -2,26 +2,6 @@
|
||||
|
||||
#include "python/src/trees.h"
|
||||
|
||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
|
||||
std::function<void(nb::handle)> recurse;
|
||||
recurse = [&](nb::handle subtree) {
|
||||
if (nb::isinstance<nb::list>(subtree) ||
|
||||
nb::isinstance<nb::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||
for (auto item : nb::cast<nb::dict>(subtree)) {
|
||||
recurse(item.second);
|
||||
}
|
||||
} else {
|
||||
visitor(subtree);
|
||||
}
|
||||
};
|
||||
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void validate_subtrees(const std::vector<nb::object>& subtrees) {
|
||||
int len = nb::cast<T>(subtrees[0]).size();
|
||||
@ -107,6 +87,85 @@ nb::object tree_map(
|
||||
});
|
||||
}
|
||||
|
||||
void tree_visit(
|
||||
const std::vector<nb::object>& trees,
|
||||
std::function<void(const std::vector<nb::object>&)> visitor) {
|
||||
std::function<void(const std::vector<nb::object>&)> recurse;
|
||||
|
||||
recurse = [&](const std::vector<nb::object>& subtrees) {
|
||||
if (nb::isinstance<nb::list>(subtrees[0])) {
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
|
||||
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (nb::isinstance<nb::list>(subtrees[j])) {
|
||||
items[j] = nb::cast<nb::list>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
recurse(items);
|
||||
}
|
||||
} else if (nb::isinstance<nb::tuple>(subtrees[0])) {
|
||||
// Check the rest of the subtrees
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
int len = nb::cast<nb::tuple>(subtrees[0]).size();
|
||||
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
||||
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
recurse(items);
|
||||
}
|
||||
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
||||
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (nb::isinstance<nb::dict>(subtrees[j])) {
|
||||
auto subdict = nb::cast<nb::dict>(subtrees[j]);
|
||||
if (!subdict.contains(item.first)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_visit] Tree is not a valid prefix tree of the first tree.");
|
||||
}
|
||||
items[j] = subdict[item.first];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
recurse(items);
|
||||
}
|
||||
} else {
|
||||
visitor(subtrees);
|
||||
}
|
||||
};
|
||||
return recurse(trees);
|
||||
}
|
||||
|
||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
|
||||
std::function<void(nb::handle)> recurse;
|
||||
recurse = [&](nb::handle subtree) {
|
||||
if (nb::isinstance<nb::list>(subtree) ||
|
||||
nb::isinstance<nb::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||
for (auto item : nb::cast<nb::dict>(subtree)) {
|
||||
recurse(item.second);
|
||||
}
|
||||
} else {
|
||||
visitor(subtree);
|
||||
}
|
||||
};
|
||||
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
void tree_visit_update(
|
||||
nb::object tree,
|
||||
std::function<nb::object(nb::handle)> visitor) {
|
||||
|
@ -7,6 +7,9 @@
|
||||
namespace nb = nanobind;
|
||||
using namespace mlx::core;
|
||||
|
||||
void tree_visit(
|
||||
const std::vector<nb::object>& trees,
|
||||
std::function<void(const std::vector<nb::object>&)> visitor);
|
||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
|
||||
|
||||
nb::object tree_map(
|
||||
|
@ -121,16 +121,13 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
expected = my_fun(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
|
||||
out = mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
|
||||
out = mx.vmap(my_fun, in_axes={"a": 0, "b": (0, 0)}, out_axes=0)(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
tree = {
|
||||
@ -140,7 +137,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
mx.random.uniform(shape=(4, 2)),
|
||||
),
|
||||
}
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
|
||||
out = mx.vmap(my_fun, in_axes={"a": 0, "b": (1, 1)}, out_axes=0)(tree)
|
||||
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user