From 3fc993f82d8c1bdefd2969c0deff9776ec0a2054 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 2 Apr 2024 18:07:23 -0700 Subject: [PATCH] Properly handle negative axes in python vmap (#944) --- python/src/transforms.cpp | 77 ++++++++++++++++++++++++------ python/src/trees.cpp | 99 +++++++++++++++++++++++++++++++-------- python/src/trees.h | 3 ++ python/tests/test_vmap.py | 9 ++-- 4 files changed, 147 insertions(+), 41 deletions(-) diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index fa9f9235a..01b67b6a3 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -16,6 +16,7 @@ #include "mlx/graph_utils.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" +#include "mlx/utils.h" #include "python/src/trees.h" namespace nb = nanobind; @@ -265,26 +266,72 @@ auto py_vmap( const nb::object& out_axes) { return [fun, in_axes, out_axes](const nb::args& args) { auto axes_to_flat_tree = [](const nb::object& tree, - const nb::object& axes) { - auto tree_axes = tree_map( - {tree, axes}, - [](const std::vector& inputs) { return inputs[1]; }); + const nb::object& axes, + bool output_axes) { std::vector flat_axes; - tree_visit(tree_axes, [&flat_axes](nb::handle obj) { - if (obj.is_none()) { - flat_axes.push_back(-1); - } else if (nb::isinstance(obj)) { - flat_axes.push_back(nb::cast(nb::cast(obj))); - } else { - throw std::invalid_argument("[vmap] axis must be int or None."); - } - }); + bool encountered_tuple = false; + tree_visit( + {tree, axes}, + [&flat_axes, &encountered_tuple, output_axes]( + const std::vector& inputs) { + if (nb::isinstance(inputs[0])) { + if (inputs[1].is_none()) { + flat_axes.push_back(-1); + } else if (nb::isinstance(inputs[1])) { + int axis = nb::cast(nb::cast(inputs[1])); + const array& x = nb::cast(inputs[0]); + if (axis < 0) { + axis += x.ndim() + output_axes; + } + if (axis < 0 || axis >= (x.ndim() + output_axes)) { + std::ostringstream msg; + msg << "[vmap] Invalid" << (output_axes ? " output " : " ") + << "vectorization axis " << axis + << " for array with shape " << x.shape(); + throw std::invalid_argument(msg.str()); + } + flat_axes.push_back(axis); + } else if (nb::isinstance(inputs[1])) { + encountered_tuple = true; + auto l = nb::cast(inputs[1]); + if (l.size() == 1 && nb::isinstance(l[0])) { + int axis = nb::cast(nb::cast(l[0])); + const array& x = nb::cast(inputs[0]); + if (axis < 0) { + axis += x.ndim() + output_axes; + } + if (axis < 0 || axis >= (x.ndim() + output_axes)) { + std::ostringstream msg; + msg << "[vmap] Invalid" << (output_axes ? " output " : " ") + << "vectorization axis " << axis + << " for array with shape " << x.shape(); + throw std::invalid_argument(msg.str()); + } + flat_axes.push_back(axis); + } else if (l.size() == 1 && l[0].is_none()) { + flat_axes.push_back(-1); + } else { + throw std::invalid_argument( + "[vmap] axis must be int or None."); + } + } else { + throw std::invalid_argument("[vmap] axis must be int or None."); + } + } else { + throw std::invalid_argument( + "[vmap] The arguments should contain only arrays"); + } + }); + if (encountered_tuple && !nb::isinstance(tree)) { + throw std::invalid_argument("[vmap] axis must be int or None."); + } return flat_axes; }; // Inputs must be array or tree of arrays auto inputs = tree_flatten(args, true); - auto flat_in_axes = axes_to_flat_tree(args, in_axes); + auto flat_in_axes = + axes_to_flat_tree((args.size() == 1) ? args[0] : args, in_axes, false); // py_value_out will hold the output of the python function in order to be // able to reconstruct the python tree of extra return values @@ -302,7 +349,7 @@ auto py_vmap( auto [trace_inputs, trace_outputs] = detail::vmap_trace(vmap_fn, inputs, flat_in_axes); - auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes); + auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true); // Perform the vmap auto outputs = detail::vmap_replace( diff --git a/python/src/trees.cpp b/python/src/trees.cpp index 29fe9d4bb..b4ae53746 100644 --- a/python/src/trees.cpp +++ b/python/src/trees.cpp @@ -2,26 +2,6 @@ #include "python/src/trees.h" -void tree_visit(nb::object tree, std::function visitor) { - std::function recurse; - recurse = [&](nb::handle subtree) { - if (nb::isinstance(subtree) || - nb::isinstance(subtree)) { - for (auto item : subtree) { - recurse(item); - } - } else if (nb::isinstance(subtree)) { - for (auto item : nb::cast(subtree)) { - recurse(item.second); - } - } else { - visitor(subtree); - } - }; - - recurse(tree); -} - template void validate_subtrees(const std::vector& subtrees) { int len = nb::cast(subtrees[0]).size(); @@ -107,6 +87,85 @@ nb::object tree_map( }); } +void tree_visit( + const std::vector& trees, + std::function&)> visitor) { + std::function&)> recurse; + + recurse = [&](const std::vector& subtrees) { + if (nb::isinstance(subtrees[0])) { + std::vector items(subtrees.size()); + validate_subtrees(subtrees); + for (int i = 0; i < nb::cast(subtrees[0]).size(); ++i) { + for (int j = 0; j < subtrees.size(); ++j) { + if (nb::isinstance(subtrees[j])) { + items[j] = nb::cast(subtrees[j])[i]; + } else { + items[j] = subtrees[j]; + } + } + recurse(items); + } + } else if (nb::isinstance(subtrees[0])) { + // Check the rest of the subtrees + std::vector items(subtrees.size()); + int len = nb::cast(subtrees[0]).size(); + validate_subtrees(subtrees); + for (int i = 0; i < len; ++i) { + for (int j = 0; j < subtrees.size(); ++j) { + if (nb::isinstance(subtrees[j])) { + items[j] = nb::cast(subtrees[j])[i]; + } else { + items[j] = subtrees[j]; + } + } + recurse(items); + } + } else if (nb::isinstance(subtrees[0])) { + std::vector items(subtrees.size()); + validate_subtrees(subtrees); + for (auto item : nb::cast(subtrees[0])) { + for (int j = 0; j < subtrees.size(); ++j) { + if (nb::isinstance(subtrees[j])) { + auto subdict = nb::cast(subtrees[j]); + if (!subdict.contains(item.first)) { + throw std::invalid_argument( + "[tree_visit] Tree is not a valid prefix tree of the first tree."); + } + items[j] = subdict[item.first]; + } else { + items[j] = subtrees[j]; + } + } + recurse(items); + } + } else { + visitor(subtrees); + } + }; + return recurse(trees); +} + +void tree_visit(nb::object tree, std::function visitor) { + std::function recurse; + recurse = [&](nb::handle subtree) { + if (nb::isinstance(subtree) || + nb::isinstance(subtree)) { + for (auto item : subtree) { + recurse(item); + } + } else if (nb::isinstance(subtree)) { + for (auto item : nb::cast(subtree)) { + recurse(item.second); + } + } else { + visitor(subtree); + } + }; + + recurse(tree); +} + void tree_visit_update( nb::object tree, std::function visitor) { diff --git a/python/src/trees.h b/python/src/trees.h index 44d9d9b0e..931b3ea6b 100644 --- a/python/src/trees.h +++ b/python/src/trees.h @@ -7,6 +7,9 @@ namespace nb = nanobind; using namespace mlx::core; +void tree_visit( + const std::vector& trees, + std::function&)> visitor); void tree_visit(nb::object tree, std::function visitor); nb::object tree_map( diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 08512fb29..5c4690640 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -121,16 +121,13 @@ class TestVmap(mlx_tests.MLXTestCase): expected = my_fun(tree) self.assertTrue(mx.array_equal(out, my_fun(tree))) - with self.assertRaises(ValueError): - mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree) - with self.assertRaises(ValueError): mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree) - out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree) + out = mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree) self.assertTrue(mx.array_equal(out, my_fun(tree))) - out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree) + out = mx.vmap(my_fun, in_axes={"a": 0, "b": (0, 0)}, out_axes=0)(tree) self.assertTrue(mx.array_equal(out, my_fun(tree))) tree = { @@ -140,7 +137,7 @@ class TestVmap(mlx_tests.MLXTestCase): mx.random.uniform(shape=(4, 2)), ), } - out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree) + out = mx.vmap(my_fun, in_axes={"a": 0, "b": (1, 1)}, out_axes=0)(tree) expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T self.assertTrue(mx.array_equal(out, expected))