Properly handle negative axes in python vmap (#944)

This commit is contained in:
Angelos Katharopoulos 2024-04-02 18:07:23 -07:00 committed by GitHub
parent 741eb28443
commit 3fc993f82d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 147 additions and 41 deletions

View File

@ -16,6 +16,7 @@
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
#include "python/src/trees.h"
namespace nb = nanobind;
@ -265,26 +266,72 @@ auto py_vmap(
const nb::object& out_axes) {
return [fun, in_axes, out_axes](const nb::args& args) {
auto axes_to_flat_tree = [](const nb::object& tree,
const nb::object& axes) {
auto tree_axes = tree_map(
{tree, axes},
[](const std::vector<nb::object>& inputs) { return inputs[1]; });
const nb::object& axes,
bool output_axes) {
std::vector<int> flat_axes;
tree_visit(tree_axes, [&flat_axes](nb::handle obj) {
if (obj.is_none()) {
bool encountered_tuple = false;
tree_visit(
{tree, axes},
[&flat_axes, &encountered_tuple, output_axes](
const std::vector<nb::object>& inputs) {
if (nb::isinstance<array>(inputs[0])) {
if (inputs[1].is_none()) {
flat_axes.push_back(-1);
} else if (nb::isinstance<nb::int_>(obj)) {
flat_axes.push_back(nb::cast<int>(nb::cast<nb::int_>(obj)));
} else if (nb::isinstance<nb::int_>(inputs[1])) {
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
const array& x = nb::cast<array>(inputs[0]);
if (axis < 0) {
axis += x.ndim() + output_axes;
}
if (axis < 0 || axis >= (x.ndim() + output_axes)) {
std::ostringstream msg;
msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
<< "vectorization axis " << axis
<< " for array with shape " << x.shape();
throw std::invalid_argument(msg.str());
}
flat_axes.push_back(axis);
} else if (nb::isinstance<nb::tuple>(inputs[1])) {
encountered_tuple = true;
auto l = nb::cast<nb::tuple>(inputs[1]);
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
const array& x = nb::cast<array>(inputs[0]);
if (axis < 0) {
axis += x.ndim() + output_axes;
}
if (axis < 0 || axis >= (x.ndim() + output_axes)) {
std::ostringstream msg;
msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
<< "vectorization axis " << axis
<< " for array with shape " << x.shape();
throw std::invalid_argument(msg.str());
}
flat_axes.push_back(axis);
} else if (l.size() == 1 && l[0].is_none()) {
flat_axes.push_back(-1);
} else {
throw std::invalid_argument(
"[vmap] axis must be int or None.");
}
} else {
throw std::invalid_argument("[vmap] axis must be int or None.");
}
} else {
throw std::invalid_argument(
"[vmap] The arguments should contain only arrays");
}
});
if (encountered_tuple && !nb::isinstance<array>(tree)) {
throw std::invalid_argument("[vmap] axis must be int or None.");
}
return flat_axes;
};
// Inputs must be array or tree of arrays
auto inputs = tree_flatten(args, true);
auto flat_in_axes = axes_to_flat_tree(args, in_axes);
auto flat_in_axes =
axes_to_flat_tree((args.size() == 1) ? args[0] : args, in_axes, false);
// py_value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values
@ -302,7 +349,7 @@ auto py_vmap(
auto [trace_inputs, trace_outputs] =
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes);
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
// Perform the vmap
auto outputs = detail::vmap_replace(

View File

@ -2,26 +2,6 @@
#include "python/src/trees.h"
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
std::function<void(nb::handle)> recurse;
recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree) ||
nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (nb::isinstance<nb::dict>(subtree)) {
for (auto item : nb::cast<nb::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<nb::object>& subtrees) {
int len = nb::cast<T>(subtrees[0]).size();
@ -107,6 +87,85 @@ nb::object tree_map(
});
}
void tree_visit(
const std::vector<nb::object>& trees,
std::function<void(const std::vector<nb::object>&)> visitor) {
std::function<void(const std::vector<nb::object>&)> recurse;
recurse = [&](const std::vector<nb::object>& subtrees) {
if (nb::isinstance<nb::list>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::list>(subtrees[j])) {
items[j] = nb::cast<nb::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
recurse(items);
}
} else if (nb::isinstance<nb::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<nb::object> items(subtrees.size());
int len = nb::cast<nb::tuple>(subtrees[0]).size();
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::tuple>(subtrees[j])) {
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
recurse(items);
}
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (nb::isinstance<nb::dict>(subtrees[j])) {
auto subdict = nb::cast<nb::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_visit] Tree is not a valid prefix tree of the first tree.");
}
items[j] = subdict[item.first];
} else {
items[j] = subtrees[j];
}
}
recurse(items);
}
} else {
visitor(subtrees);
}
};
return recurse(trees);
}
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
std::function<void(nb::handle)> recurse;
recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree) ||
nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (nb::isinstance<nb::dict>(subtree)) {
for (auto item : nb::cast<nb::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
void tree_visit_update(
nb::object tree,
std::function<nb::object(nb::handle)> visitor) {

View File

@ -7,6 +7,9 @@
namespace nb = nanobind;
using namespace mlx::core;
void tree_visit(
const std::vector<nb::object>& trees,
std::function<void(const std::vector<nb::object>&)> visitor);
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
nb::object tree_map(

View File

@ -121,16 +121,13 @@ class TestVmap(mlx_tests.MLXTestCase):
expected = my_fun(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
out = mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
out = mx.vmap(my_fun, in_axes={"a": 0, "b": (0, 0)}, out_axes=0)(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
tree = {
@ -140,7 +137,7 @@ class TestVmap(mlx_tests.MLXTestCase):
mx.random.uniform(shape=(4, 2)),
),
}
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
out = mx.vmap(my_fun, in_axes={"a": 0, "b": (1, 1)}, out_axes=0)(tree)
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
self.assertTrue(mx.array_equal(out, expected))