Fix the top-k op (#768)

This commit is contained in:
Angelos Katharopoulos
2024-03-01 22:08:43 -08:00
committed by GitHub
parent d5964a2710
commit 8e281c76c3
3 changed files with 39 additions and 13 deletions

View File

@@ -2858,4 +2858,23 @@ TEST_CASE("test atleast_3d vector") {
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3, 1});
CHECK_EQ(out[2].ndim(), 3);
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
}
}
TEST_CASE("test topk") {
auto x = reshape(arange(10), {2, 5});
{
auto y = topk(x, 1, 1);
CHECK(array_equal(y, array({4, 9}, {2, 1})).item<bool>());
}
{
auto y = topk(x, 2, 0);
CHECK(array_equal(y, x).item<bool>());
}
{
auto y = topk(x, 1, 0);
CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());
}
}