mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
revert sort + flaky test
This commit is contained in:
parent
3b169acf50
commit
ff1c6fc148
@ -21,8 +21,6 @@ void single_block_sort(
|
||||
int bn,
|
||||
int tn,
|
||||
bool argsort) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
@ -158,6 +156,9 @@ void multi_block_sort(
|
||||
dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes()));
|
||||
block_partitions.set_data(allocator::malloc(block_partitions.nbytes()));
|
||||
|
||||
std::vector<array> copies = {
|
||||
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
||||
|
||||
// Prepare command encoder
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
@ -249,17 +250,25 @@ void multi_block_sort(
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(
|
||||
argsort ? dev_idxs_out : dev_vals_out,
|
||||
|
||||
// Copy outputs with appropriate strides
|
||||
auto strides = out.strides();
|
||||
for (int ax = axis + 1; ax < strides.size(); ax++) {
|
||||
strides[ax] *= out.shape(axis);
|
||||
}
|
||||
strides[axis] = 1;
|
||||
copy_gpu_inplace(
|
||||
(argsort) ? dev_idxs_out : dev_vals_out,
|
||||
out,
|
||||
out.shape(),
|
||||
strides,
|
||||
out.strides(),
|
||||
out.flags(),
|
||||
out.data_size());
|
||||
d.add_temporaries(
|
||||
{dev_vals_in,
|
||||
dev_idxs_in,
|
||||
argsort ? dev_vals_in : dev_idxs_in,
|
||||
block_partitions},
|
||||
s.index);
|
||||
0,
|
||||
0,
|
||||
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void gpu_merge_sort(
|
||||
@ -309,6 +318,8 @@ void gpu_merge_sort(
|
||||
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
@ -319,6 +330,8 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
@ -330,6 +343,8 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We direct arg partition to sort for now
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
@ -341,6 +356,8 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We direct partition to sort for now
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
||||
|
||||
def test_leaks(self):
|
||||
mx.synchronize()
|
||||
if mx.metal.is_available():
|
||||
mem_pre = mx.get_active_memory()
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user