Fix the top-k op (#768)

This commit is contained in:
Angelos Katharopoulos
2024-03-01 22:08:43 -08:00
committed by GitHub
parent d5964a2710
commit 8e281c76c3
3 changed files with 39 additions and 13 deletions

View File

@@ -1721,24 +1721,28 @@ array topk(const array& a, int k, StreamOrDevice s /* = {}*/) {
array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
// Check for valid axis
int axis_ = axis < 0 ? axis + a.ndim() : axis;
int kth_ = k < 0 ? k + a.shape(axis) : k;
if (axis_ < 0 || axis_ >= static_cast<int>(a.ndim())) {
std::ostringstream msg;
msg << "[topk] Received invalid axis " << axis << " for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (kth_ < 0 || kth_ >= a.shape(axis_)) {
if (k < 0 || k > a.shape(axis_)) {
std::ostringstream msg;
msg << "[topk] Received invalid k " << k << "along axis " << axis
msg << "[topk] Received invalid k=" << k << " along axis " << axis
<< " for array with shape: " << a.shape();
throw std::invalid_argument(msg.str());
}
array a_partitioned = partition(a, kth_, axis_, s);
// Return early if the whole input was requested.
if (k == a.shape(axis_)) {
return a;
}
array a_partitioned = partition(a, -k, axis_, s);
std::vector<int> slice_starts(a.ndim(), 0);
std::vector<int> slice_ends = a.shape();
slice_starts[axis_] = kth_;
slice_starts[axis_] = a.shape(axis_) - k;
return slice(a_partitioned, slice_starts, slice_ends, s);
}