mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
Fix multiblock sort limits (#906)
* Fix multiblock sort limits * Fix metal validation error
This commit is contained in:
parent
5611e1a95e
commit
925014b661
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
@ -102,6 +102,11 @@ void multi_block_sort(
|
|||||||
|
|
||||||
int nc_dim = nc_shape.size();
|
int nc_dim = nc_shape.size();
|
||||||
|
|
||||||
|
if (nc_dim == 0) {
|
||||||
|
nc_shape = {0};
|
||||||
|
nc_str = {1};
|
||||||
|
}
|
||||||
|
|
||||||
int size_sorted_axis = in.shape(axis);
|
int size_sorted_axis = in.shape(axis);
|
||||||
int stride_sorted_axis = in.strides()[axis];
|
int stride_sorted_axis = in.strides()[axis];
|
||||||
|
|
||||||
@ -143,8 +148,9 @@ void multi_block_sort(
|
|||||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
||||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
|
compute_encoder->setBytes(
|
||||||
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7);
|
nc_shape.data(), nc_shape.size() * sizeof(int), 6);
|
||||||
|
compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7);
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||||
@ -158,7 +164,8 @@ void multi_block_sort(
|
|||||||
array dev_idxs_in = dev_idxs_0;
|
array dev_idxs_in = dev_idxs_0;
|
||||||
array dev_vals_out = dev_vals_1;
|
array dev_vals_out = dev_vals_1;
|
||||||
array dev_idxs_out = dev_idxs_1;
|
array dev_idxs_out = dev_idxs_1;
|
||||||
for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) {
|
|
||||||
|
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
|
||||||
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
||||||
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
||||||
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;
|
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;
|
||||||
|
@ -1597,6 +1597,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.array_equal(d_np, d_mx))
|
self.assertTrue(np.array_equal(d_np, d_mx))
|
||||||
self.assertEqual(c_mx.dtype, mx.uint32)
|
self.assertEqual(c_mx.dtype, mx.uint32)
|
||||||
|
|
||||||
|
# Test multi-block sort
|
||||||
|
a_np = np.random.normal(size=(32769,)).astype(np.float32)
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
|
||||||
|
b_np = np.sort(a_np)
|
||||||
|
b_mx = mx.sort(a_mx)
|
||||||
|
|
||||||
|
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||||
|
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||||
|
|
||||||
def test_partition(self):
|
def test_partition(self):
|
||||||
shape = (3, 4, 5)
|
shape = (3, 4, 5)
|
||||||
for dtype in ("int32", "float32"):
|
for dtype in ("int32", "float32"):
|
||||||
|
Loading…
Reference in New Issue
Block a user