mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 08:24:39 +08:00
Fix the top-k op (#768)
This commit is contained in:

committed by
GitHub

parent
d5964a2710
commit
8e281c76c3
@@ -2858,4 +2858,23 @@ TEST_CASE("test atleast_3d vector") {
|
||||
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3, 1});
|
||||
CHECK_EQ(out[2].ndim(), 3);
|
||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test topk") {
|
||||
auto x = reshape(arange(10), {2, 5});
|
||||
|
||||
{
|
||||
auto y = topk(x, 1, 1);
|
||||
CHECK(array_equal(y, array({4, 9}, {2, 1})).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto y = topk(x, 2, 0);
|
||||
CHECK(array_equal(y, x).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto y = topk(x, 1, 0);
|
||||
CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user