Add mx.meshgrid (#961)

This commit is contained in:
Abe Leininger
2024-04-09 14:43:08 -04:00
committed by GitHub
parent ae812350f9
commit a1a31eed27
6 changed files with 161 additions and 1 deletions

View File

@@ -3156,3 +3156,28 @@ TEST_CASE("test topk") {
CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());
}
}
TEST_CASE("test meshgrid") {
// Test default
auto x = array({1, 2, 3}, {3});
auto in = std::vector<array>{x};
auto out = meshgrid(in);
CHECK(array_equal(out[0], x).item<bool>());
// Test different lengths
auto y = array({4, 5}, {2});
in = std::vector<array>{x, y};
out = meshgrid(in);
auto expected_zero = array({1, 2, 3, 1, 2, 3}, {2, 3});
auto expected_one = array({4, 4, 4, 5, 5, 5}, {2, 3});
CHECK(array_equal(out[0], expected_zero).item<bool>());
CHECK(array_equal(out[1], expected_one).item<bool>());
// Test sparse true
in = std::vector<array>{x, x};
out = meshgrid(in, true);
expected_zero = array({1, 2, 3}, {1, 3});
expected_one = array({1, 2, 3}, {3, 1});
CHECK(array_equal(out[0], expected_zero).item<bool>());
CHECK(array_equal(out[1], expected_one).item<bool>());
}