mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	feat: add cross_product (#1252)
* feat: add cross_product * lint * python binding * refactor: Improve error message for cross_product function * refactor: more close to numpy cross product * refactor: improve error message for cross_product function * finish * fix acks * allow old numpy * doc --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -390,3 +390,48 @@ TEST_CASE("test matrix pseudo-inverse") { | ||||
|     CHECK(allclose(A_pinv_again, A_pinv).item<bool>()); | ||||
|   } | ||||
| } | ||||
|  | ||||
| TEST_CASE("test cross product") { | ||||
|   using namespace mlx::core::linalg; | ||||
|  | ||||
|   // Test for vectors of length 3 | ||||
|   array a = array({1.0, 2.0, 3.0}); | ||||
|   array b = array({4.0, 5.0, 6.0}); | ||||
|  | ||||
|   array expected = array( | ||||
|       {2.0 * 6.0 - 3.0 * 5.0, 3.0 * 4.0 - 1.0 * 6.0, 1.0 * 5.0 - 2.0 * 4.0}); | ||||
|  | ||||
|   array result = cross(a, b); | ||||
|   CHECK(allclose(result, expected).item<bool>()); | ||||
|  | ||||
|   // Test for vectors of length 3 with negative values | ||||
|   a = array({-1.0, -2.0, -3.0}); | ||||
|   b = array({4.0, -5.0, 6.0}); | ||||
|  | ||||
|   expected = array( | ||||
|       {-2.0 * 6.0 - (-3.0 * -5.0), | ||||
|        -3.0 * 4.0 - (-1.0 * 6.0), | ||||
|        -1.0 * -5.0 - (-2.0 * 4.0)}); | ||||
|  | ||||
|   result = cross(a, b); | ||||
|   CHECK(allclose(result, expected).item<bool>()); | ||||
|  | ||||
|   // Test for incorrect vector size (should throw) | ||||
|   b = array({1.0, 2.0}); | ||||
|   expected = array( | ||||
|       {-2.0 * 0.0 - (-3.0 * 2.0), | ||||
|        -3.0 * 1.0 - (-1.0 * 0.0), | ||||
|        -1.0 * 2.0 - (-2.0 * 1.0)}); | ||||
|  | ||||
|   result = cross(a, b); | ||||
|   CHECK(allclose(result, expected).item<bool>()); | ||||
|  | ||||
|   // Test for vectors of length 3 with integer values | ||||
|   a = array({1, 2, 3}); | ||||
|   b = array({4, 5, 6}); | ||||
|  | ||||
|   expected = array({2 * 6 - 3 * 5, 3 * 4 - 1 * 6, 1 * 5 - 2 * 4}); | ||||
|  | ||||
|   result = cross(a, b); | ||||
|   CHECK(allclose(result, expected).item<bool>()); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Nripesh Niketan
					Nripesh Niketan