feat: add cross_product (#1252)

* feat: add cross_product

* lint

* python binding

* refactor: Improve error message for cross_product function

* refactor: more close to numpy cross product

* refactor: improve error message for cross_product function

* finish

* fix acks

* allow old numpy

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-09-17 21:12:43 +01:00
committed by GitHub
parent 4f46e9c997
commit 6af5ca35b2
7 changed files with 204 additions and 1 deletions

View File

@@ -377,4 +377,32 @@ void init_linalg(nb::module_& parent_module) {
Returns:
array: ``aplus`` such that ``a @ aplus @ a = a``
)pbdoc");
m.def(
"cross",
&cross,
"a"_a,
"b"_a,
"axis"_a = -1,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def cross(a: array, b: array, axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the cross product of two arrays along a specified axis.
The cross product is defined for arrays with size 2 or 3 in the
specified axis. If the size is 2 then the third value is assumed
to be zero.
Args:
a (array): Input array.
b (array): Input array.
axis (int, optional): Axis along which to compute the cross
product. Default: ``-1``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The cross product of ``a`` and ``b`` along the specified axis.
)pbdoc");
}

View File

@@ -220,6 +220,54 @@ class TestLinalg(mlx_tests.MLXTestCase):
for M, M_inv in zip(AB, AB_inv):
self.assertTrue(mx.allclose(M @ M_inv, mx.eye(N), atol=1e-4))
def test_cross_product(self):
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4.0, 5.0, 6.0])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with negative values
a = mx.array([-1.0, -2.0, -3.0])
b = mx.array([4.0, -5.0, 6.0])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with integer values
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with 2D arrays and axis parameter
a = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = mx.array([[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]])
result = mx.linalg.cross(a, b, axis=1)
expected = np.cross(a, b, axis=1)
self.assertTrue(np.allclose(result, expected))
# Test with broadcast
a = mx.random.uniform(shape=(2, 1, 3))
b = mx.random.uniform(shape=(1, 2, 3))
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Type promotion
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([4, 5, 6])
result = mx.linalg.cross(a, b)
expected = np.cross(a, b)
self.assertTrue(np.allclose(result, expected))
# Test with incorrect vector size (should raise an exception)
a = mx.array([1.0])
b = mx.array([4.0])
with self.assertRaises(ValueError):
mx.linalg.cross(a, b)
if __name__ == "__main__":
unittest.main()