mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
feat: add cross_product (#1252)
* feat: add cross_product * lint * python binding * refactor: Improve error message for cross_product function * refactor: more close to numpy cross product * refactor: improve error message for cross_product function * finish * fix acks * allow old numpy * doc --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -390,3 +390,48 @@ TEST_CASE("test matrix pseudo-inverse") {
|
||||
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test cross product") {
|
||||
using namespace mlx::core::linalg;
|
||||
|
||||
// Test for vectors of length 3
|
||||
array a = array({1.0, 2.0, 3.0});
|
||||
array b = array({4.0, 5.0, 6.0});
|
||||
|
||||
array expected = array(
|
||||
{2.0 * 6.0 - 3.0 * 5.0, 3.0 * 4.0 - 1.0 * 6.0, 1.0 * 5.0 - 2.0 * 4.0});
|
||||
|
||||
array result = cross(a, b);
|
||||
CHECK(allclose(result, expected).item<bool>());
|
||||
|
||||
// Test for vectors of length 3 with negative values
|
||||
a = array({-1.0, -2.0, -3.0});
|
||||
b = array({4.0, -5.0, 6.0});
|
||||
|
||||
expected = array(
|
||||
{-2.0 * 6.0 - (-3.0 * -5.0),
|
||||
-3.0 * 4.0 - (-1.0 * 6.0),
|
||||
-1.0 * -5.0 - (-2.0 * 4.0)});
|
||||
|
||||
result = cross(a, b);
|
||||
CHECK(allclose(result, expected).item<bool>());
|
||||
|
||||
// Test for incorrect vector size (should throw)
|
||||
b = array({1.0, 2.0});
|
||||
expected = array(
|
||||
{-2.0 * 0.0 - (-3.0 * 2.0),
|
||||
-3.0 * 1.0 - (-1.0 * 0.0),
|
||||
-1.0 * 2.0 - (-2.0 * 1.0)});
|
||||
|
||||
result = cross(a, b);
|
||||
CHECK(allclose(result, expected).item<bool>());
|
||||
|
||||
// Test for vectors of length 3 with integer values
|
||||
a = array({1, 2, 3});
|
||||
b = array({4, 5, 6});
|
||||
|
||||
expected = array({2 * 6 - 3 * 5, 3 * 4 - 1 * 6, 1 * 5 - 2 * 4});
|
||||
|
||||
result = cross(a, b);
|
||||
CHECK(allclose(result, expected).item<bool>());
|
||||
}
|
||||
|
Reference in New Issue
Block a user