faster sort (#1831)

This commit is contained in:
Awni Hannun
2025-02-05 06:10:22 -08:00
committed by GitHub
parent a229c8cef0
commit fe5987b81d
3 changed files with 98 additions and 91 deletions

View File

@@ -274,16 +274,20 @@ void gpu_merge_sort(
int size_sorted_axis = in.shape(axis);
// Get kernel size
int tn = 8;
int bn = 128;
int tn = 4;
int potential_bn = (size_sorted_axis + tn - 1) / tn;
int bn;
if (potential_bn > 256) {
bn = 512;
} else if (potential_bn > 128) {
bn = 256;
} else {
} else if (potential_bn > 64) {
bn = 128;
} else if (potential_bn > 32) {
bn = 64;
} else {
bn = 32;
}
if (bn == 512 && size_of(in.dtype()) > 4) {