Fix the top-k op (#768)

This commit is contained in:
Angelos Katharopoulos 2024-03-01 22:08:43 -08:00 committed by GitHub
parent d5964a2710
commit 8e281c76c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 13 deletions

View File

@ -1721,24 +1721,28 @@ array topk(const array& a, int k, StreamOrDevice s /* = {}*/) {
array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) { array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
// Check for valid axis // Check for valid axis
int axis_ = axis < 0 ? axis + a.ndim() : axis; int axis_ = axis < 0 ? axis + a.ndim() : axis;
int kth_ = k < 0 ? k + a.shape(axis) : k;
if (axis_ < 0 || axis_ >= static_cast<int>(a.ndim())) { if (axis_ < 0 || axis_ >= static_cast<int>(a.ndim())) {
std::ostringstream msg; std::ostringstream msg;
msg << "[topk] Received invalid axis " << axis << " for array with " msg << "[topk] Received invalid axis " << axis << " for array with "
<< a.ndim() << " dimensions."; << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (kth_ < 0 || kth_ >= a.shape(axis_)) { if (k < 0 || k > a.shape(axis_)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[topk] Received invalid k " << k << "along axis " << axis msg << "[topk] Received invalid k=" << k << " along axis " << axis
<< " for array with shape: " << a.shape(); << " for array with shape: " << a.shape();
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
array a_partitioned = partition(a, kth_, axis_, s); // Return early if the whole input was requested.
if (k == a.shape(axis_)) {
return a;
}
array a_partitioned = partition(a, -k, axis_, s);
std::vector<int> slice_starts(a.ndim(), 0); std::vector<int> slice_starts(a.ndim(), 0);
std::vector<int> slice_ends = a.shape(); std::vector<int> slice_ends = a.shape();
slice_starts[axis_] = kth_; slice_starts[axis_] = a.shape(axis_) - k;
return slice(a_partitioned, slice_starts, slice_ends, s); return slice(a_partitioned, slice_starts, slice_ends, s);
} }

View File

@ -1589,7 +1589,7 @@ class TestOps(mlx_tests.MLXTestCase):
shape = (3, 4, 5) shape = (3, 4, 5)
for dtype in ("int32", "float32"): for dtype in ("int32", "float32"):
for axis in (None, 0, 1, 2): for axis in (None, 0, 1, 2):
for kth in (-2, 2): for kth in (-2, 0, 2):
with self.subTest(dtype=dtype, axis=axis, kth=kth): with self.subTest(dtype=dtype, axis=axis, kth=kth):
np.random.seed(0) np.random.seed(0)
np_dtype = getattr(np, dtype) np_dtype = getattr(np, dtype)
@ -1605,13 +1605,16 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(c_np, c_mx)) self.assertTrue(np.array_equal(c_np, c_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype) self.assertEqual(b_mx.dtype, a_mx.dtype)
top_k_mx = mx.topk(a_mx, kth, axis=axis)
self.assertTrue(np.all(c_np <= top_k_mx))
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
if kth >= 0: if kth >= 0:
d_np = np.take(b_mx, np.arange(kth), axis=axis) top_k_mx = mx.topk(a_mx, kth, axis=axis)
self.assertTrue(np.all(d_np <= c_mx)) top_k_np = np.take(
np.partition(a_np, -kth, axis=axis), (-kth,), axis=axis
)
self.assertTrue(np.all(top_k_np <= top_k_mx))
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
N = a_mx.shape[axis] if axis is not None else a_mx.size
M = top_k_mx.shape[axis or 0]
self.assertEqual(M, (kth + N) % N)
@unittest.skipIf( @unittest.skipIf(
os.getenv("LOW_MEMORY", None) is not None, os.getenv("LOW_MEMORY", None) is not None,

View File

@ -2858,4 +2858,23 @@ TEST_CASE("test atleast_3d vector") {
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3, 1}); CHECK_EQ(out[1].shape(), std::vector<int>{1, 3, 1});
CHECK_EQ(out[2].ndim(), 3); CHECK_EQ(out[2].ndim(), 3);
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1}); CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
} }
TEST_CASE("test topk") {
auto x = reshape(arange(10), {2, 5});
{
auto y = topk(x, 1, 1);
CHECK(array_equal(y, array({4, 9}, {2, 1})).item<bool>());
}
{
auto y = topk(x, 2, 0);
CHECK(array_equal(y, x).item<bool>());
}
{
auto y = topk(x, 1, 0);
CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());
}
}