mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
@@ -522,13 +522,13 @@ template <
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||
mb_block_partition(
|
||||
[[kernel]] void mb_block_partition(
|
||||
device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals [[buffer(1)]],
|
||||
const device idx_t* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
const constant int& n_blocks [[buffer(5)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||
@@ -543,23 +543,29 @@ mb_block_partition(
|
||||
dev_vals += tid.y * size_sorted_axis;
|
||||
dev_idxs += tid.y * size_sorted_axis;
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_tiles;
|
||||
int merge_lane = lid.x % merge_tiles;
|
||||
for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
|
||||
// Find location in merge step
|
||||
int merge_group = i / merge_tiles;
|
||||
int merge_lane = i % merge_tiles;
|
||||
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
|
||||
int A_st = min(size_sorted_axis, sort_st);
|
||||
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||
int B_st = A_ed;
|
||||
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
||||
int A_st = min(size_sorted_axis, sort_st);
|
||||
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||
int B_st = A_ed;
|
||||
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
||||
|
||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||
int partition = sort_kernel::merge_partition(
|
||||
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
|
||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||
int partition = sort_kernel::merge_partition(
|
||||
dev_vals + A_st,
|
||||
dev_vals + B_st,
|
||||
A_ed - A_st,
|
||||
B_ed - B_st,
|
||||
partition_at);
|
||||
|
||||
block_partitions[lid.x] = A_st + partition;
|
||||
block_partitions[i] = A_st + partition;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
|
@@ -177,6 +177,8 @@ void multi_block_sort(
|
||||
array dev_vals_out = dev_vals_1;
|
||||
array dev_idxs_out = dev_idxs_1;
|
||||
|
||||
int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;
|
||||
|
||||
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
|
||||
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
||||
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
||||
@@ -199,8 +201,9 @@ void multi_block_sort(
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&n_blocks, sizeof(int), 5);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
Reference in New Issue
Block a user