mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Add mx.meshgrid (#961)
This commit is contained in:
@@ -3156,3 +3156,28 @@ TEST_CASE("test topk") {
|
||||
CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test meshgrid") {
|
||||
// Test default
|
||||
auto x = array({1, 2, 3}, {3});
|
||||
auto in = std::vector<array>{x};
|
||||
auto out = meshgrid(in);
|
||||
CHECK(array_equal(out[0], x).item<bool>());
|
||||
|
||||
// Test different lengths
|
||||
auto y = array({4, 5}, {2});
|
||||
in = std::vector<array>{x, y};
|
||||
out = meshgrid(in);
|
||||
auto expected_zero = array({1, 2, 3, 1, 2, 3}, {2, 3});
|
||||
auto expected_one = array({4, 4, 4, 5, 5, 5}, {2, 3});
|
||||
CHECK(array_equal(out[0], expected_zero).item<bool>());
|
||||
CHECK(array_equal(out[1], expected_one).item<bool>());
|
||||
|
||||
// Test sparse true
|
||||
in = std::vector<array>{x, x};
|
||||
out = meshgrid(in, true);
|
||||
expected_zero = array({1, 2, 3}, {1, 3});
|
||||
expected_one = array({1, 2, 3}, {3, 1});
|
||||
CHECK(array_equal(out[0], expected_zero).item<bool>());
|
||||
CHECK(array_equal(out[1], expected_one).item<bool>());
|
||||
}
|
Reference in New Issue
Block a user