divmod, partition, sort fixes (#2302)

This commit is contained in:
Awni Hannun
2025-06-16 18:49:32 -07:00
committed by GitHub
parent bc53f8293f
commit b8022c578a
8 changed files with 271 additions and 49 deletions

View File

@@ -86,7 +86,6 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
axis += in.ndim();
}
int nsort = in.shape(axis);
int nsegments = in.data_size() / nsort;
int last_dim = in.ndim() - 1;
// If we are not sorting the innermost dimension of a contiguous array,
@@ -100,7 +99,11 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
encoder.launch_kernel([&](cudaStream_t stream) {
@@ -134,7 +137,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
indices.data<uint32_t>(),
out.data<uint32_t>(),
in.data_size(),
nsegments,
in.data_size() / nsort,
offsets,
offsets + 1,
stream);
@@ -144,7 +147,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.data<Type>(),
out.data<Type>(),
in.data_size(),
nsegments,
in.data_size() / nsort,
offsets,
offsets + 1,
stream);
@@ -177,4 +180,14 @@ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
gpu_sort(stream(), inputs[0], out, axis_, false);
}
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgPartition::eval_gpu");
gpu_sort(stream(), inputs[0], out, axis_, true);
}
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Partition::eval_gpu");
gpu_sort(stream(), inputs[0], out, axis_, false);
}
} // namespace mlx::core