diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 5ef1f54c2..f08be3e9d 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -61,6 +61,7 @@ Operations mean min minimum + moveaxis multiply negative ones @@ -87,6 +88,7 @@ Operations stop_gradient subtract sum + swapaxes take take_along_axis tan diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 066242486..a3fea63e5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -677,6 +677,53 @@ array pad( s); } +array moveaxis( + const array& a, + int source, + int destination, + StreamOrDevice s /* = {} */) { + auto check_ax = [&a](int ax) { + auto ndim = static_cast(a.ndim()); + if (ax < -ndim || ax >= ndim) { + std::ostringstream msg; + msg << "[moveaxis] Invalid axis " << ax << " for array with " << ndim + << " dimensions."; + throw std::out_of_range(msg.str()); + } + return ax < 0 ? ax + ndim : ax; + }; + source = check_ax(source); + destination = check_ax(destination); + std::vector reorder(a.ndim()); + std::iota(reorder.begin(), reorder.end(), 0); + reorder.erase(reorder.begin() + source); + reorder.insert(reorder.begin() + destination, source); + return transpose(a, reorder, s); +} + +array swapaxes( + const array& a, + int axis1, + int axis2, + StreamOrDevice s /* = {} */) { + auto check_ax = [&a](int ax) { + auto ndim = static_cast(a.ndim()); + if (ax < -ndim || ax >= ndim) { + std::ostringstream msg; + msg << "[swapaxes] Invalid axis " << ax << " for array with " << ndim + << " dimensions."; + throw std::out_of_range(msg.str()); + } + return ax < 0 ? ax + ndim : ax; + }; + axis1 = check_ax(axis1); + axis2 = check_ax(axis2); + std::vector reorder(a.ndim()); + std::iota(reorder.begin(), reorder.end(), 0); + std::swap(reorder[axis1], reorder[axis2]); + return transpose(a, reorder, s); +} + array transpose( const array& a, std::vector axes, diff --git a/mlx/ops.h b/mlx/ops.h index 686dbfb79..81d457281 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -183,6 +183,16 @@ inline array transpose( return transpose(a, std::vector(axes), s); } +/** Swap two axes of an array. */ +array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {}); + +/** Move an axis of an array. */ +array moveaxis( + const array& a, + int source, + int destination, + StreamOrDevice s = {}); + /** Pad an array with a constant value */ array pad( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 58d952e92..e1f17ae7d 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include #include #include @@ -512,7 +511,26 @@ array Concatenate::jvp( std::pair Concatenate::vmap( const std::vector& inputs, const std::vector& axes) { - throw std::runtime_error("Concatenate vmap is NYI."); + std::vector t_inputs; + // Find the first vmapped input + int i = 0; + for (; i < axes.size(); i++) { + t_inputs.push_back(inputs[i]); + if (axes[i] >= 0) { + break; + } + } + auto out_ax = axes[i++]; + // Move vmap axes to the same spot. + for (; i < axes.size(); ++i) { + if (out_ax != axes[i] && axes[i] >= 0) { + t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream())); + } else { + t_inputs.push_back(inputs[i]); + } + } + auto axis = axis_ + (axis_ >= out_ax); + return {concatenate(t_inputs, axis, stream()), out_ax}; } bool Concatenate::is_equivalent(const Primitive& other) const { @@ -1054,7 +1072,53 @@ std::pair Full::vmap( std::pair Gather::vmap( const std::vector& inputs, const std::vector& axes) { - throw std::runtime_error("Gather vmap is NYI, please change slices instead"); + auto& src = inputs[0]; + std::vector indices(inputs.begin() + 1, inputs.end()); + auto gather_axes = axes_; + auto slice_sizes = slice_sizes_; + auto src_vmapped = axes[0] >= 0; + auto indices_vmapped = + std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; }); + auto out_ax = + *std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; }); + + // Reorder all the index arrays so the vmap axis is in the same spot. + for (int i = 1; i < axes.size(); ++i) { + if (out_ax != axes[i] && axes[i] >= 0) { + indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream()); + } + } + + if (src_vmapped) { + int max_dims = 0; + for (auto& idx : indices) { + max_dims = std::max(static_cast(idx.ndim()), max_dims); + } + auto new_ax_loc = + std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) { + return a >= out_ax; + }); + for (; new_ax_loc < gather_axes.end(); new_ax_loc++) { + (*new_ax_loc)++; + } + if (indices_vmapped) { + // Make a new index array for the vmapped dimension + // Reshape it so it broadcasts with other index arrays + // Update gather axes and slice sizes accordingly + auto shape = std::vector(max_dims - out_ax, 1); + auto vmap_inds = arange(0, src.shape(out_ax), stream()); + shape[0] = vmap_inds.shape(0); + vmap_inds = reshape(vmap_inds, shape, stream()); + slice_sizes.insert(slice_sizes.begin() + out_ax, 1); + auto new_ax_idx = new_ax_loc - gather_axes.begin(); + gather_axes.insert(new_ax_loc, out_ax); + indices.insert(indices.begin() + new_ax_idx, vmap_inds); + } else { + slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0])); + out_ax = max_dims + axes[0]; + } + } + return {gather(src, indices, gather_axes, slice_sizes, stream()), out_ax}; } std::vector Gather::vjp( @@ -1997,8 +2061,15 @@ std::pair Sinh::vmap( std::pair Slice::vmap( const std::vector& inputs, const std::vector& axes) { - // TODO implement - return {array(1.0f), axes[0]}; + auto start = start_indices_; + auto stop = end_indices_; + auto strides = strides_; + auto ax = axes[0]; + auto& input = inputs[0]; + start.insert(start.begin() + ax, 0); + stop.insert(stop.begin() + ax, input.shape(ax)); + strides.insert(strides.begin() + ax, 1); + return {slice(input, start, stop, strides, stream()), ax}; } std::vector Slice::vjp( diff --git a/python/src/array.cpp b/python/src/array.cpp index af0afef4d..906632e5f 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -862,6 +862,22 @@ void init_array(py::module_& m) { py::kw_only(), "stream"_a = none, "See :func:`any`.") + .def( + "moveaxis", + &moveaxis, + "source"_a, + "destination"_a, + py::kw_only(), + "stream"_a = none, + "See :func:`moveaxis`.") + .def( + "swapaxes", + &swapaxes, + "axis1"_a, + "axis2"_a, + py::kw_only(), + "stream"_a = none, + "See :func:`moveaxis`.") .def( "transpose", [](const array& a, py::args axes, StreamOrDevice s) { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 76a057b20..ab77f3c42 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1591,6 +1591,50 @@ void init_ops(py::module_& m) { Returns: array: The ceil of ``a``. )pbdoc"); + m.def( + "moveaxis", + &moveaxis, + "a"_a, + py::pos_only(), + "source"_a, + "destiantion"_a, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + moveaxis(a: array, /, source: int, destination: int, *, stream: Union[None, Stream, Device] = None) -> array + + Move an axis to a new position. + + Args: + a (array): Input array. + source (int): Specifies the source axis. + destination (int): Specifies the destination axis. + + Returns: + array: The array with the axis moved. + )pbdoc"); + m.def( + "swapaxes", + &swapaxes, + "a"_a, + py::pos_only(), + "axis1"_a, + "axis2"_a, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + swapaxes(a: array, /, axis1 : int, axis2: int, *, stream: Union[None, Stream, Device] = None) -> array + + Swap two axes of an array. + + Args: + a (array): Input array. + axis1 (int): Specifies the first axis. + axis2 (int): Specifies the second axis. + + Returns: + array: The array with swapped axes. + )pbdoc"); m.def( "transpose", [](const array& a, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 09263fefe..a35a82dc6 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -375,6 +375,13 @@ class TestOps(mlx_tests.MLXTestCase): self.assertListEqual(mx.transpose(x, axes=(0, 2, 1)).tolist(), expected) + def test_move_swap_axes(self): + x = mx.zeros((2, 3, 4)) + self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2]) + self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2]) + self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2]) + self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2]) + def test_sum(self): x = mx.array( [ diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index c56eda80d..63f616aa3 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -163,6 +163,61 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertTrue(mx.array_equal(out["a"].T, expected["a"])) self.assertTrue(mx.array_equal(out["b"], expected["b"])) + def test_vmap_indexing(self): + x = mx.arange(16).reshape(2, 2, 2, 2) + inds = mx.array([[0, 1, 0], [1, 1, 0]]) + + out = mx.vmap(lambda x, y: x[y], in_axes=(0, 0))(x, inds) + expected = mx.array( + [ + [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]], + [[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]], + ] + ) + self.assertTrue(mx.array_equal(out, expected)) + + out = mx.vmap(lambda x, y: x[y], in_axes=(0, None))(x, inds) + expected = mx.array( + [ + [ + [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]], + [[[4, 5], [6, 7]], [[4, 5], [6, 7]], [[0, 1], [2, 3]]], + ], + [ + [[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]], + [[[12, 13], [14, 15]], [[12, 13], [14, 15]], [[8, 9], [10, 11]]], + ], + ] + ) + self.assertTrue(mx.array_equal(out, expected)) + + out = mx.vmap(lambda x, y: x[y], in_axes=(None, 0))(x, inds) + expected = mx.array( + [ + [ + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], + [[[8, 9], [10, 11]], [[12, 13], [14, 15]]], + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], + ], + [ + [[[8, 9], [10, 11]], [[12, 13], [14, 15]]], + [[[8, 9], [10, 11]], [[12, 13], [14, 15]]], + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], + ], + ] + ) + self.assertTrue(mx.array_equal(out, expected)) + + inds2 = mx.array([[0, 1, 0], [0, 1, 0]]) + out = mx.vmap(lambda x, y, z: x[y, z], in_axes=(None, 0, 0))(x, inds, inds2) + expected = mx.array( + [ + [[[0, 1], [2, 3]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]], + [[[8, 9], [10, 11]], [[12, 13], [14, 15]], [[0, 1], [2, 3]]], + ] + ) + self.assertTrue(mx.array_equal(out, expected)) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index d9df0262d..87437223c 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include #include @@ -205,6 +204,46 @@ TEST_CASE("test split") { CHECK(array_equal(out[3], array({2, 3, 4})).item()); } +TEST_CASE("test swap and move axes") { + // Test swapaxes + array a(0.0); + CHECK_THROWS(swapaxes(a, 0, 0)); + + a = zeros({2}); + CHECK_THROWS(swapaxes(a, 0, 1)); + CHECK_EQ(swapaxes(a, 0, 0).shape(), std::vector{2}); + CHECK_EQ(swapaxes(a, -1, -1).shape(), std::vector{2}); + + a = zeros({2, 3, 4}); + CHECK_THROWS(swapaxes(a, 0, -4)); + CHECK_THROWS(swapaxes(a, 0, 3)); + CHECK_THROWS(swapaxes(a, 3, 0)); + CHECK_THROWS(swapaxes(a, -4, 0)); + CHECK_EQ(swapaxes(a, 0, 2).shape(), std::vector{4, 3, 2}); + CHECK_EQ(swapaxes(a, 0, 1).shape(), std::vector{3, 2, 4}); + CHECK_EQ(swapaxes(a, 0, -1).shape(), std::vector{4, 3, 2}); + CHECK_EQ(swapaxes(a, -2, 2).shape(), std::vector{2, 4, 3}); + + // Test moveaxis + a = array(0.0); + CHECK_THROWS(moveaxis(a, 0, 0)); + + a = zeros({2}); + CHECK_THROWS(moveaxis(a, 0, 1)); + CHECK_EQ(moveaxis(a, 0, 0).shape(), std::vector{2}); + CHECK_EQ(moveaxis(a, -1, -1).shape(), std::vector{2}); + + a = zeros({2, 3, 4}); + CHECK_THROWS(moveaxis(a, 0, -4)); + CHECK_THROWS(moveaxis(a, 0, 3)); + CHECK_THROWS(moveaxis(a, 3, 0)); + CHECK_THROWS(moveaxis(a, -4, 0)); + CHECK_EQ(moveaxis(a, 0, 2).shape(), std::vector{3, 4, 2}); + CHECK_EQ(moveaxis(a, 0, 1).shape(), std::vector{3, 2, 4}); + CHECK_EQ(moveaxis(a, 0, -1).shape(), std::vector{3, 4, 2}); + CHECK_EQ(moveaxis(a, -2, 2).shape(), std::vector{2, 4, 3}); +} + TEST_CASE("test transpose") { array x(1); auto y = transpose(x); @@ -2003,4 +2042,4 @@ TEST_CASE("test eye with negative k offset") { {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {4, 3}); CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item()); -} \ No newline at end of file +} diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 41175b0f5..b30fd0a21 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -248,3 +248,104 @@ TEST_CASE("test vmap creation ops") { CHECK(array_equal(out, expected).item()); } } + +TEST_CASE("test vmap slice") { + { + auto fun = [](array in) { return slice(in, {4}, {8}, {2}); }; + auto x = reshape(arange(16), {2, 8}); + auto out = vmap(fun)(x); + auto expected = reshape(array({4, 6, 12, 14}), {2, 2}); + CHECK(array_equal(out, expected).item()); + } + + { + auto fun = [](array in) { return slice(in, {0, 1}, {2, 3}); }; + auto x = reshape(arange(12), {2, 2, 3}); + auto out = vmap(fun, 1, 0)(x); + auto expected = reshape(array({1, 2, 7, 8, 4, 5, 10, 11}), {2, 2, 2}); + CHECK(array_equal(out, expected).item()); + } +} + +TEST_CASE("test vmap concatenate") { + auto fun = [](std::vector inputs) { + return std::vector{concatenate(inputs, 0)}; + }; + auto x = reshape(arange(4), {2, 2}); + auto y = reshape(arange(4), {2, 2}); + auto out = vmap(fun)({x, y})[0]; + auto expected = reshape(array({0, 1, 0, 1, 2, 3, 2, 3}), {2, 4}); + CHECK(array_equal(out, expected).item()); + out = vmap(fun, {1, 1})({x, y})[0]; + expected = reshape(array({0, 2, 0, 2, 1, 3, 1, 3}), {2, 4}); + CHECK(array_equal(out, expected).item()); + out = vmap(fun, {0, 1})({x, y})[0]; + expected = reshape(array({0, 1, 0, 2, 2, 3, 1, 3}), {2, 4}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test vmap gather") { + { + auto fun = [](std::vector inputs) { + auto src = inputs[0]; + auto indices = inputs[1]; + std::vector slice_sizes = {1, 2, 2}; + auto out = squeeze(gather(src, indices, 0, slice_sizes), 2); + return std::vector{out}; + }; + auto x = zeros({2, 2, 2, 2}); + auto y = array({0, 1, 0, 0, 1, 0}, {2, 3}); + auto out = vmap(fun, {0, -1})({x, y})[0]; + CHECK_EQ(out.shape(), std::vector{2, 2, 3, 2, 2}); + out = vmap(fun, {0, -1}, {3})({x, y})[0]; + CHECK_EQ(out.shape(), std::vector{2, 3, 2, 2, 2}); + } + + { + auto fun = [](std::vector inputs) { + auto src = inputs[0]; + auto indices = inputs[1]; + std::vector slice_sizes = {1, 2, 2}; + auto out = squeeze(gather(src, indices, 0, slice_sizes), 1); + return std::vector{out}; + }; + auto x = zeros({2, 2, 2, 2}); + auto y = array({0, 1, 0, 0, 1, 0}, {2, 3}); + auto out = vmap(fun, {0, 0})({x, y})[0]; + CHECK_EQ(out.shape(), std::vector{2, 3, 2, 2}); + } + + { + auto fun = [](std::vector inputs) { + auto src = inputs[0]; + auto indices = inputs[1]; + std::vector slice_sizes = {1, 2, 2, 2}; + auto out = squeeze(gather(src, indices, 0, slice_sizes), 1); + return std::vector{out}; + }; + auto x = zeros({2, 2, 2, 2}); + auto y = array({0, 1, 0, 0, 1, 0}, {2, 3}); + + auto out = vmap(fun, {-1, 0})({x, y})[0]; + CHECK_EQ(out.shape(), std::vector{2, 3, 2, 2, 2}); + } + + { + auto fun = [](std::vector inputs) { + auto src = inputs[0]; + auto indices = std::vector(inputs.begin() + 1, inputs.end()); + std::vector slice_sizes = {1, 1, 2, 2}; + auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2}); + return std::vector{out}; + }; + auto x = zeros({2, 2, 2, 2}); + auto y = array({0, 1, 0, 0, 1, 0}, {2, 3}); + auto z = array({0, 1, 0, 0, 1, 0}, {2, 3}); + auto out = vmap(fun, {-1, 0, 0})({x, y, z})[0]; + CHECK_EQ(out.shape(), std::vector{2, 3, 2, 2}); + + z = array({0, 1, 0, 0, 1, 0}, {3, 2}); + out = vmap(fun, {-1, 0, 1})({x, y, z})[0]; + CHECK_EQ(out.shape(), std::vector{2, 3, 2, 2}); + } +}