mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	feat: add cross_product (#1252)
* feat: add cross_product * lint * python binding * refactor: Improve error message for cross_product function * refactor: more close to numpy cross product * refactor: improve error message for cross_product function * finish * fix acks * allow old numpy * doc --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -377,4 +377,32 @@ void init_linalg(nb::module_& parent_module) { | ||||
|         Returns: | ||||
|             array: ``aplus`` such that ``a @ aplus @ a = a`` | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "cross", | ||||
|       &cross, | ||||
|       "a"_a, | ||||
|       "b"_a, | ||||
|       "axis"_a = -1, | ||||
|       nb::kw_only(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def cross(a: array, b: array, axis: int = -1, *, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         Compute the cross product of two arrays along a specified axis. | ||||
|  | ||||
|         The cross product is defined for arrays with size 2 or 3 in the | ||||
|         specified axis. If the size is 2 then the third value is assumed | ||||
|         to be zero. | ||||
|  | ||||
|         Args: | ||||
|             a (array): Input array. | ||||
|             b (array): Input array. | ||||
|             axis (int, optional): Axis along which to compute the cross | ||||
|               product. Default: ``-1``. | ||||
|             stream (Stream, optional): Stream or device. Defaults to ``None`` | ||||
|               in which case the default stream of the default device is used. | ||||
|  | ||||
|         Returns: | ||||
|             array: The cross product of ``a`` and ``b`` along the specified axis. | ||||
|       )pbdoc"); | ||||
| } | ||||
|   | ||||
| @@ -220,6 +220,54 @@ class TestLinalg(mlx_tests.MLXTestCase): | ||||
|             for M, M_inv in zip(AB, AB_inv): | ||||
|                 self.assertTrue(mx.allclose(M @ M_inv, mx.eye(N), atol=1e-4)) | ||||
|  | ||||
|     def test_cross_product(self): | ||||
|         a = mx.array([1.0, 2.0, 3.0]) | ||||
|         b = mx.array([4.0, 5.0, 6.0]) | ||||
|         result = mx.linalg.cross(a, b) | ||||
|         expected = np.cross(a, b) | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         # Test with negative values | ||||
|         a = mx.array([-1.0, -2.0, -3.0]) | ||||
|         b = mx.array([4.0, -5.0, 6.0]) | ||||
|         result = mx.linalg.cross(a, b) | ||||
|         expected = np.cross(a, b) | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         # Test with integer values | ||||
|         a = mx.array([1, 2, 3]) | ||||
|         b = mx.array([4, 5, 6]) | ||||
|         result = mx.linalg.cross(a, b) | ||||
|         expected = np.cross(a, b) | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         # Test with 2D arrays and axis parameter | ||||
|         a = mx.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) | ||||
|         b = mx.array([[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]) | ||||
|         result = mx.linalg.cross(a, b, axis=1) | ||||
|         expected = np.cross(a, b, axis=1) | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         # Test with broadcast | ||||
|         a = mx.random.uniform(shape=(2, 1, 3)) | ||||
|         b = mx.random.uniform(shape=(1, 2, 3)) | ||||
|         result = mx.linalg.cross(a, b) | ||||
|         expected = np.cross(a, b) | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         # Type promotion | ||||
|         a = mx.array([1.0, 2.0, 3.0]) | ||||
|         b = mx.array([4, 5, 6]) | ||||
|         result = mx.linalg.cross(a, b) | ||||
|         expected = np.cross(a, b) | ||||
|         self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|         # Test with incorrect vector size (should raise an exception) | ||||
|         a = mx.array([1.0]) | ||||
|         b = mx.array([4.0]) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.linalg.cross(a, b) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Nripesh Niketan
					Nripesh Niketan