revert sort + flaky test

This commit is contained in:
Awni Hannun 2025-05-16 13:40:51 -07:00
parent 3b169acf50
commit ff1c6fc148
2 changed files with 30 additions and 12 deletions

View File

@ -21,8 +21,6 @@ void single_block_sort(
int bn,
int tn,
bool argsort) {
out.set_data(allocator::malloc(out.nbytes()));
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
@ -158,6 +156,9 @@ void multi_block_sort(
dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes()));
block_partitions.set_data(allocator::malloc(block_partitions.nbytes()));
std::vector<array> copies = {
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
// Prepare command encoder
auto& compute_encoder = d.get_command_encoder(s.index);
@ -249,17 +250,25 @@ void multi_block_sort(
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
}
out.copy_shared_buffer(
argsort ? dev_idxs_out : dev_vals_out,
// Copy outputs with appropriate strides
auto strides = out.strides();
for (int ax = axis + 1; ax < strides.size(); ax++) {
strides[ax] *= out.shape(axis);
}
strides[axis] = 1;
copy_gpu_inplace(
(argsort) ? dev_idxs_out : dev_vals_out,
out,
out.shape(),
strides,
out.strides(),
out.flags(),
out.data_size());
d.add_temporaries(
{dev_vals_in,
dev_idxs_in,
argsort ? dev_vals_in : dev_idxs_in,
block_partitions},
s.index);
0,
0,
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
s);
d.add_temporaries(std::move(copies), s.index);
}
void gpu_merge_sort(
@ -309,6 +318,8 @@ void gpu_merge_sort(
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
@ -319,6 +330,8 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
@ -330,6 +343,8 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
// We direct arg partition to sort for now
assert(inputs.size() == 1);
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
@ -341,6 +356,8 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
// We direct partition to sort for now
assert(inputs.size() == 1);
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];

View File

@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(fy.shape, (4, 5, 6, 7))
def test_leaks(self):
mx.synchronize()
if mx.metal.is_available():
mem_pre = mx.get_active_memory()
else: