mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
parent
ebd7135b50
commit
7f914365fd
@ -522,13 +522,13 @@ template <
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
||||
mb_block_partition(
|
||||
[[kernel]] void mb_block_partition(
|
||||
device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals [[buffer(1)]],
|
||||
const device idx_t* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
const constant int& n_blocks [[buffer(5)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||
@ -543,23 +543,29 @@ mb_block_partition(
|
||||
dev_vals += tid.y * size_sorted_axis;
|
||||
dev_idxs += tid.y * size_sorted_axis;
|
||||
|
||||
// Find location in merge step
|
||||
int merge_group = lid.x / merge_tiles;
|
||||
int merge_lane = lid.x % merge_tiles;
|
||||
for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
|
||||
// Find location in merge step
|
||||
int merge_group = i / merge_tiles;
|
||||
int merge_lane = i % merge_tiles;
|
||||
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||
|
||||
int A_st = min(size_sorted_axis, sort_st);
|
||||
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||
int B_st = A_ed;
|
||||
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
||||
int A_st = min(size_sorted_axis, sort_st);
|
||||
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||
int B_st = A_ed;
|
||||
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
||||
|
||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||
int partition = sort_kernel::merge_partition(
|
||||
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
|
||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||
int partition = sort_kernel::merge_partition(
|
||||
dev_vals + A_st,
|
||||
dev_vals + B_st,
|
||||
A_ed - A_st,
|
||||
B_ed - B_st,
|
||||
partition_at);
|
||||
|
||||
block_partitions[lid.x] = A_st + partition;
|
||||
block_partitions[i] = A_st + partition;
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
|
@ -177,6 +177,8 @@ void multi_block_sort(
|
||||
array dev_vals_out = dev_vals_1;
|
||||
array dev_idxs_out = dev_idxs_1;
|
||||
|
||||
int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;
|
||||
|
||||
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
|
||||
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
||||
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
||||
@ -199,8 +201,9 @@ void multi_block_sort(
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&n_blocks, sizeof(int), 5);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
@ -1785,15 +1785,6 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// TODO: Fix GPU kernel
|
||||
if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) {
|
||||
std::ostringstream msg;
|
||||
msg << "[sort] GPU sort cannot handle sort axis of >= 2M elements,"
|
||||
<< " got array with sort axis size " << a.shape(axis) << "."
|
||||
<< " Please place this operation on the CPU instead.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return array(
|
||||
a.shape(), a.dtype(), std::make_shared<Sort>(to_stream(s), axis), {a});
|
||||
}
|
||||
|
@ -1840,6 +1840,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.array_equal(c_np, c_mx))
|
||||
self.assertEqual(b_mx.dtype, c_mx.dtype)
|
||||
|
||||
# Test very large array
|
||||
if mx.default_device() == mx.gpu:
|
||||
a_np = np.random.normal(20, 20, size=(2**22)).astype(np.float32)
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
b_np = np.sort(a_np)
|
||||
b_mx = mx.sort(a_mx)
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
|
||||
def test_partition(self):
|
||||
shape = (3, 4, 5)
|
||||
for dtype in ("int32", "float32"):
|
||||
|
Loading…
Reference in New Issue
Block a user