mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-05 08:41:13 +08:00
Properly handle negative axes in python vmap (#944)
This commit is contained in:
parent
741eb28443
commit
3fc993f82d
@ -16,6 +16,7 @@
|
|||||||
#include "mlx/graph_utils.h"
|
#include "mlx/graph_utils.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
#include "python/src/trees.h"
|
#include "python/src/trees.h"
|
||||||
|
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
@ -265,26 +266,72 @@ auto py_vmap(
|
|||||||
const nb::object& out_axes) {
|
const nb::object& out_axes) {
|
||||||
return [fun, in_axes, out_axes](const nb::args& args) {
|
return [fun, in_axes, out_axes](const nb::args& args) {
|
||||||
auto axes_to_flat_tree = [](const nb::object& tree,
|
auto axes_to_flat_tree = [](const nb::object& tree,
|
||||||
const nb::object& axes) {
|
const nb::object& axes,
|
||||||
auto tree_axes = tree_map(
|
bool output_axes) {
|
||||||
{tree, axes},
|
|
||||||
[](const std::vector<nb::object>& inputs) { return inputs[1]; });
|
|
||||||
std::vector<int> flat_axes;
|
std::vector<int> flat_axes;
|
||||||
tree_visit(tree_axes, [&flat_axes](nb::handle obj) {
|
bool encountered_tuple = false;
|
||||||
if (obj.is_none()) {
|
tree_visit(
|
||||||
flat_axes.push_back(-1);
|
{tree, axes},
|
||||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
[&flat_axes, &encountered_tuple, output_axes](
|
||||||
flat_axes.push_back(nb::cast<int>(nb::cast<nb::int_>(obj)));
|
const std::vector<nb::object>& inputs) {
|
||||||
} else {
|
if (nb::isinstance<array>(inputs[0])) {
|
||||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
if (inputs[1].is_none()) {
|
||||||
}
|
flat_axes.push_back(-1);
|
||||||
});
|
} else if (nb::isinstance<nb::int_>(inputs[1])) {
|
||||||
|
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
|
||||||
|
const array& x = nb::cast<array>(inputs[0]);
|
||||||
|
if (axis < 0) {
|
||||||
|
axis += x.ndim() + output_axes;
|
||||||
|
}
|
||||||
|
if (axis < 0 || axis >= (x.ndim() + output_axes)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
|
||||||
|
<< "vectorization axis " << axis
|
||||||
|
<< " for array with shape " << x.shape();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
flat_axes.push_back(axis);
|
||||||
|
} else if (nb::isinstance<nb::tuple>(inputs[1])) {
|
||||||
|
encountered_tuple = true;
|
||||||
|
auto l = nb::cast<nb::tuple>(inputs[1]);
|
||||||
|
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
|
||||||
|
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
|
||||||
|
const array& x = nb::cast<array>(inputs[0]);
|
||||||
|
if (axis < 0) {
|
||||||
|
axis += x.ndim() + output_axes;
|
||||||
|
}
|
||||||
|
if (axis < 0 || axis >= (x.ndim() + output_axes)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
|
||||||
|
<< "vectorization axis " << axis
|
||||||
|
<< " for array with shape " << x.shape();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
flat_axes.push_back(axis);
|
||||||
|
} else if (l.size() == 1 && l[0].is_none()) {
|
||||||
|
flat_axes.push_back(-1);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[vmap] axis must be int or None.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[vmap] The arguments should contain only arrays");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (encountered_tuple && !nb::isinstance<array>(tree)) {
|
||||||
|
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||||
|
}
|
||||||
return flat_axes;
|
return flat_axes;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Inputs must be array or tree of arrays
|
// Inputs must be array or tree of arrays
|
||||||
auto inputs = tree_flatten(args, true);
|
auto inputs = tree_flatten(args, true);
|
||||||
auto flat_in_axes = axes_to_flat_tree(args, in_axes);
|
auto flat_in_axes =
|
||||||
|
axes_to_flat_tree((args.size() == 1) ? args[0] : args, in_axes, false);
|
||||||
|
|
||||||
// py_value_out will hold the output of the python function in order to be
|
// py_value_out will hold the output of the python function in order to be
|
||||||
// able to reconstruct the python tree of extra return values
|
// able to reconstruct the python tree of extra return values
|
||||||
@ -302,7 +349,7 @@ auto py_vmap(
|
|||||||
auto [trace_inputs, trace_outputs] =
|
auto [trace_inputs, trace_outputs] =
|
||||||
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||||
|
|
||||||
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes);
|
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
|
||||||
|
|
||||||
// Perform the vmap
|
// Perform the vmap
|
||||||
auto outputs = detail::vmap_replace(
|
auto outputs = detail::vmap_replace(
|
||||||
|
@ -2,26 +2,6 @@
|
|||||||
|
|
||||||
#include "python/src/trees.h"
|
#include "python/src/trees.h"
|
||||||
|
|
||||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
|
|
||||||
std::function<void(nb::handle)> recurse;
|
|
||||||
recurse = [&](nb::handle subtree) {
|
|
||||||
if (nb::isinstance<nb::list>(subtree) ||
|
|
||||||
nb::isinstance<nb::tuple>(subtree)) {
|
|
||||||
for (auto item : subtree) {
|
|
||||||
recurse(item);
|
|
||||||
}
|
|
||||||
} else if (nb::isinstance<nb::dict>(subtree)) {
|
|
||||||
for (auto item : nb::cast<nb::dict>(subtree)) {
|
|
||||||
recurse(item.second);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
visitor(subtree);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
recurse(tree);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U, typename V>
|
template <typename T, typename U, typename V>
|
||||||
void validate_subtrees(const std::vector<nb::object>& subtrees) {
|
void validate_subtrees(const std::vector<nb::object>& subtrees) {
|
||||||
int len = nb::cast<T>(subtrees[0]).size();
|
int len = nb::cast<T>(subtrees[0]).size();
|
||||||
@ -107,6 +87,85 @@ nb::object tree_map(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void tree_visit(
|
||||||
|
const std::vector<nb::object>& trees,
|
||||||
|
std::function<void(const std::vector<nb::object>&)> visitor) {
|
||||||
|
std::function<void(const std::vector<nb::object>&)> recurse;
|
||||||
|
|
||||||
|
recurse = [&](const std::vector<nb::object>& subtrees) {
|
||||||
|
if (nb::isinstance<nb::list>(subtrees[0])) {
|
||||||
|
std::vector<nb::object> items(subtrees.size());
|
||||||
|
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
|
||||||
|
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
|
||||||
|
for (int j = 0; j < subtrees.size(); ++j) {
|
||||||
|
if (nb::isinstance<nb::list>(subtrees[j])) {
|
||||||
|
items[j] = nb::cast<nb::list>(subtrees[j])[i];
|
||||||
|
} else {
|
||||||
|
items[j] = subtrees[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recurse(items);
|
||||||
|
}
|
||||||
|
} else if (nb::isinstance<nb::tuple>(subtrees[0])) {
|
||||||
|
// Check the rest of the subtrees
|
||||||
|
std::vector<nb::object> items(subtrees.size());
|
||||||
|
int len = nb::cast<nb::tuple>(subtrees[0]).size();
|
||||||
|
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
for (int j = 0; j < subtrees.size(); ++j) {
|
||||||
|
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
||||||
|
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
|
||||||
|
} else {
|
||||||
|
items[j] = subtrees[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recurse(items);
|
||||||
|
}
|
||||||
|
} else if (nb::isinstance<nb::dict>(subtrees[0])) {
|
||||||
|
std::vector<nb::object> items(subtrees.size());
|
||||||
|
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
||||||
|
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
|
||||||
|
for (int j = 0; j < subtrees.size(); ++j) {
|
||||||
|
if (nb::isinstance<nb::dict>(subtrees[j])) {
|
||||||
|
auto subdict = nb::cast<nb::dict>(subtrees[j]);
|
||||||
|
if (!subdict.contains(item.first)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[tree_visit] Tree is not a valid prefix tree of the first tree.");
|
||||||
|
}
|
||||||
|
items[j] = subdict[item.first];
|
||||||
|
} else {
|
||||||
|
items[j] = subtrees[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recurse(items);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
visitor(subtrees);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return recurse(trees);
|
||||||
|
}
|
||||||
|
|
||||||
|
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
|
||||||
|
std::function<void(nb::handle)> recurse;
|
||||||
|
recurse = [&](nb::handle subtree) {
|
||||||
|
if (nb::isinstance<nb::list>(subtree) ||
|
||||||
|
nb::isinstance<nb::tuple>(subtree)) {
|
||||||
|
for (auto item : subtree) {
|
||||||
|
recurse(item);
|
||||||
|
}
|
||||||
|
} else if (nb::isinstance<nb::dict>(subtree)) {
|
||||||
|
for (auto item : nb::cast<nb::dict>(subtree)) {
|
||||||
|
recurse(item.second);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
visitor(subtree);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
recurse(tree);
|
||||||
|
}
|
||||||
|
|
||||||
void tree_visit_update(
|
void tree_visit_update(
|
||||||
nb::object tree,
|
nb::object tree,
|
||||||
std::function<nb::object(nb::handle)> visitor) {
|
std::function<nb::object(nb::handle)> visitor) {
|
||||||
|
@ -7,6 +7,9 @@
|
|||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
void tree_visit(
|
||||||
|
const std::vector<nb::object>& trees,
|
||||||
|
std::function<void(const std::vector<nb::object>&)> visitor);
|
||||||
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
|
void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor);
|
||||||
|
|
||||||
nb::object tree_map(
|
nb::object tree_map(
|
||||||
|
@ -121,16 +121,13 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
expected = my_fun(tree)
|
expected = my_fun(tree)
|
||||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
|
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
|
||||||
|
|
||||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
|
out = mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
|
||||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||||
|
|
||||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
|
out = mx.vmap(my_fun, in_axes={"a": 0, "b": (0, 0)}, out_axes=0)(tree)
|
||||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||||
|
|
||||||
tree = {
|
tree = {
|
||||||
@ -140,7 +137,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
mx.random.uniform(shape=(4, 2)),
|
mx.random.uniform(shape=(4, 2)),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
|
out = mx.vmap(my_fun, in_axes={"a": 0, "b": (1, 1)}, out_axes=0)(tree)
|
||||||
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
|
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
|
||||||
self.assertTrue(mx.array_equal(out, expected))
|
self.assertTrue(mx.array_equal(out, expected))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user