From 6af5ca35b2ace3b88e83b1c2055bc0e6de5127c0 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Tue, 17 Sep 2024 21:12:43 +0100 Subject: [PATCH] feat: add cross_product (#1252) * feat: add cross_product * lint * python binding * refactor: Improve error message for cross_product function * refactor: more close to numpy cross product * refactor: improve error message for cross_product function * finish * fix acks * allow old numpy * doc --------- Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 2 +- docs/src/python/linalg.rst | 1 + mlx/linalg.cpp | 72 +++++++++++++++++++++++++++++++++++++ mlx/linalg.h | 9 +++++ python/src/linalg.cpp | 28 +++++++++++++++ python/tests/test_linalg.py | 48 +++++++++++++++++++++++++ tests/linalg_tests.cpp | 45 +++++++++++++++++++++++ 7 files changed, 204 insertions(+), 1 deletion(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index f406c36bb..db3ddeecf 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: -- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. +- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`. diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index e7fd5ecee..227711c22 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -13,5 +13,6 @@ Linear Algebra norm cholesky cholesky_inv + cross qr svd diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 9a397b868..a64f98aa8 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -382,4 +382,76 @@ array cholesky_inv( } } +array cross( + const array& a, + const array& b, + int axis /* = -1 */, + StreamOrDevice s /* = {} */) { + auto check_ax = [axis](const array& arr) { + if (axis >= static_cast(arr.ndim()) || axis + arr.ndim() < 0) { + std::ostringstream msg; + msg << "[linalg::cross] axis " << axis << " invalid for array with " + << arr.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (arr.shape(axis) < 2 || arr.shape(axis) > 3) { + throw std::invalid_argument( + "[linalg::cross] The specified axis must have size 2 or 3."); + } + }; + check_ax(a); + check_ax(b); + + bool a_2d = a.shape(axis) == 2; + bool b_2d = b.shape(axis) == 2; + + auto out_type = promote_types(a.dtype(), b.dtype()); + auto ashape = a.shape(); + auto bshape = b.shape(); + + ashape[axis < 0 ? axis + a.ndim() : axis] = 3; + bshape[axis < 0 ? axis + b.ndim() : axis] = 3; + auto out_shape = broadcast_shapes(ashape, bshape); + + if (axis < 0) { + axis += out_shape.size(); + } + + out_shape[axis] = a_2d ? 2 : 3; + auto a_ = broadcast_to(astype(a, out_type, s), out_shape, s); + + out_shape[axis] = b_2d ? 2 : 3; + auto b_ = broadcast_to(astype(b, out_type, s), out_shape, s); + + auto a_splits = split(a_, a_2d ? 2 : 3, axis); + auto b_splits = split(b_, b_2d ? 2 : 3, axis); + + std::vector outputs; + if (a_2d && b_2d) { + auto z = zeros_like(a_splits[0], s); + outputs.push_back(z); + outputs.push_back(z); + } else if (b_2d) { + outputs.push_back(negative(multiply(a_splits[2], b_splits[1], s), s)); + outputs.push_back(multiply(a_splits[2], b_splits[0], s)); + } else if (a_2d) { + outputs.push_back(multiply(a_splits[1], b_splits[2], s)); + outputs.push_back(negative(multiply(a_splits[0], b_splits[2], s), s)); + } else { + outputs.push_back(subtract( + multiply(a_splits[1], b_splits[2], s), + multiply(a_splits[2], b_splits[1], s), + s)); + outputs.push_back(subtract( + multiply(a_splits[2], b_splits[0], s), + multiply(a_splits[0], b_splits[2], s), + s)); + } + outputs.push_back(subtract( + multiply(a_splits[0], b_splits[1], s), + multiply(a_splits[1], b_splits[0], s), + s)); + return concatenate(outputs, axis, s); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index 3ffca476c..acfcc1a41 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -74,4 +74,13 @@ array pinv(const array& a, StreamOrDevice s = {}); array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); +/** + * Compute the cross product of two arrays along the given axis. + */ +array cross( + const array& a, + const array& b, + int axis = -1, + StreamOrDevice s = {}); + } // namespace mlx::core::linalg diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index c175ebbfa..65dd8d0e4 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -377,4 +377,32 @@ void init_linalg(nb::module_& parent_module) { Returns: array: ``aplus`` such that ``a @ aplus @ a = a`` )pbdoc"); + m.def( + "cross", + &cross, + "a"_a, + "b"_a, + "axis"_a = -1, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def cross(a: array, b: array, axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the cross product of two arrays along a specified axis. + + The cross product is defined for arrays with size 2 or 3 in the + specified axis. If the size is 2 then the third value is assumed + to be zero. + + Args: + a (array): Input array. + b (array): Input array. + axis (int, optional): Axis along which to compute the cross + product. Default: ``-1``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The cross product of ``a`` and ``b`` along the specified axis. + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 294051077..6051beef7 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -220,6 +220,54 @@ class TestLinalg(mlx_tests.MLXTestCase): for M, M_inv in zip(AB, AB_inv): self.assertTrue(mx.allclose(M @ M_inv, mx.eye(N), atol=1e-4)) + def test_cross_product(self): + a = mx.array([1.0, 2.0, 3.0]) + b = mx.array([4.0, 5.0, 6.0]) + result = mx.linalg.cross(a, b) + expected = np.cross(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test with negative values + a = mx.array([-1.0, -2.0, -3.0]) + b = mx.array([4.0, -5.0, 6.0]) + result = mx.linalg.cross(a, b) + expected = np.cross(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test with integer values + a = mx.array([1, 2, 3]) + b = mx.array([4, 5, 6]) + result = mx.linalg.cross(a, b) + expected = np.cross(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test with 2D arrays and axis parameter + a = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + b = mx.array([[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]) + result = mx.linalg.cross(a, b, axis=1) + expected = np.cross(a, b, axis=1) + self.assertTrue(np.allclose(result, expected)) + + # Test with broadcast + a = mx.random.uniform(shape=(2, 1, 3)) + b = mx.random.uniform(shape=(1, 2, 3)) + result = mx.linalg.cross(a, b) + expected = np.cross(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Type promotion + a = mx.array([1.0, 2.0, 3.0]) + b = mx.array([4, 5, 6]) + result = mx.linalg.cross(a, b) + expected = np.cross(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test with incorrect vector size (should raise an exception) + a = mx.array([1.0]) + b = mx.array([4.0]) + with self.assertRaises(ValueError): + mx.linalg.cross(a, b) + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 186bbc613..e9e196583 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -390,3 +390,48 @@ TEST_CASE("test matrix pseudo-inverse") { CHECK(allclose(A_pinv_again, A_pinv).item()); } } + +TEST_CASE("test cross product") { + using namespace mlx::core::linalg; + + // Test for vectors of length 3 + array a = array({1.0, 2.0, 3.0}); + array b = array({4.0, 5.0, 6.0}); + + array expected = array( + {2.0 * 6.0 - 3.0 * 5.0, 3.0 * 4.0 - 1.0 * 6.0, 1.0 * 5.0 - 2.0 * 4.0}); + + array result = cross(a, b); + CHECK(allclose(result, expected).item()); + + // Test for vectors of length 3 with negative values + a = array({-1.0, -2.0, -3.0}); + b = array({4.0, -5.0, 6.0}); + + expected = array( + {-2.0 * 6.0 - (-3.0 * -5.0), + -3.0 * 4.0 - (-1.0 * 6.0), + -1.0 * -5.0 - (-2.0 * 4.0)}); + + result = cross(a, b); + CHECK(allclose(result, expected).item()); + + // Test for incorrect vector size (should throw) + b = array({1.0, 2.0}); + expected = array( + {-2.0 * 0.0 - (-3.0 * 2.0), + -3.0 * 1.0 - (-1.0 * 0.0), + -1.0 * 2.0 - (-2.0 * 1.0)}); + + result = cross(a, b); + CHECK(allclose(result, expected).item()); + + // Test for vectors of length 3 with integer values + a = array({1, 2, 3}); + b = array({4, 5, 6}); + + expected = array({2 * 6 - 3 * 5, 3 * 4 - 1 * 6, 1 * 5 - 2 * 4}); + + result = cross(a, b); + CHECK(allclose(result, expected).item()); +}