mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
revert sort + flaky test
This commit is contained in:
parent
3b169acf50
commit
ff1c6fc148
@ -21,8 +21,6 @@ void single_block_sort(
|
|||||||
int bn,
|
int bn,
|
||||||
int tn,
|
int tn,
|
||||||
bool argsort) {
|
bool argsort) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
|
|
||||||
// Prepare shapes
|
// Prepare shapes
|
||||||
int n_rows = in.size() / in.shape(axis);
|
int n_rows = in.size() / in.shape(axis);
|
||||||
|
|
||||||
@ -158,6 +156,9 @@ void multi_block_sort(
|
|||||||
dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes()));
|
dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes()));
|
||||||
block_partitions.set_data(allocator::malloc(block_partitions.nbytes()));
|
block_partitions.set_data(allocator::malloc(block_partitions.nbytes()));
|
||||||
|
|
||||||
|
std::vector<array> copies = {
|
||||||
|
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
||||||
|
|
||||||
// Prepare command encoder
|
// Prepare command encoder
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|
||||||
@ -249,17 +250,25 @@ void multi_block_sort(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out.copy_shared_buffer(
|
|
||||||
argsort ? dev_idxs_out : dev_vals_out,
|
// Copy outputs with appropriate strides
|
||||||
|
auto strides = out.strides();
|
||||||
|
for (int ax = axis + 1; ax < strides.size(); ax++) {
|
||||||
|
strides[ax] *= out.shape(axis);
|
||||||
|
}
|
||||||
|
strides[axis] = 1;
|
||||||
|
copy_gpu_inplace(
|
||||||
|
(argsort) ? dev_idxs_out : dev_vals_out,
|
||||||
|
out,
|
||||||
|
out.shape(),
|
||||||
|
strides,
|
||||||
out.strides(),
|
out.strides(),
|
||||||
out.flags(),
|
0,
|
||||||
out.data_size());
|
0,
|
||||||
d.add_temporaries(
|
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
|
||||||
{dev_vals_in,
|
s);
|
||||||
dev_idxs_in,
|
|
||||||
argsort ? dev_vals_in : dev_idxs_in,
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
block_partitions},
|
|
||||||
s.index);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void gpu_merge_sort(
|
void gpu_merge_sort(
|
||||||
@ -309,6 +318,8 @@ void gpu_merge_sort(
|
|||||||
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
@ -319,6 +330,8 @@ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
@ -330,6 +343,8 @@ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// We direct arg partition to sort for now
|
// We direct arg partition to sort for now
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
@ -341,6 +356,8 @@ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// We direct partition to sort for now
|
// We direct partition to sort for now
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
@ -634,6 +634,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
self.assertEqual(fy.shape, (4, 5, 6, 7))
|
||||||
|
|
||||||
def test_leaks(self):
|
def test_leaks(self):
|
||||||
|
mx.synchronize()
|
||||||
if mx.metal.is_available():
|
if mx.metal.is_available():
|
||||||
mem_pre = mx.get_active_memory()
|
mem_pre = mx.get_active_memory()
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user