Properly handle negative axes in python vmap (#944)

This commit is contained in:
Angelos Katharopoulos
2024-04-02 18:07:23 -07:00
committed by GitHub
parent 741eb28443
commit 3fc993f82d
4 changed files with 147 additions and 41 deletions

View File

@@ -16,6 +16,7 @@
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
#include "python/src/trees.h"
namespace nb = nanobind;
@@ -265,26 +266,72 @@ auto py_vmap(
const nb::object& out_axes) {
return [fun, in_axes, out_axes](const nb::args& args) {
auto axes_to_flat_tree = [](const nb::object& tree,
const nb::object& axes) {
auto tree_axes = tree_map(
{tree, axes},
[](const std::vector<nb::object>& inputs) { return inputs[1]; });
const nb::object& axes,
bool output_axes) {
std::vector<int> flat_axes;
tree_visit(tree_axes, [&flat_axes](nb::handle obj) {
if (obj.is_none()) {
flat_axes.push_back(-1);
} else if (nb::isinstance<nb::int_>(obj)) {
flat_axes.push_back(nb::cast<int>(nb::cast<nb::int_>(obj)));
} else {
throw std::invalid_argument("[vmap] axis must be int or None.");
}
});
bool encountered_tuple = false;
tree_visit(
{tree, axes},
[&flat_axes, &encountered_tuple, output_axes](
const std::vector<nb::object>& inputs) {
if (nb::isinstance<array>(inputs[0])) {
if (inputs[1].is_none()) {
flat_axes.push_back(-1);
} else if (nb::isinstance<nb::int_>(inputs[1])) {
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
const array& x = nb::cast<array>(inputs[0]);
if (axis < 0) {
axis += x.ndim() + output_axes;
}
if (axis < 0 || axis >= (x.ndim() + output_axes)) {
std::ostringstream msg;
msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
<< "vectorization axis " << axis
<< " for array with shape " << x.shape();
throw std::invalid_argument(msg.str());
}
flat_axes.push_back(axis);
} else if (nb::isinstance<nb::tuple>(inputs[1])) {
encountered_tuple = true;
auto l = nb::cast<nb::tuple>(inputs[1]);
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
const array& x = nb::cast<array>(inputs[0]);
if (axis < 0) {
axis += x.ndim() + output_axes;
}
if (axis < 0 || axis >= (x.ndim() + output_axes)) {
std::ostringstream msg;
msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
<< "vectorization axis " << axis
<< " for array with shape " << x.shape();
throw std::invalid_argument(msg.str());
}
flat_axes.push_back(axis);
} else if (l.size() == 1 && l[0].is_none()) {
flat_axes.push_back(-1);
} else {
throw std::invalid_argument(
"[vmap] axis must be int or None.");
}
} else {
throw std::invalid_argument("[vmap] axis must be int or None.");
}
} else {
throw std::invalid_argument(
"[vmap] The arguments should contain only arrays");
}
});
if (encountered_tuple && !nb::isinstance<array>(tree)) {
throw std::invalid_argument("[vmap] axis must be int or None.");
}
return flat_axes;
};
// Inputs must be array or tree of arrays
auto inputs = tree_flatten(args, true);
auto flat_in_axes = axes_to_flat_tree(args, in_axes);
auto flat_in_axes =
axes_to_flat_tree((args.size() == 1) ? args[0] : args, in_axes, false);
// py_value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values
@@ -302,7 +349,7 @@ auto py_vmap(
auto [trace_inputs, trace_outputs] =
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes);
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
// Perform the vmap
auto outputs = detail::vmap_replace(