feat: add cross_product (#1252)

* feat: add cross_product

* lint

* python binding

* refactor: Improve error message for cross_product function

* refactor: more close to numpy cross product

* refactor: improve error message for cross_product function

* finish

* fix acks

* allow old numpy

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-09-17 21:12:43 +01:00
committed by GitHub
parent 4f46e9c997
commit 6af5ca35b2
7 changed files with 204 additions and 1 deletions

View File

@@ -220,6 +220,54 @@ class TestLinalg(mlx_tests.MLXTestCase):
for M, M_inv in zip(AB, AB_inv):
self.assertTrue(mx.allclose(M @ M_inv, mx.eye(N), atol=1e-4))
def test_cross_product(self):
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with negative values
a = mx.array([-1.0, -2.0, -3.0])
b = mx.array([4.0, -5.0, 6.0])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with integer values
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with 2D arrays and axis parameter
a = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = mx.array([[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]])
result = mx.linalg.cross(a, b, axis=1)
expected = np.cross(a, b, axis=1)
self.assertTrue(np.allclose(result, expected))
# Test with broadcast
a = mx.random.uniform(shape=(2, 1, 3))
b = mx.random.uniform(shape=(1, 2, 3))
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Type promotion
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4, 5, 6])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with incorrect vector size (should raise an exception)
a = mx.array([1.0])
b = mx.array([4.0])
with self.assertRaises(ValueError):
mx.linalg.cross(a, b)
if __name__ == "__main__":
unittest.main()