mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
parent
ebd7135b50
commit
7f914365fd
@ -522,13 +522,13 @@ template <
|
|||||||
bool ARG_SORT,
|
bool ARG_SORT,
|
||||||
short BLOCK_THREADS,
|
short BLOCK_THREADS,
|
||||||
short N_PER_THREAD>
|
short N_PER_THREAD>
|
||||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
|
[[kernel]] void mb_block_partition(
|
||||||
mb_block_partition(
|
|
||||||
device idx_t* block_partitions [[buffer(0)]],
|
device idx_t* block_partitions [[buffer(0)]],
|
||||||
const device val_t* dev_vals [[buffer(1)]],
|
const device val_t* dev_vals [[buffer(1)]],
|
||||||
const device idx_t* dev_idxs [[buffer(2)]],
|
const device idx_t* dev_idxs [[buffer(2)]],
|
||||||
const constant int& size_sorted_axis [[buffer(3)]],
|
const constant int& size_sorted_axis [[buffer(3)]],
|
||||||
const constant int& merge_tiles [[buffer(4)]],
|
const constant int& merge_tiles [[buffer(4)]],
|
||||||
|
const constant int& n_blocks [[buffer(5)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
uint3 lid [[thread_position_in_threadgroup]],
|
||||||
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
uint3 tgp_dims [[threads_per_threadgroup]]) {
|
||||||
@ -543,23 +543,29 @@ mb_block_partition(
|
|||||||
dev_vals += tid.y * size_sorted_axis;
|
dev_vals += tid.y * size_sorted_axis;
|
||||||
dev_idxs += tid.y * size_sorted_axis;
|
dev_idxs += tid.y * size_sorted_axis;
|
||||||
|
|
||||||
// Find location in merge step
|
for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
|
||||||
int merge_group = lid.x / merge_tiles;
|
// Find location in merge step
|
||||||
int merge_lane = lid.x % merge_tiles;
|
int merge_group = i / merge_tiles;
|
||||||
|
int merge_lane = i % merge_tiles;
|
||||||
|
|
||||||
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
|
||||||
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
|
||||||
|
|
||||||
int A_st = min(size_sorted_axis, sort_st);
|
int A_st = min(size_sorted_axis, sort_st);
|
||||||
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
|
||||||
int B_st = A_ed;
|
int B_st = A_ed;
|
||||||
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
|
||||||
|
|
||||||
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
|
||||||
int partition = sort_kernel::merge_partition(
|
int partition = sort_kernel::merge_partition(
|
||||||
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
|
dev_vals + A_st,
|
||||||
|
dev_vals + B_st,
|
||||||
|
A_ed - A_st,
|
||||||
|
B_ed - B_st,
|
||||||
|
partition_at);
|
||||||
|
|
||||||
block_partitions[lid.x] = A_st + partition;
|
block_partitions[i] = A_st + partition;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
|
@ -177,6 +177,8 @@ void multi_block_sort(
|
|||||||
array dev_vals_out = dev_vals_1;
|
array dev_vals_out = dev_vals_1;
|
||||||
array dev_idxs_out = dev_idxs_1;
|
array dev_idxs_out = dev_idxs_1;
|
||||||
|
|
||||||
|
int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;
|
||||||
|
|
||||||
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
|
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
|
||||||
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
||||||
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
||||||
@ -199,8 +201,9 @@ void multi_block_sort(
|
|||||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
||||||
|
compute_encoder->setBytes(&n_blocks, sizeof(int), 5);
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
|
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
|
||||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||||
|
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
@ -1785,15 +1785,6 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Fix GPU kernel
|
|
||||||
if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[sort] GPU sort cannot handle sort axis of >= 2M elements,"
|
|
||||||
<< " got array with sort axis size " << a.shape(axis) << "."
|
|
||||||
<< " Please place this operation on the CPU instead.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
a.shape(), a.dtype(), std::make_shared<Sort>(to_stream(s), axis), {a});
|
a.shape(), a.dtype(), std::make_shared<Sort>(to_stream(s), axis), {a});
|
||||||
}
|
}
|
||||||
|
@ -1840,6 +1840,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.array_equal(c_np, c_mx))
|
self.assertTrue(np.array_equal(c_np, c_mx))
|
||||||
self.assertEqual(b_mx.dtype, c_mx.dtype)
|
self.assertEqual(b_mx.dtype, c_mx.dtype)
|
||||||
|
|
||||||
|
# Test very large array
|
||||||
|
if mx.default_device() == mx.gpu:
|
||||||
|
a_np = np.random.normal(20, 20, size=(2**22)).astype(np.float32)
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
|
||||||
|
b_np = np.sort(a_np)
|
||||||
|
b_mx = mx.sort(a_mx)
|
||||||
|
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||||
|
|
||||||
def test_partition(self):
|
def test_partition(self):
|
||||||
shape = (3, 4, 5)
|
shape = (3, 4, 5)
|
||||||
for dtype in ("int32", "float32"):
|
for dtype in ("int32", "float32"):
|
||||||
|
Loading…
Reference in New Issue
Block a user