mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Fix the top-k op (#768)
This commit is contained in:

committed by
GitHub

parent
d5964a2710
commit
8e281c76c3
14
mlx/ops.cpp
14
mlx/ops.cpp
@@ -1721,24 +1721,28 @@ array topk(const array& a, int k, StreamOrDevice s /* = {}*/) {
|
||||
array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
|
||||
// Check for valid axis
|
||||
int axis_ = axis < 0 ? axis + a.ndim() : axis;
|
||||
int kth_ = k < 0 ? k + a.shape(axis) : k;
|
||||
if (axis_ < 0 || axis_ >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[topk] Received invalid axis " << axis << " for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (kth_ < 0 || kth_ >= a.shape(axis_)) {
|
||||
if (k < 0 || k > a.shape(axis_)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[topk] Received invalid k " << k << "along axis " << axis
|
||||
msg << "[topk] Received invalid k=" << k << " along axis " << axis
|
||||
<< " for array with shape: " << a.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
array a_partitioned = partition(a, kth_, axis_, s);
|
||||
// Return early if the whole input was requested.
|
||||
if (k == a.shape(axis_)) {
|
||||
return a;
|
||||
}
|
||||
|
||||
array a_partitioned = partition(a, -k, axis_, s);
|
||||
std::vector<int> slice_starts(a.ndim(), 0);
|
||||
std::vector<int> slice_ends = a.shape();
|
||||
slice_starts[axis_] = kth_;
|
||||
slice_starts[axis_] = a.shape(axis_) - k;
|
||||
return slice(a_partitioned, slice_starts, slice_ends, s);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user