Fix the top-k op (#768)

This commit is contained in:
Angelos Katharopoulos
2024-03-01 22:08:43 -08:00
committed by GitHub
parent d5964a2710
commit 8e281c76c3
3 changed files with 39 additions and 13 deletions

View File

@@ -1589,7 +1589,7 @@ class TestOps(mlx_tests.MLXTestCase):
shape = (3, 4, 5)
for dtype in ("int32", "float32"):
for axis in (None, 0, 1, 2):
for kth in (-2, 2):
for kth in (-2, 0, 2):
with self.subTest(dtype=dtype, axis=axis, kth=kth):
np.random.seed(0)
np_dtype = getattr(np, dtype)
@@ -1605,13 +1605,16 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(c_np, c_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
top_k_mx = mx.topk(a_mx, kth, axis=axis)
self.assertTrue(np.all(c_np <= top_k_mx))
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
if kth >= 0:
d_np = np.take(b_mx, np.arange(kth), axis=axis)
self.assertTrue(np.all(d_np <= c_mx))
top_k_mx = mx.topk(a_mx, kth, axis=axis)
top_k_np = np.take(
np.partition(a_np, -kth, axis=axis), (-kth,), axis=axis
)
self.assertTrue(np.all(top_k_np <= top_k_mx))
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
N = a_mx.shape[axis] if axis is not None else a_mx.size
M = top_k_mx.shape[axis or 0]
self.assertEqual(M, (kth + N) % N)
@unittest.skipIf(
os.getenv("LOW_MEMORY", None) is not None,