mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
faster sort (#1831)
This commit is contained in:
@@ -274,16 +274,20 @@ void gpu_merge_sort(
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
|
||||
// Get kernel size
|
||||
int tn = 8;
|
||||
int bn = 128;
|
||||
int tn = 4;
|
||||
int potential_bn = (size_sorted_axis + tn - 1) / tn;
|
||||
|
||||
int bn;
|
||||
if (potential_bn > 256) {
|
||||
bn = 512;
|
||||
} else if (potential_bn > 128) {
|
||||
bn = 256;
|
||||
} else {
|
||||
} else if (potential_bn > 64) {
|
||||
bn = 128;
|
||||
} else if (potential_bn > 32) {
|
||||
bn = 64;
|
||||
} else {
|
||||
bn = 32;
|
||||
}
|
||||
|
||||
if (bn == 512 && size_of(in.dtype()) > 4) {
|
||||
|
||||
Reference in New Issue
Block a user