mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix the top-k op (#768)
This commit is contained in:
parent
d5964a2710
commit
8e281c76c3
14
mlx/ops.cpp
14
mlx/ops.cpp
@ -1721,24 +1721,28 @@ array topk(const array& a, int k, StreamOrDevice s /* = {}*/) {
|
|||||||
array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
|
array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
|
||||||
// Check for valid axis
|
// Check for valid axis
|
||||||
int axis_ = axis < 0 ? axis + a.ndim() : axis;
|
int axis_ = axis < 0 ? axis + a.ndim() : axis;
|
||||||
int kth_ = k < 0 ? k + a.shape(axis) : k;
|
|
||||||
if (axis_ < 0 || axis_ >= static_cast<int>(a.ndim())) {
|
if (axis_ < 0 || axis_ >= static_cast<int>(a.ndim())) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[topk] Received invalid axis " << axis << " for array with "
|
msg << "[topk] Received invalid axis " << axis << " for array with "
|
||||||
<< a.ndim() << " dimensions.";
|
<< a.ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (kth_ < 0 || kth_ >= a.shape(axis_)) {
|
if (k < 0 || k > a.shape(axis_)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[topk] Received invalid k " << k << "along axis " << axis
|
msg << "[topk] Received invalid k=" << k << " along axis " << axis
|
||||||
<< " for array with shape: " << a.shape();
|
<< " for array with shape: " << a.shape();
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
array a_partitioned = partition(a, kth_, axis_, s);
|
// Return early if the whole input was requested.
|
||||||
|
if (k == a.shape(axis_)) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
array a_partitioned = partition(a, -k, axis_, s);
|
||||||
std::vector<int> slice_starts(a.ndim(), 0);
|
std::vector<int> slice_starts(a.ndim(), 0);
|
||||||
std::vector<int> slice_ends = a.shape();
|
std::vector<int> slice_ends = a.shape();
|
||||||
slice_starts[axis_] = kth_;
|
slice_starts[axis_] = a.shape(axis_) - k;
|
||||||
return slice(a_partitioned, slice_starts, slice_ends, s);
|
return slice(a_partitioned, slice_starts, slice_ends, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1589,7 +1589,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
shape = (3, 4, 5)
|
shape = (3, 4, 5)
|
||||||
for dtype in ("int32", "float32"):
|
for dtype in ("int32", "float32"):
|
||||||
for axis in (None, 0, 1, 2):
|
for axis in (None, 0, 1, 2):
|
||||||
for kth in (-2, 2):
|
for kth in (-2, 0, 2):
|
||||||
with self.subTest(dtype=dtype, axis=axis, kth=kth):
|
with self.subTest(dtype=dtype, axis=axis, kth=kth):
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
np_dtype = getattr(np, dtype)
|
np_dtype = getattr(np, dtype)
|
||||||
@ -1605,13 +1605,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.array_equal(c_np, c_mx))
|
self.assertTrue(np.array_equal(c_np, c_mx))
|
||||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||||
|
|
||||||
top_k_mx = mx.topk(a_mx, kth, axis=axis)
|
|
||||||
self.assertTrue(np.all(c_np <= top_k_mx))
|
|
||||||
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
|
|
||||||
|
|
||||||
if kth >= 0:
|
if kth >= 0:
|
||||||
d_np = np.take(b_mx, np.arange(kth), axis=axis)
|
top_k_mx = mx.topk(a_mx, kth, axis=axis)
|
||||||
self.assertTrue(np.all(d_np <= c_mx))
|
top_k_np = np.take(
|
||||||
|
np.partition(a_np, -kth, axis=axis), (-kth,), axis=axis
|
||||||
|
)
|
||||||
|
self.assertTrue(np.all(top_k_np <= top_k_mx))
|
||||||
|
self.assertEqual(top_k_mx.dtype, a_mx.dtype)
|
||||||
|
N = a_mx.shape[axis] if axis is not None else a_mx.size
|
||||||
|
M = top_k_mx.shape[axis or 0]
|
||||||
|
self.assertEqual(M, (kth + N) % N)
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
os.getenv("LOW_MEMORY", None) is not None,
|
os.getenv("LOW_MEMORY", None) is not None,
|
||||||
|
@ -2859,3 +2859,22 @@ TEST_CASE("test atleast_3d vector") {
|
|||||||
CHECK_EQ(out[2].ndim(), 3);
|
CHECK_EQ(out[2].ndim(), 3);
|
||||||
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
|
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test topk") {
|
||||||
|
auto x = reshape(arange(10), {2, 5});
|
||||||
|
|
||||||
|
{
|
||||||
|
auto y = topk(x, 1, 1);
|
||||||
|
CHECK(array_equal(y, array({4, 9}, {2, 1})).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto y = topk(x, 2, 0);
|
||||||
|
CHECK(array_equal(y, x).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto y = topk(x, 1, 0);
|
||||||
|
CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user