From f883fcede082c7198f32b4793f60d26d3195c2fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Mon, 19 Feb 2024 12:40:52 -0500 Subject: [PATCH] Added support for atleast_1d, atleast_2d, atleast_3d (#694) --- ACKNOWLEDGMENTS.md | 3 +- docs/src/python/ops.rst | 3 ++ mlx/ops.cpp | 30 ++++++++++++++ mlx/ops.h | 5 +++ python/src/ops.cpp | 60 +++++++++++++++++++++++++++ python/tests/test_ops.py | 90 ++++++++++++++++++++++++++++++++++++++++ tests/ops_tests.cpp | 51 +++++++++++++++++++++++ 7 files changed, 241 insertions(+), 1 deletion(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 36aedc77a..c2cad615e 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -10,8 +10,9 @@ MLX was developed with contributions from the following individuals: - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. -- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support +- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``. +- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 09e2d5f71..7ec7defc9 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -25,6 +25,9 @@ Operations argpartition argsort array_equal + atleast_1d + atleast_2d + atleast_3d broadcast_to ceil clip diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 32af8a078..97d4a3a2d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3381,4 +3381,34 @@ std::vector depends( shapes, dtypes, std::make_shared(to_stream(s)), all_inputs); } +array atleast_1d(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() == 0) { + return reshape(a, {1}, s); + } + return a; +} + +array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { + switch (a.ndim()) { + case 0: + return reshape(a, {1, 1}, s); + case 1: + return reshape(a, {1, static_cast(a.size())}, s); + default: + return a; + } +} + +array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { + switch (a.ndim()) { + case 0: + return reshape(a, {1, 1, 1}, s); + case 1: + return reshape(a, {1, static_cast(a.size()), 1}, s); + case 2: + return reshape(a, {a.shape(0), a.shape(1), 1}, s); + default: + return a; + } +} } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index f7036b8c6..b61224d65 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1121,4 +1121,9 @@ std::vector depends( const std::vector& inputs, const std::vector& dependencies); +/** convert an array to an atleast ndim array */ +array atleast_1d(const array& a, StreamOrDevice s = {}); +array atleast_2d(const array& a, StreamOrDevice s = {}); +array atleast_3d(const array& a, StreamOrDevice s = {}); + } // namespace mlx::core diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8e08e6ca9..2c2dcecfd 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3636,4 +3636,64 @@ void init_ops(py::module_& m) { Returns: array: The extracted diagonal or the constructed diagonal matrix. )pbdoc"); + m.def( + "atleast_1d", + &atleast_1d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to have at least one dimension. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least one dimension. + + )pbdoc"); + m.def( + "atleast_2d", + &atleast_2d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to have at least two dimensions. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least two dimensions. + + )pbdoc"); + m.def( + "atleast_3d", + &atleast_3d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to have at least three dimensions. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least three dimensions. + + )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 66e683303..3401338f8 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1883,6 +1883,96 @@ class TestOps(mlx_tests.MLXTestCase): expected = mx.array(np.diag(x, k=-1)) self.assertTrue(mx.array_equal(result, expected)) + def test_atleast_1d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_1d(mx.array(array)) + np_res = np.atleast_1d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + + def test_atleast_2d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_2d(mx.array(array)) + np_res = np.atleast_2d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + + def test_atleast_3d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_3d(mx.array(array)) + np_res = np.atleast_3d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 41db064be..ba4ab552f 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2716,3 +2716,54 @@ TEST_CASE("test diag") { out = diag(x, -1); CHECK(array_equal(out, array({3, 7}, {2})).item()); } + +TEST_CASE("test atleast_1d") { + auto x = array(1); + auto out = atleast_1d(x); + CHECK_EQ(out.ndim(), 1); + CHECK_EQ(out.shape(), std::vector{1}); + + x = array({1, 2, 3}, {3}); + out = atleast_1d(x); + CHECK_EQ(out.ndim(), 1); + CHECK_EQ(out.shape(), std::vector{3}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_1d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{3, 1}); +} + +TEST_CASE("test atleast_2d") { + auto x = array(1); + auto out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{1, 1}); + + x = array({1, 2, 3}, {3}); + out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{1, 3}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{3, 1}); +} + +TEST_CASE("test atleast_3d") { + auto x = array(1); + auto out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{1, 1, 1}); + + x = array({1, 2, 3}, {3}); + out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{1, 3, 1}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{3, 1, 1}); +} \ No newline at end of file