From ff1c6fc1488eea5528fd38d48935d2f788c76ee7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 16 May 2025 13:40:51 -0700 Subject: [PATCH] revert sort + flaky test --- mlx/backend/metal/sort.cpp | 41 +++++++++++++++++++++++++++----------- python/tests/test_vmap.py | 1 + 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index d7cfd12b4..3c84022f2 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -21,8 +21,6 @@ void single_block_sort( int bn, int tn, bool argsort) { - out.set_data(allocator::malloc(out.nbytes())); - // Prepare shapes int n_rows = in.size() / in.shape(axis); @@ -158,6 +156,9 @@ void multi_block_sort( dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes())); block_partitions.set_data(allocator::malloc(block_partitions.nbytes())); + std::vector copies = { + dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions}; + // Prepare command encoder auto& compute_encoder = d.get_command_encoder(s.index); @@ -249,17 +250,25 @@ void multi_block_sort( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } } - out.copy_shared_buffer( - argsort ? dev_idxs_out : dev_vals_out, + + // Copy outputs with appropriate strides + auto strides = out.strides(); + for (int ax = axis + 1; ax < strides.size(); ax++) { + strides[ax] *= out.shape(axis); + } + strides[axis] = 1; + copy_gpu_inplace( + (argsort) ? dev_idxs_out : dev_vals_out, + out, + out.shape(), + strides, out.strides(), - out.flags(), - out.data_size()); - d.add_temporaries( - {dev_vals_in, - dev_idxs_in, - argsort ? dev_vals_in : dev_idxs_in, - block_partitions}, - s.index); + 0, + 0, + (axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General, + s); + + d.add_temporaries(std::move(copies), s.index); } void gpu_merge_sort( @@ -309,6 +318,8 @@ void gpu_merge_sort( void ArgSort::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); + out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); auto& d = metal::device(s.device); auto& in = inputs[0]; @@ -319,6 +330,8 @@ void ArgSort::eval_gpu(const std::vector& inputs, array& out) { void Sort::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); + out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); auto& d = metal::device(s.device); auto& in = inputs[0]; @@ -330,6 +343,8 @@ void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { // We direct arg partition to sort for now assert(inputs.size() == 1); + out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); auto& d = metal::device(s.device); auto& in = inputs[0]; @@ -341,6 +356,8 @@ void Partition::eval_gpu(const std::vector& inputs, array& out) { // We direct partition to sort for now assert(inputs.size() == 1); + out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); auto& d = metal::device(s.device); auto& in = inputs[0]; diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index e571678d3..ddfceb0a1 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase): self.assertEqual(fy.shape, (4, 5, 6, 7)) def test_leaks(self): + mx.synchronize() if mx.metal.is_available(): mem_pre = mx.get_active_memory() else: