From e28b57e37194556b681c6bf044201f8a2b3e1cb5 Mon Sep 17 00:00:00 2001 From: Jason <5698518+j-csc@users.noreply.github.com> Date: Thu, 14 Dec 2023 16:21:19 -0500 Subject: [PATCH] Added mx.stack c++ frontend impl (#123) * stack C++ operation + python bindings --- docs/src/python/ops.rst | 1 + mlx/ops.cpp | 31 +++++++++++++++++++++++++++---- mlx/ops.h | 4 ++++ mlx/utils.cpp | 25 +++++++++++++++++++++++++ mlx/utils.h | 9 +++++++++ python/src/ops.cpp | 30 ++++++++++++++++++++++++++++++ python/tests/test_ops.py | 31 +++++++++++++++++++++++++++++++ tests/ops_tests.cpp | 29 +++++++++++++++++++++++++++++ tests/utils_tests.cpp | 35 +++++++++++++++++++++++++++++++++++ 9 files changed, 191 insertions(+), 4 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index f08be3e9d..c235d3b64 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -85,6 +85,7 @@ Operations sqrt square squeeze + stack stop_gradient subtract sum diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a3fea63e5..dc852370b 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -574,11 +574,11 @@ array concatenate( shape[ax] += a.shape(ax); } + // Promote all the arrays to the same type + auto dtype = result_type(arrays); + return array( - shape, - arrays[0].dtype(), - std::make_unique(to_stream(s), ax), - arrays); + shape, dtype, std::make_unique(to_stream(s), ax), arrays); } array concatenate( @@ -591,6 +591,29 @@ array concatenate( return concatenate(flat_inputs, 0, s); } +/** Stack arrays along a new axis */ +array stack( + const std::vector& arrays, + int axis, + StreamOrDevice s /* = {} */) { + if (arrays.empty()) { + throw std::invalid_argument("No arrays provided for stacking"); + } + if (!is_same_shape(arrays)) { + throw std::invalid_argument("All arrays must have the same shape"); + } + int normalized_axis = normalize_axis(axis, arrays[0].ndim() + 1); + std::vector new_arrays; + new_arrays.reserve(arrays.size()); + for (auto& a : arrays) { + new_arrays.emplace_back(expand_dims(a, normalized_axis, s)); + } + return concatenate(new_arrays, axis, s); +} +array stack(const std::vector& arrays, StreamOrDevice s /* = {} */) { + return stack(arrays, 0, s); +} + /** Pad an array with a constant value */ array pad( const array& a, diff --git a/mlx/ops.h b/mlx/ops.h index 81d457281..6b081ad9f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -174,6 +174,10 @@ array concatenate( StreamOrDevice s = {}); array concatenate(const std::vector& arrays, StreamOrDevice s = {}); +/** Stack arrays along a new axis. */ +array stack(const std::vector& arrays, int axis, StreamOrDevice s = {}); +array stack(const std::vector& arrays, StreamOrDevice s = {}); + /** Permutes the dimensions according to the given axes. */ array transpose(const array& a, std::vector axes, StreamOrDevice s = {}); inline array transpose( diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 5f38cd06f..f564cfe59 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -49,6 +49,31 @@ std::vector broadcast_shapes( return out_shape; } +bool is_same_shape(const std::vector& arrays) { + if (arrays.empty()) + return true; + + return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) { + return (a.shape() == arrays[0].shape()); + }); +} + +int normalize_axis(int axis, int ndim) { + if (ndim <= 0) { + throw std::invalid_argument("Number of dimensions must be positive."); + } + if (axis < -ndim || axis >= ndim) { + std::ostringstream msg; + msg << "Axis " << axis << " is out of bounds for array with " << ndim + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (axis < 0) { + axis += ndim; + } + return axis; +} + std::ostream& operator<<(std::ostream& os, const Device& d) { os << "Device("; switch (d.type) { diff --git a/mlx/utils.h b/mlx/utils.h index f21e55fce..823b4c872 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -16,6 +16,15 @@ std::vector broadcast_shapes( const std::vector& s1, const std::vector& s2); +bool is_same_shape(const std::vector& arrays); + +/** + * Returns the axis normalized to be in the range [0, ndim). + * Based on numpy's normalize_axis_index. See + * https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html + */ +int normalize_axis(int axis, int ndim); + std::ostream& operator<<(std::ostream& os, const Device& d); std::ostream& operator<<(std::ostream& os, const Stream& s); std::ostream& operator<<(std::ostream& os, const Dtype& d); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index ab77f3c42..6d4d80b97 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2230,6 +2230,36 @@ void init_ops(py::module_& m) { Returns: array: The concatenated array. )pbdoc"); + m.def( + "stack", + [](const std::vector& arrays, + std::optional axis, + StreamOrDevice s) { + if (axis.has_value()) { + return stack(arrays, axis.value(), s); + } else { + return stack(arrays, s); + } + }, + "arrays"_a, + py::pos_only(), + "axis"_a = 0, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + stack(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array + + Stacks the arrays along a new axis. + + Args: + arrays (list(array)): A list of arrays to stack. + axis (int, optional): The axis in the result array along which the + input arrays are stacked. Defaults to ``0``. + stream (Stream, optional): Stream or device. Defaults to ``None``. + + Returns: + array: The resulting stacked array. + )pbdoc"); m.def( "pad", [](const array& a, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index a35a82dc6..db1830e16 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1371,6 +1371,37 @@ class TestOps(mlx_tests.MLXTestCase): np_eye_matrix = np.eye(5, 6, k=-2) self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) + def test_stack(self): + a = mx.ones((2,)) + np_a = np.ones((2,)) + b = mx.ones((2,)) + np_b = np.ones((2,)) + + # One dimensional stack axis=0 + c = mx.stack([a, b]) + np_c = np.stack([np_a, np_b]) + self.assertTrue(np.array_equal(c, np_c)) + + # One dimensional stack axis=1 + c = mx.stack([a, b], axis=1) + np_c = np.stack([np_a, np_b], axis=1) + self.assertTrue(np.array_equal(c, np_c)) + + a = mx.ones((1, 2)) + np_a = np.ones((1, 2)) + b = mx.ones((1, 2)) + np_b = np.ones((1, 2)) + + # Two dimensional stack axis=0 + c = mx.stack([a, b]) + np_c = np.stack([np_a, np_b]) + self.assertTrue(np.array_equal(c, np_c)) + + # Two dimensional stack axis=1 + c = mx.stack([a, b], axis=1) + np_c = np.stack([np_a, np_b], axis=1) + self.assertTrue(np.array_equal(c, np_c)) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 87437223c..0916eeafe 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1989,6 +1989,35 @@ TEST_CASE("test where") { CHECK(array_equal(where(condition, x, y), expected).item()); } +TEST_CASE("test stack") { + auto x = array({}); + CHECK_EQ(stack({x}, 0).shape(), std::vector{1, 0}); + CHECK_EQ(stack({x}, 1).shape(), std::vector{0, 1}); + + x = array({1, 2, 3}, {3}); + CHECK_EQ(stack({x}, 0).shape(), std::vector{1, 3}); + CHECK_EQ(stack({x}, 1).shape(), std::vector{3, 1}); + + auto y = array({4, 5, 6}, {3}); + auto z = std::vector{x, y}; + CHECK_EQ(stack(z).shape(), std::vector{2, 3}); + CHECK_EQ(stack(z, 0).shape(), std::vector{2, 3}); + CHECK_EQ(stack(z, 1).shape(), std::vector{3, 2}); + CHECK_EQ(stack(z, -1).shape(), std::vector{3, 2}); + CHECK_EQ(stack(z, -2).shape(), std::vector{2, 3}); + + CHECK_THROWS_MESSAGE(stack({}, 0), "No arrays provided for stacking"); + + x = array({1, 2, 3}, {3}, float16); + y = array({4, 5, 6}, {3}, int32); + CHECK_EQ(stack({x, y}, 0).dtype(), float16); + + x = array({1, 2, 3}, {3}, int32); + y = array({4, 5, 6, 7}, {4}, int32); + CHECK_THROWS_MESSAGE( + stack({x, y}, 0), "All arrays must have the same shape and dtype"); +} + TEST_CASE("test eye") { auto eye_3 = eye(3); CHECK_EQ(eye_3.shape(), std::vector{3, 3}); diff --git a/tests/utils_tests.cpp b/tests/utils_tests.cpp index 4c4cf31ac..e7bb35d21 100644 --- a/tests/utils_tests.cpp +++ b/tests/utils_tests.cpp @@ -25,3 +25,38 @@ TEST_CASE("test type promotion") { CHECK_EQ(result_type(arrs), float32); } } + +TEST_CASE("test normalize axis") { + struct TestCase { + int axis; + int ndim; + int expected; + }; + + std::vector testCases = { + {0, 3, 0}, {1, 3, 1}, {2, 3, 2}, {-1, 3, 2}, {-2, 3, 1}, {-3, 3, 0}}; + + for (const auto& tc : testCases) { + CHECK_EQ(normalize_axis(tc.axis, tc.ndim), tc.expected); + } + + CHECK_THROWS(normalize_axis(3, 3)); + CHECK_THROWS(normalize_axis(-4, 3)); +} + +TEST_CASE("test is same size and shape") { + struct TestCase { + std::vector a; + bool expected; + }; + + std::vector testCases = { + {{array({}), array({})}, true}, + {{array({1}), array({1})}, true}, + {{array({1, 2, 3}), array({1, 2, 4})}, true}, + {{array({1, 2, 3}), array({1, 2})}, false}}; + + for (const auto& tc : testCases) { + CHECK_EQ(is_same_shape(tc.a), tc.expected); + } +} \ No newline at end of file