mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
Properly handle negative axes in python vmap (#944)
This commit is contained in:
committed by
GitHub
parent
741eb28443
commit
3fc993f82d
@@ -16,6 +16,7 @@
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
@@ -265,26 +266,72 @@ auto py_vmap(
|
||||
const nb::object& out_axes) {
|
||||
return [fun, in_axes, out_axes](const nb::args& args) {
|
||||
auto axes_to_flat_tree = [](const nb::object& tree,
|
||||
const nb::object& axes) {
|
||||
auto tree_axes = tree_map(
|
||||
{tree, axes},
|
||||
[](const std::vector<nb::object>& inputs) { return inputs[1]; });
|
||||
const nb::object& axes,
|
||||
bool output_axes) {
|
||||
std::vector<int> flat_axes;
|
||||
tree_visit(tree_axes, [&flat_axes](nb::handle obj) {
|
||||
if (obj.is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
flat_axes.push_back(nb::cast<int>(nb::cast<nb::int_>(obj)));
|
||||
} else {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
});
|
||||
bool encountered_tuple = false;
|
||||
tree_visit(
|
||||
{tree, axes},
|
||||
[&flat_axes, &encountered_tuple, output_axes](
|
||||
const std::vector<nb::object>& inputs) {
|
||||
if (nb::isinstance<array>(inputs[0])) {
|
||||
if (inputs[1].is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else if (nb::isinstance<nb::int_>(inputs[1])) {
|
||||
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
|
||||
const array& x = nb::cast<array>(inputs[0]);
|
||||
if (axis < 0) {
|
||||
axis += x.ndim() + output_axes;
|
||||
}
|
||||
if (axis < 0 || axis >= (x.ndim() + output_axes)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
|
||||
<< "vectorization axis " << axis
|
||||
<< " for array with shape " << x.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
flat_axes.push_back(axis);
|
||||
} else if (nb::isinstance<nb::tuple>(inputs[1])) {
|
||||
encountered_tuple = true;
|
||||
auto l = nb::cast<nb::tuple>(inputs[1]);
|
||||
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
|
||||
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
|
||||
const array& x = nb::cast<array>(inputs[0]);
|
||||
if (axis < 0) {
|
||||
axis += x.ndim() + output_axes;
|
||||
}
|
||||
if (axis < 0 || axis >= (x.ndim() + output_axes)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
|
||||
<< "vectorization axis " << axis
|
||||
<< " for array with shape " << x.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
flat_axes.push_back(axis);
|
||||
} else if (l.size() == 1 && l[0].is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] axis must be int or None.");
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[vmap] The arguments should contain only arrays");
|
||||
}
|
||||
});
|
||||
if (encountered_tuple && !nb::isinstance<array>(tree)) {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
return flat_axes;
|
||||
};
|
||||
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
auto flat_in_axes = axes_to_flat_tree(args, in_axes);
|
||||
auto flat_in_axes =
|
||||
axes_to_flat_tree((args.size() == 1) ? args[0] : args, in_axes, false);
|
||||
|
||||
// py_value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
@@ -302,7 +349,7 @@ auto py_vmap(
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||
|
||||
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes);
|
||||
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
|
||||
|
||||
// Perform the vmap
|
||||
auto outputs = detail::vmap_replace(
|
||||
|
||||
Reference in New Issue
Block a user