feat: add cross_product (#1252)

* feat: add cross_product

* lint

* python binding

* refactor: Improve error message for cross_product function

* refactor: more close to numpy cross product

* refactor: improve error message for cross_product function

* finish

* fix acks

* allow old numpy

* doc

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-09-17 21:12:43 +01:00
committed by GitHub
parent 4f46e9c997
commit 6af5ca35b2
7 changed files with 204 additions and 1 deletions

View File

@@ -390,3 +390,48 @@ TEST_CASE("test matrix pseudo-inverse") {
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
}
}
TEST_CASE("test cross product") {
using namespace mlx::core::linalg;
// Test for vectors of length 3
array a = array({1.0, 2.0, 3.0});
array b = array({4.0, 5.0, 6.0});
array expected = array(
{2.0 * 6.0 - 3.0 * 5.0, 3.0 * 4.0 - 1.0 * 6.0, 1.0 * 5.0 - 2.0 * 4.0});
array result = cross(a, b);
CHECK(allclose(result, expected).item<bool>());
// Test for vectors of length 3 with negative values
a = array({-1.0, -2.0, -3.0});
b = array({4.0, -5.0, 6.0});
expected = array(
{-2.0 * 6.0 - (-3.0 * -5.0),
-3.0 * 4.0 - (-1.0 * 6.0),
-1.0 * -5.0 - (-2.0 * 4.0)});
result = cross(a, b);
CHECK(allclose(result, expected).item<bool>());
// Test for incorrect vector size (should throw)
b = array({1.0, 2.0});
expected = array(
{-2.0 * 0.0 - (-3.0 * 2.0),
-3.0 * 1.0 - (-1.0 * 0.0),
-1.0 * 2.0 - (-2.0 * 1.0)});
result = cross(a, b);
CHECK(allclose(result, expected).item<bool>());
// Test for vectors of length 3 with integer values
a = array({1, 2, 3});
b = array({4, 5, 6});
expected = array({2 * 6 - 3 * 5, 3 * 4 - 1 * 6, 1 * 5 - 2 * 4});
result = cross(a, b);
CHECK(allclose(result, expected).item<bool>());
}