From 8e281c76c34b1440d9d488eda42ce1b73ef2e4b8 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 1 Mar 2024 22:08:43 -0800 Subject: [PATCH] Fix the top-k op (#768) --- mlx/ops.cpp | 14 +++++++++----- python/tests/test_ops.py | 17 ++++++++++------- tests/ops_tests.cpp | 21 ++++++++++++++++++++- 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 774bb2285..04b2a8e5b 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1721,24 +1721,28 @@ array topk(const array& a, int k, StreamOrDevice s /* = {}*/) { array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) { // Check for valid axis int axis_ = axis < 0 ? axis + a.ndim() : axis; - int kth_ = k < 0 ? k + a.shape(axis) : k; if (axis_ < 0 || axis_ >= static_cast(a.ndim())) { std::ostringstream msg; msg << "[topk] Received invalid axis " << axis << " for array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - if (kth_ < 0 || kth_ >= a.shape(axis_)) { + if (k < 0 || k > a.shape(axis_)) { std::ostringstream msg; - msg << "[topk] Received invalid k " << k << "along axis " << axis + msg << "[topk] Received invalid k=" << k << " along axis " << axis << " for array with shape: " << a.shape(); throw std::invalid_argument(msg.str()); } - array a_partitioned = partition(a, kth_, axis_, s); + // Return early if the whole input was requested. + if (k == a.shape(axis_)) { + return a; + } + + array a_partitioned = partition(a, -k, axis_, s); std::vector slice_starts(a.ndim(), 0); std::vector slice_ends = a.shape(); - slice_starts[axis_] = kth_; + slice_starts[axis_] = a.shape(axis_) - k; return slice(a_partitioned, slice_starts, slice_ends, s); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 1a504dd45..fe935ebc8 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1589,7 +1589,7 @@ class TestOps(mlx_tests.MLXTestCase): shape = (3, 4, 5) for dtype in ("int32", "float32"): for axis in (None, 0, 1, 2): - for kth in (-2, 2): + for kth in (-2, 0, 2): with self.subTest(dtype=dtype, axis=axis, kth=kth): np.random.seed(0) np_dtype = getattr(np, dtype) @@ -1605,13 +1605,16 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.array_equal(c_np, c_mx)) self.assertEqual(b_mx.dtype, a_mx.dtype) - top_k_mx = mx.topk(a_mx, kth, axis=axis) - self.assertTrue(np.all(c_np <= top_k_mx)) - self.assertEqual(top_k_mx.dtype, a_mx.dtype) - if kth >= 0: - d_np = np.take(b_mx, np.arange(kth), axis=axis) - self.assertTrue(np.all(d_np <= c_mx)) + top_k_mx = mx.topk(a_mx, kth, axis=axis) + top_k_np = np.take( + np.partition(a_np, -kth, axis=axis), (-kth,), axis=axis + ) + self.assertTrue(np.all(top_k_np <= top_k_mx)) + self.assertEqual(top_k_mx.dtype, a_mx.dtype) + N = a_mx.shape[axis] if axis is not None else a_mx.size + M = top_k_mx.shape[axis or 0] + self.assertEqual(M, (kth + N) % N) @unittest.skipIf( os.getenv("LOW_MEMORY", None) is not None, diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index fb4bfb78f..4b05425e7 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2858,4 +2858,23 @@ TEST_CASE("test atleast_3d vector") { CHECK_EQ(out[1].shape(), std::vector{1, 3, 1}); CHECK_EQ(out[2].ndim(), 3); CHECK_EQ(out[2].shape(), std::vector{3, 1, 1}); -} \ No newline at end of file +} + +TEST_CASE("test topk") { + auto x = reshape(arange(10), {2, 5}); + + { + auto y = topk(x, 1, 1); + CHECK(array_equal(y, array({4, 9}, {2, 1})).item()); + } + + { + auto y = topk(x, 2, 0); + CHECK(array_equal(y, x).item()); + } + + { + auto y = topk(x, 1, 0); + CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item()); + } +}