diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 75d54be5a..ea25b90f9 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -41,6 +41,7 @@ Operations expand_dims eye floor + flatten full greater greater_equal diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 1e61745eb..19b165f3a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -277,6 +277,34 @@ array reshape( shape, a.dtype(), std::make_unique(to_stream(s), shape), {a}); } +array flatten( + const array& a, + int start_axis, + int end_axis /* = -1 */, + StreamOrDevice s /* = {} */) { + auto ndim = static_cast(a.ndim()); + start_axis += (start_axis < 0 ? ndim : 0); + end_axis += (end_axis < 0 ? ndim + 1 : 0); + start_axis = std::max(0, start_axis); + end_axis = std::min(ndim, end_axis); + if (end_axis < start_axis) { + throw std::invalid_argument( + "[flatten] start_axis must be less than or equal to end_axis"); + } + if (start_axis == end_axis and a.ndim() != 0) { + return a; + } + std::vector new_shape(a.shape().begin(), a.shape().begin() + start_axis); + new_shape.push_back(-1); + new_shape.insert( + new_shape.end(), a.shape().begin() + end_axis + 1, a.shape().end()); + return reshape(a, new_shape, s); +} + +array flatten(const array& a, StreamOrDevice s /* = {} */) { + return flatten(a, 0, a.ndim() - 1, s); +} + array squeeze( const array& a, const std::vector& axes, diff --git a/mlx/ops.h b/mlx/ops.h index e5bdcf358..86c475e6e 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -123,6 +123,16 @@ array triu(array x, int k, StreamOrDevice s = {}); /** Reshape an array to the given shape. */ array reshape(const array& a, std::vector shape, StreamOrDevice s = {}); +/** Flatten the dimensions in the range `[start_axis, end_axis]` . */ +array flatten( + const array& a, + int start_axis, + int end_axis = -1, + StreamOrDevice s = {}); + +/** Flatten the array to 1D. */ +array flatten(const array& a, StreamOrDevice s = {}); + /** Remove singleton dimensions at the given axes. */ array squeeze( const array& a, diff --git a/mlx/utils.cpp b/mlx/utils.cpp index f564cfe59..1fbc67c8e 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -50,9 +50,9 @@ std::vector broadcast_shapes( } bool is_same_shape(const std::vector& arrays) { - if (arrays.empty()) + if (arrays.empty()) { return true; - + } return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) { return (a.shape() == arrays[0].shape()); }); diff --git a/python/src/array.cpp b/python/src/array.cpp index 906632e5f..36f856c8c 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -728,6 +728,21 @@ void init_array(py::module_& m) { return power(a, to_array(v, a.dtype())); }, "other"_a) + .def( + "flatten", + [](const array& a, + int start_axis, + int end_axis, + const StreamOrDevice& s) { + return flatten(a, start_axis, end_axis); + }, + "start_axis"_a = 0, + "end_axis"_a = -1, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + See :func:`flatten`. + )pbdoc") .def( "reshape", [](const array& a, py::args shape, StreamOrDevice s) { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 4891052b6..58b15e1d6 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -61,6 +61,33 @@ void init_ops(py::module_& m) { Returns: array: The reshaped array. )pbdoc"); + m.def( + "flatten", + [](const array& a, + int start_axis, + int end_axis, + const StreamOrDevice& s) { return flatten(a, start_axis, end_axis); }, + "a"_a, + py::pos_only(), + "start_axis"_a = 0, + "end_axis"_a = -1, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + flatten(a: array, /, start_axis: int = 0, end_axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array + + Flatten an array. + + Args: + a (array): Input array. + start_axis (int, optional): The first dimension to flatten. Defaults to ``0``. + end_axis (int, optional): The last dimension to flatten. Defaults to ``-1``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The flattened array. + )pbdoc"); m.def( "squeeze", [](const array& a, const IntOrVec& v, const StreamOrDevice& s) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 9ea0e80db..eea726b16 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1426,6 +1426,15 @@ class TestOps(mlx_tests.MLXTestCase): np_c = np.stack([np_a, np_b], axis=1) self.assertTrue(np.array_equal(c, np_c)) + def test_flatten(self): + x = mx.zeros([2, 3, 4]) + self.assertEqual(mx.flatten(x).shape, [2 * 3 * 4]) + self.assertEqual(mx.flatten(x, start_axis=1).shape, [2, 3 * 4]) + self.assertEqual(mx.flatten(x, end_axis=1).shape, [2 * 3, 4]) + self.assertEqual(x.flatten().shape, [2 * 3 * 4]) + self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4]) + self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4]) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 921afc31d..d33c7340a 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -58,6 +58,27 @@ TEST_CASE("test reshape") { CHECK_EQ(y.shape(), std::vector{1, 5, 0}); } +TEST_CASE("test flatten") { + array x = zeros({2, 3, 4}); + CHECK_EQ(flatten(x).shape(), std::vector({2 * 3 * 4})); + + CHECK_EQ(flatten(x, 1, 1).shape(), std::vector({2, 3, 4})); + CHECK_EQ(flatten(x, 1, 2).shape(), std::vector({2, 3 * 4})); + CHECK_EQ(flatten(x, 1, 3).shape(), std::vector({2, 3 * 4})); + CHECK_EQ(flatten(x, 1, -1).shape(), std::vector({2, 3 * 4})); + CHECK_EQ(flatten(x, -2, -1).shape(), std::vector({2, 3 * 4})); + CHECK_EQ(flatten(x, -3, -1).shape(), std::vector({2 * 3 * 4})); + CHECK_EQ(flatten(x, -4, -1).shape(), std::vector({2 * 3 * 4})); + + // Check start > end throws + CHECK_THROWS(flatten(x, 2, 1)); + + // Check scalar flattens to 1D + x = array(1); + CHECK_EQ(flatten(x, -3, -1).shape(), std::vector({1})); + CHECK_EQ(flatten(x, 0, 0).shape(), std::vector({1})); +} + TEST_CASE("test squeeze and expand") { array x = zeros({2, 1, 2, 1, 2, 1}); CHECK_EQ(squeeze(x).shape(), std::vector{2, 2, 2});