From 08226ab4912cce8d7bee1ad18b43c6795c94eeb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Mon, 26 Feb 2024 14:17:59 -0500 Subject: [PATCH] added atleast *args input support (#710) * added atleast list(array) input support * function overloading implemented * Refactoring * fixed formatting * removed pos_only --- mlx/ops.cpp | 34 ++++++++++++++++++++++ mlx/ops.h | 9 ++++++ python/src/ops.cpp | 61 ++++++++++++++++++++++------------------ python/tests/test_ops.py | 18 ++++++++++-- tests/ops_tests.cpp | 39 +++++++++++++++++++++++++ 5 files changed, 131 insertions(+), 30 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index df933a515..dd496d181 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3414,6 +3414,17 @@ array atleast_1d(const array& a, StreamOrDevice s /* = {} */) { return a; } +std::vector atleast_1d( + const std::vector& arrays, + StreamOrDevice s /* = {} */) { + std::vector out; + out.reserve(arrays.size()); + for (const auto& a : arrays) { + out.push_back(atleast_1d(a, s)); + } + return out; +} + array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { switch (a.ndim()) { case 0: @@ -3425,6 +3436,17 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { } } +std::vector atleast_2d( + const std::vector& arrays, + StreamOrDevice s /* = {} */) { + std::vector out; + out.reserve(arrays.size()); + for (const auto& a : arrays) { + out.push_back(atleast_2d(a, s)); + } + return out; +} + array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { switch (a.ndim()) { case 0: @@ -3437,4 +3459,16 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { return a; } } + +std::vector atleast_3d( + const std::vector& arrays, + StreamOrDevice s /* = {} */) { + std::vector out; + out.reserve(arrays.size()); + for (const auto& a : arrays) { + out.push_back(atleast_3d(a, s)); + } + return out; +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index b61224d65..a92a4f8c0 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1123,7 +1123,16 @@ std::vector depends( /** convert an array to an atleast ndim array */ array atleast_1d(const array& a, StreamOrDevice s = {}); +std::vector atleast_1d( + const std::vector& a, + StreamOrDevice s = {}); array atleast_2d(const array& a, StreamOrDevice s = {}); +std::vector atleast_2d( + const std::vector& a, + StreamOrDevice s = {}); array atleast_3d(const array& a, StreamOrDevice s = {}); +std::vector atleast_3d( + const std::vector& a, + StreamOrDevice s = {}); } // namespace mlx::core diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 2c2dcecfd..56a1ac8de 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3638,62 +3638,69 @@ void init_ops(py::module_& m) { )pbdoc"); m.def( "atleast_1d", - &atleast_1d, - "a"_a, - py::pos_only(), + [](const py::args& arys, StreamOrDevice s) -> py::object { + if (arys.size() == 1) { + return py::cast(atleast_1d(arys[0].cast(), s)); + } + return py::cast(atleast_1d(arys.cast>(), s)); + }, py::kw_only(), "stream"_a = none, R"pbdoc( - atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array + atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] - Convert array to have at least one dimension. + Convert all arrays to have at least one dimension. - args: - a (array): Input array + Args: + *arys: Input arrays. stream (Union[None, Stream, Device], optional): The stream to execute the operation on. Returns: - array: An array with at least one dimension. - + array or list(array): An array or list of arrays with at least one dimension. )pbdoc"); m.def( "atleast_2d", - &atleast_2d, - "a"_a, - py::pos_only(), + [](const py::args& arys, StreamOrDevice s) -> py::object { + if (arys.size() == 1) { + return py::cast(atleast_2d(arys[0].cast(), s)); + } + return py::cast(atleast_2d(arys.cast>(), s)); + }, py::kw_only(), "stream"_a = none, R"pbdoc( - atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array + atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] - Convert array to have at least two dimensions. + Convert all arrays to have at least two dimensions. - args: - a (array): Input array + Args: + *arys: Input arrays. stream (Union[None, Stream, Device], optional): The stream to execute the operation on. Returns: - array: An array with at least two dimensions. - + array or list(array): An array or list of arrays with at least two dimensions. )pbdoc"); + m.def( "atleast_3d", - &atleast_3d, - "a"_a, - py::pos_only(), + [](const py::args& arys, StreamOrDevice s) -> py::object { + if (arys.size() == 1) { + return py::cast(atleast_3d(arys[0].cast(), s)); + } + return py::cast(atleast_3d(arys.cast>(), s)); + }, py::kw_only(), "stream"_a = none, R"pbdoc( - atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array + atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] - Convert array to have at least three dimensions. + Convert all arrays to have at least three dimensions. - args: - a (array): Input array + Args: + *arys: Input arrays. stream (Union[None, Stream, Device], optional): The stream to execute the operation on. Returns: - array: An array with at least three dimensions. - + array or list(array): An array or list of arrays with at least three dimensions. )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 23a8c7bc1..1a504dd45 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1932,12 +1932,16 @@ class TestOps(mlx_tests.MLXTestCase): [[[[1]], [[2]], [[3]]]], ] - for array in arrays: + mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays] + atleast_arrays = mx.atleast_1d(*mx_arrays) + + for i, array in enumerate(arrays): mx_res = mx.atleast_1d(mx.array(array)) np_res = np.atleast_1d(np.array(array)) self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) + self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) def test_atleast_2d(self): def compare_nested_lists(x, y): @@ -1962,12 +1966,16 @@ class TestOps(mlx_tests.MLXTestCase): [[[[1]], [[2]], [[3]]]], ] - for array in arrays: + mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays] + atleast_arrays = mx.atleast_2d(*mx_arrays) + + for i, array in enumerate(arrays): mx_res = mx.atleast_2d(mx.array(array)) np_res = np.atleast_2d(np.array(array)) self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) + self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) def test_atleast_3d(self): def compare_nested_lists(x, y): @@ -1992,12 +2000,16 @@ class TestOps(mlx_tests.MLXTestCase): [[[[1]], [[2]], [[3]]]], ] - for array in arrays: + mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays] + atleast_arrays = mx.atleast_3d(*mx_arrays) + + for i, array in enumerate(arrays): mx_res = mx.atleast_3d(mx.array(array)) np_res = np.atleast_3d(np.array(array)) self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) + self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) if __name__ == "__main__": diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index d63c1f2cd..fb4bfb78f 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2787,6 +2787,19 @@ TEST_CASE("test atleast_1d") { CHECK_EQ(out.shape(), std::vector{3, 1}); } +TEST_CASE("test atleast_1d vector") { + auto x = std::vector{ + array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})}; + auto out = atleast_1d(x); + CHECK_EQ(out.size(), 3); + CHECK_EQ(out[0].ndim(), 1); + CHECK_EQ(out[0].shape(), std::vector{1}); + CHECK_EQ(out[1].ndim(), 1); + CHECK_EQ(out[1].shape(), std::vector{3}); + CHECK_EQ(out[2].ndim(), 2); + CHECK_EQ(out[2].shape(), std::vector{3, 1}); +} + TEST_CASE("test atleast_2d") { auto x = array(1); auto out = atleast_2d(x); @@ -2804,6 +2817,19 @@ TEST_CASE("test atleast_2d") { CHECK_EQ(out.shape(), std::vector{3, 1}); } +TEST_CASE("test atleast_2d vector") { + auto x = std::vector{ + array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})}; + auto out = atleast_2d(x); + CHECK_EQ(out.size(), 3); + CHECK_EQ(out[0].ndim(), 2); + CHECK_EQ(out[0].shape(), std::vector{1, 1}); + CHECK_EQ(out[1].ndim(), 2); + CHECK_EQ(out[1].shape(), std::vector{1, 3}); + CHECK_EQ(out[2].ndim(), 2); + CHECK_EQ(out[2].shape(), std::vector{3, 1}); +} + TEST_CASE("test atleast_3d") { auto x = array(1); auto out = atleast_3d(x); @@ -2820,3 +2846,16 @@ TEST_CASE("test atleast_3d") { CHECK_EQ(out.ndim(), 3); CHECK_EQ(out.shape(), std::vector{3, 1, 1}); } + +TEST_CASE("test atleast_3d vector") { + auto x = std::vector{ + array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})}; + auto out = atleast_3d(x); + CHECK_EQ(out.size(), 3); + CHECK_EQ(out[0].ndim(), 3); + CHECK_EQ(out[0].shape(), std::vector{1, 1, 1}); + CHECK_EQ(out[1].ndim(), 3); + CHECK_EQ(out[1].shape(), std::vector{1, 3, 1}); + CHECK_EQ(out[2].ndim(), 3); + CHECK_EQ(out[2].shape(), std::vector{3, 1, 1}); +} \ No newline at end of file