diff --git a/mlx/backend/common/sort.cpp b/mlx/backend/common/sort.cpp index edeee065b..d7f4895bf 100644 --- a/mlx/backend/common/sort.cpp +++ b/mlx/backend/common/sort.cpp @@ -113,14 +113,14 @@ void sort(const array& in, array& out, int axis) { axis = axis < 0 ? axis + in.ndim() : axis; size_t n_rows = in.size() / in.shape(axis); - auto remaining_shape = in.shape(); + auto remaining_shape = out.shape(); remaining_shape.erase(remaining_shape.begin() + axis); - auto remaining_strides = in.strides(); + auto remaining_strides = out.strides(); remaining_strides.erase(remaining_strides.begin() + axis); - size_t axis_stride = in.strides()[axis]; - int axis_size = in.shape(axis); + size_t axis_stride = out.strides()[axis]; + int axis_size = out.shape(axis); // Perform sorting in place for (int i = 0; i < n_rows; i++) { @@ -143,34 +143,42 @@ void argsort(const array& in, array& out, int axis) { axis = axis < 0 ? axis + in.ndim() : axis; size_t n_rows = in.size() / in.shape(axis); - auto remaining_shape = in.shape(); - remaining_shape.erase(remaining_shape.begin() + axis); + auto in_remaining_shape = in.shape(); + in_remaining_shape.erase(in_remaining_shape.begin() + axis); - auto remaining_strides = in.strides(); - remaining_strides.erase(remaining_strides.begin() + axis); + auto in_remaining_strides = in.strides(); + in_remaining_strides.erase(in_remaining_strides.begin() + axis); - size_t axis_stride = in.strides()[axis]; + auto out_remaining_shape = out.shape(); + out_remaining_shape.erase(out_remaining_shape.begin() + axis); + + auto out_remaining_strides = out.strides(); + out_remaining_strides.erase(out_remaining_strides.begin() + axis); + + size_t in_stride = in.strides()[axis]; + size_t out_stride = out.strides()[axis]; int axis_size = in.shape(axis); // Perform sorting for (int i = 0; i < n_rows; i++) { - size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); - const T* data_ptr = in.data() + loc; - IdxT* idx_ptr = out.data() + loc; + size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides); + size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides); + const T* data_ptr = in.data() + in_loc; + IdxT* idx_ptr = out.data() + out_loc; - StridedIterator st_(idx_ptr, axis_stride, 0); - StridedIterator ed_(idx_ptr, axis_stride, axis_size); + StridedIterator st_(idx_ptr, out_stride, 0); + StridedIterator ed_(idx_ptr, out_stride, axis_size); // Initialize with iota std::iota(st_, ed_, IdxT(0)); // Sort according to vals - StridedIterator st(idx_ptr, axis_stride, 0); - StridedIterator ed(idx_ptr, axis_stride, axis_size); + StridedIterator st(idx_ptr, out_stride, 0); + StridedIterator ed(idx_ptr, out_stride, axis_size); - std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) { - auto v1 = data_ptr[a * axis_stride]; - auto v2 = data_ptr[b * axis_stride]; + std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { + auto v1 = data_ptr[a * in_stride]; + auto v2 = data_ptr[b * in_stride]; return v1 < v2 || (v1 == v2 && a < b); }); } diff --git a/mlx/backend/metal/jit/sort.h b/mlx/backend/metal/jit/sort.h deleted file mode 100644 index 023822f48..000000000 --- a/mlx/backend/metal/jit/sort.h +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view block_sort_kernels = R"( -template [[host_name("carg_{0}")]] [[kernel]] void -block_sort<{1}, {2}, true, {3}, {4}>( - const device {1}* inp [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant int& size_sorted_axis [[buffer(2)]], - const constant int& stride_sorted_axis [[buffer(3)]], - const constant int& stride_segment_axis [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]); -template [[host_name("ncarg_{0}")]] [[kernel]] void -block_sort_nc<{1}, {2}, true, {3}, {4}>( - const device {1}* inp [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant int& size_sorted_axis [[buffer(2)]], - const constant int& stride_sorted_axis [[buffer(3)]], - const constant int& nc_dim [[buffer(4)]], - const device int* nc_shape [[buffer(5)]], - const device size_t* nc_strides [[buffer(6)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]); -template [[host_name("c_{0}")]] [[kernel]] void -block_sort<{1}, {2}, false, {3}, {4}>( - const device {1}* inp [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant int& size_sorted_axis [[buffer(2)]], - const constant int& stride_sorted_axis [[buffer(3)]], - const constant int& stride_segment_axis [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]); -template [[host_name("nc_{0}")]] [[kernel]] void -block_sort_nc<{1}, {2}, false, {3}, {4}>( - const device {1}* inp [[buffer(0)]], - device {2}* out [[buffer(1)]], - const constant int& size_sorted_axis [[buffer(2)]], - const constant int& stride_sorted_axis [[buffer(3)]], - const constant int& nc_dim [[buffer(4)]], - const device int* nc_shape [[buffer(5)]], - const device size_t* nc_strides [[buffer(6)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]); -)"; - -constexpr std::string_view multiblock_sort_kernels = R"( -template [[host_name("sort_{0}")]] [[kernel]] void -mb_block_sort<{1}, {2}, true, {3}, {4}>( - const device {1}* inp [[buffer(0)]], - device {1}* out_vals [[buffer(1)]], - device {2}* out_idxs [[buffer(2)]], - const constant int& size_sorted_axis [[buffer(3)]], - const constant int& stride_sorted_axis [[buffer(4)]], - const constant int& nc_dim [[buffer(5)]], - const device int* nc_shape [[buffer(6)]], - const device size_t* nc_strides [[buffer(7)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]); -template [[host_name("partition_{0}")]] [[kernel]] void -mb_block_partition<{1}, {2}, true, {3}, {4}>( - device {2}* block_partitions [[buffer(0)]], - const device {1}* dev_vals [[buffer(1)]], - const device {2}* dev_idxs [[buffer(2)]], - const constant int& size_sorted_axis [[buffer(3)]], - const constant int& merge_tiles [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 tgp_dims [[threads_per_threadgroup]]); -template [[host_name("merge_{0}")]] [[kernel]] void -mb_block_merge<{1}, {2}, true, {3}, {4}>( - const device {2}* block_partitions [[buffer(0)]], - const device {1}* dev_vals_in [[buffer(1)]], - const device {2}* dev_idxs_in [[buffer(2)]], - device {1}* dev_vals_out [[buffer(3)]], - device {2}* dev_idxs_out [[buffer(4)]], - const constant int& size_sorted_axis [[buffer(5)]], - const constant int& merge_tiles [[buffer(6)]], - const constant int& num_tiles [[buffer(7)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]); -)"; diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 6791c6685..12e843618 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -8,7 +8,6 @@ #include "mlx/backend/metal/jit/reduce.h" #include "mlx/backend/metal/jit/scan.h" #include "mlx/backend/metal/jit/softmax.h" -#include "mlx/backend/metal/jit/sort.h" #include "mlx/backend/metal/jit/steel_conv.h" #include "mlx/backend/metal/jit/steel_gemm.h" #include "mlx/backend/metal/kernels.h" @@ -251,14 +250,29 @@ MTL::ComputePipelineState* get_sort_kernel( auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::sort() - << fmt::format( - block_sort_kernels, - lib_name, - get_type_string(in.dtype()), - get_type_string(out.dtype()), - bn, - tn); + auto in_type = get_type_string(in.dtype()); + auto out_type = get_type_string(out.dtype()); + kernel_source << metal::utils() << metal::sort(); + for (bool is_argsort : {true, false}) { + std::string bool_string = is_argsort ? "true" : "false"; + std::string func_string = is_argsort ? "carg_" : "c_"; + kernel_source << get_template_definition( + func_string + lib_name, + "block_sort", + in_type, + out_type, + bool_string, + bn, + tn); + kernel_source << get_template_definition( + "n" + func_string + lib_name, + "block_sort_nc", + in_type, + out_type, + bool_string, + bn, + tn); + } lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); @@ -275,14 +289,21 @@ MTL::ComputePipelineState* get_mb_sort_kernel( auto lib = d.get_library(lib_name); if (lib == nullptr) { std::ostringstream kernel_source; - kernel_source << metal::utils() << metal::sort() - << fmt::format( - multiblock_sort_kernels, - lib_name, - get_type_string(in.dtype()), - get_type_string(idx.dtype()), - bn, - tn); + kernel_source << metal::utils() << metal::sort(); + std::vector> kernel_types = { + {"sort_", "mb_block_sort"}, + {"partition_", "mb_block_partition"}, + {"merge_", "mb_block_merge"}}; + for (auto [name, func] : kernel_types) { + kernel_source << get_template_definition( + name + lib_name, + func, + get_type_string(in.dtype()), + get_type_string(idx.dtype()), + "true", + bn, + tn); + } lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index 5b1cf2675..dca5106de 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -235,19 +235,21 @@ struct KernelMergeSort { const device T* inp, device U* out, const constant int& size_sorted_axis, - const constant int& stride_sorted_axis, - const constant int& stride_segment_axis, + const constant int& in_stride_sorted_axis, + const constant int& out_stride_sorted_axis, + const constant int& in_stride_segment_axis, + const constant int& out_stride_segment_axis, threadgroup val_t* tgp_vals, threadgroup idx_t* tgp_idxs, uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // tid.y tells us the segment index - inp += tid.y * stride_segment_axis; - out += tid.y * stride_segment_axis; + inp += tid.y * in_stride_segment_axis; + out += tid.y * out_stride_segment_axis; // Copy into threadgroup memory for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { - tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] + tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis] : val_t(CompareOp::init); if (ARG_SORT) { tgp_idxs[i] = i; @@ -264,9 +266,9 @@ struct KernelMergeSort { // Write output for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { if (ARG_SORT) { - out[i * stride_sorted_axis] = tgp_idxs[i]; + out[i * out_stride_sorted_axis] = tgp_idxs[i]; } else { - out[i * stride_sorted_axis] = tgp_vals[i]; + out[i * out_stride_sorted_axis] = tgp_vals[i]; } } } @@ -282,8 +284,10 @@ template < const device T* inp [[buffer(0)]], device U* out [[buffer(1)]], const constant int& size_sorted_axis [[buffer(2)]], - const constant int& stride_sorted_axis [[buffer(3)]], - const constant int& stride_segment_axis [[buffer(4)]], + const constant int& in_stride_sorted_axis [[buffer(3)]], + const constant int& out_stride_sorted_axis [[buffer(4)]], + const constant int& in_stride_segment_axis [[buffer(5)]], + const constant int& out_stride_segment_axis [[buffer(6)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = @@ -298,8 +302,10 @@ template < inp, out, size_sorted_axis, - stride_sorted_axis, - stride_segment_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, tgp_vals, tgp_idxs, tid, @@ -310,8 +316,10 @@ template < inp, out, size_sorted_axis, - stride_sorted_axis, - stride_segment_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, tgp_vals, nullptr, tid, @@ -331,10 +339,12 @@ template < const device T* inp [[buffer(0)]], device U* out [[buffer(1)]], const constant int& size_sorted_axis [[buffer(2)]], - const constant int& stride_sorted_axis [[buffer(3)]], - const constant int& nc_dim [[buffer(4)]], - const device int* nc_shape [[buffer(5)]], - const device size_t* nc_strides [[buffer(6)]], + const constant int& in_stride_sorted_axis [[buffer(3)]], + const constant int& out_stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const device int* nc_shape [[buffer(6)]], + const device size_t* in_nc_strides [[buffer(7)]], + const device size_t* out_nc_strides [[buffer(8)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = @@ -342,9 +352,10 @@ template < using val_t = typename sort_kernel::val_t; using idx_t = typename sort_kernel::idx_t; - auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); - inp += block_idx; - out += block_idx; + auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim); + auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim); + inp += in_block_idx; + out += out_block_idx; if (ARG_SORT) { threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; @@ -353,7 +364,9 @@ template < inp, out, size_sorted_axis, - stride_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + zero_helper, zero_helper, tgp_vals, tgp_idxs, @@ -365,7 +378,9 @@ template < inp, out, size_sorted_axis, - stride_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + zero_helper, zero_helper, tgp_vals, nullptr, diff --git a/mlx/backend/metal/kernels/sort.metal b/mlx/backend/metal/kernels/sort.metal index ad13cc835..e0d9b6c69 100644 --- a/mlx/backend/metal/kernels/sort.metal +++ b/mlx/backend/metal/kernels/sort.metal @@ -10,28 +10,10 @@ #define instantiate_block_sort( \ name, itname, itype, otname, otype, arg_sort, bn, tn) \ - template [[host_name("c" #name "_" #itname "_" #otname "_bn" #bn \ - "_tn" #tn)]] [[kernel]] void \ - block_sort( \ - const device itype* inp [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant int& size_sorted_axis [[buffer(2)]], \ - const constant int& stride_sorted_axis [[buffer(3)]], \ - const constant int& stride_segment_axis [[buffer(4)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); \ - template [[host_name("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \ - )]] [[kernel]] void \ - block_sort_nc( \ - const device itype* inp [[buffer(0)]], \ - device otype* out [[buffer(1)]], \ - const constant int& size_sorted_axis [[buffer(2)]], \ - const constant int& stride_sorted_axis [[buffer(3)]], \ - const constant int& nc_dim [[buffer(4)]], \ - const device int* nc_shape [[buffer(5)]], \ - const device size_t* nc_strides [[buffer(6)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); + instantiate_kernel("c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \ + block_sort, itype, otype, arg_sort, bn, tn) \ + instantiate_kernel("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \ + block_sort_nc, itype, otype, arg_sort, bn, tn) #define instantiate_arg_block_sort_base(itname, itype, bn, tn) \ instantiate_block_sort( \ @@ -69,43 +51,12 @@ instantiate_block_sort_long(int64, int64_t) #define instantiate_multi_block_sort( \ vtname, vtype, itname, itype, arg_sort, bn, tn) \ - template [[host_name("sort_mbsort_" #vtname "_" #itname "_bn" #bn \ - "_tn" #tn)]] [[kernel]] void \ - mb_block_sort( \ - const device vtype* inp [[buffer(0)]], \ - device vtype* out_vals [[buffer(1)]], \ - device itype* out_idxs [[buffer(2)]], \ - const constant int& size_sorted_axis [[buffer(3)]], \ - const constant int& stride_sorted_axis [[buffer(4)]], \ - const constant int& nc_dim [[buffer(5)]], \ - const device int* nc_shape [[buffer(6)]], \ - const device size_t* nc_strides [[buffer(7)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); \ - template [[host_name("partition_mbsort_" #vtname "_" #itname "_bn" #bn \ - "_tn" #tn)]] [[kernel]] void \ - mb_block_partition( \ - device itype * block_partitions [[buffer(0)]], \ - const device vtype* dev_vals [[buffer(1)]], \ - const device itype* dev_idxs [[buffer(2)]], \ - const constant int& size_sorted_axis [[buffer(3)]], \ - const constant int& merge_tiles [[buffer(4)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 tgp_dims [[threads_per_threadgroup]]); \ - template [[host_name("merge_mbsort_" #vtname "_" #itname "_bn" #bn \ - "_tn" #tn)]] [[kernel]] void \ - mb_block_merge( \ - const device itype* block_partitions [[buffer(0)]], \ - const device vtype* dev_vals_in [[buffer(1)]], \ - const device itype* dev_idxs_in [[buffer(2)]], \ - device vtype* dev_vals_out [[buffer(3)]], \ - device itype* dev_idxs_out [[buffer(4)]], \ - const constant int& size_sorted_axis [[buffer(5)]], \ - const constant int& merge_tiles [[buffer(6)]], \ - const constant int& num_tiles [[buffer(7)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); + instantiate_kernel("sort_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \ + mb_block_sort, vtype, itype, arg_sort, bn, tn) \ + instantiate_kernel("partition_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \ + mb_block_partition, vtype, itype, arg_sort, bn, tn) \ + instantiate_kernel("merge_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \ + mb_block_merge, vtype, itype, arg_sort, bn, tn) #define instantiate_multi_block_sort_base(vtname, vtype) \ instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8) diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index f4d6e63bd..457232654 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -24,8 +24,11 @@ void single_block_sort( // Prepare shapes int n_rows = in.size() / in.shape(axis); - std::vector nc_str = in.strides(); - nc_str.erase(nc_str.begin() + axis); + std::vector in_nc_str = in.strides(); + in_nc_str.erase(in_nc_str.begin() + axis); + + std::vector out_nc_str = out.strides(); + out_nc_str.erase(out_nc_str.begin() + axis); std::vector nc_shape = in.shape(); nc_shape.erase(nc_shape.begin() + axis); @@ -33,21 +36,28 @@ void single_block_sort( int nc_dim = nc_shape.size(); int size_sorted_axis = in.shape(axis); - int stride_sorted_axis = in.strides()[axis]; - int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end()); + int in_stride_sorted_axis = in.strides()[axis]; + int out_stride_sorted_axis = out.strides()[axis]; + int in_stride_segment_axis = + *std::min_element(in_nc_str.begin(), in_nc_str.end()); + int out_stride_segment_axis = + *std::min_element(out_nc_str.begin(), out_nc_str.end()); - // Check if remaining strides are contiguous - bool contiguous_write = true; - if (axis != in.ndim() - 1 && axis != 0) { - for (int i = 0; i < nc_str.size() - 1; ++i) { - size_t expected = nc_str[i + 1] * nc_str[i + 1]; - contiguous_write &= (nc_str[i] == expected); - } - } + // We can only use the contiguous kernel if the sorted axis + // has the largest or smallest stride. + // We also need the input to be contiguous + bool contiguous = in.flags().contiguous; + auto check_strides = [](array x, int sort_stride) { + int min_stride = *std::min_element(x.strides().begin(), x.strides().end()); + int max_stride = *std::max_element(x.strides().begin(), x.strides().end()); + return sort_stride == min_stride || sort_stride == max_stride; + }; + contiguous &= check_strides(in, in_stride_sorted_axis); + contiguous &= check_strides(out, out_stride_sorted_axis); // Prepare kernel name std::ostringstream kname; - kname << (contiguous_write ? "c" : "nc"); + kname << (contiguous ? "c" : "nc"); if (argsort) { kname << "arg"; } @@ -64,14 +74,17 @@ void single_block_sort( compute_encoder.set_input_array(in, 0); compute_encoder.set_output_array(out, 1); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2); - compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3); + compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3); + compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4); - if (contiguous_write) { - compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4); + if (contiguous) { + compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5); + compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6); } else { - compute_encoder->setBytes(&nc_dim, sizeof(int), 4); - compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5); - compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6); + compute_encoder->setBytes(&nc_dim, sizeof(int), 5); + compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6); + compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7); + compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8); } MTL::Size group_dims = MTL::Size(bn, 1, 1); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index c41c79c83..dace9c9aa 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3,7 +3,7 @@ import math import os import unittest -from itertools import permutations +from itertools import permutations, product import mlx.core as mx import mlx_tests @@ -1751,60 +1751,93 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1)) def test_sort(self): - shape = (3, 4, 5) - for dtype in ("int32", "float32"): - for axis in (None, 0, 1, 2): - with self.subTest(dtype=dtype, axis=axis): - np.random.seed(0) - np_dtype = getattr(np, dtype) - a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype) - a_mx = mx.array(a_np) + shape = (6, 4, 10) + tests = product( + ("int32", "float32"), # type + (None, 0, 1, 2), # axis + (True, False), # strided + ) + for dtype, axis, strided in tests: + with self.subTest(dtype=dtype, axis=axis, strided=strided): + np.random.seed(0) + np_dtype = getattr(np, dtype) + a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype) + a_mx = mx.array(a_np) + if strided: + a_mx = a_mx[::2, :, ::2] + a_np = a_np[::2, :, ::2] - b_np = np.sort(a_np, axis=axis) - b_mx = mx.sort(a_mx, axis=axis) + b_np = np.sort(a_np, axis=axis) + b_mx = mx.sort(a_mx, axis=axis) - self.assertTrue(np.array_equal(b_np, b_mx)) - self.assertEqual(b_mx.dtype, a_mx.dtype) + self.assertTrue(np.array_equal(b_np, b_mx)) + self.assertEqual(b_mx.dtype, a_mx.dtype) - c_np = np.argsort(a_np, axis=axis) - c_mx = mx.argsort(a_mx, axis=axis) - d_np = np.take_along_axis(a_np, c_np, axis=axis) - d_mx = mx.take_along_axis(a_mx, c_mx, axis=axis) + c_np = np.argsort(a_np, axis=axis) + c_mx = mx.argsort(a_mx, axis=axis) + d_np = np.take_along_axis(a_np, c_np, axis=axis) + d_mx = mx.take_along_axis(a_mx, c_mx, axis=axis) - self.assertTrue(np.array_equal(d_np, d_mx)) - self.assertEqual(c_mx.dtype, mx.uint32) + self.assertTrue(np.array_equal(d_np, d_mx)) + self.assertEqual(c_mx.dtype, mx.uint32) # Set random seed np.random.seed(0) # Test multi-block sort - a_np = np.random.normal(size=(32769,)).astype(np.float32) + for strided in (False, True): + with self.subTest(strided=strided): + a_np = np.random.normal(size=(32769,)).astype(np.float32) + a_mx = mx.array(a_np) + + if strided: + a_mx = a_mx[::3] + a_np = a_np[::3] + + b_np = np.sort(a_np) + b_mx = mx.sort(a_mx) + + self.assertTrue(np.array_equal(b_np, b_mx)) + self.assertEqual(b_mx.dtype, a_mx.dtype) + + # Test multi-dum multi-block sort + a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32) + a_mx = mx.array(a_np) + + if strided: + a_mx = a_mx[..., ::3] + a_np = a_np[..., ::3] + + b_np = np.sort(a_np, axis=-1) + b_mx = mx.sort(a_mx, axis=-1) + + self.assertTrue(np.array_equal(b_np, b_mx)) + self.assertEqual(b_mx.dtype, a_mx.dtype) + + a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32) + a_mx = mx.array(a_np) + + if strided: + a_mx = a_mx[:, ::3] + a_np = a_np[:, ::3] + + b_np = np.sort(a_np, axis=1) + b_mx = mx.sort(a_mx, axis=1) + + self.assertTrue(np.array_equal(b_np, b_mx)) + self.assertEqual(b_mx.dtype, a_mx.dtype) + + # test 0 strides + a_np = np.array([1, 0, 2, 1, 3, 0, 4, 0]) a_mx = mx.array(a_np) - - b_np = np.sort(a_np) - b_mx = mx.sort(a_mx) - - self.assertTrue(np.array_equal(b_np, b_mx)) - self.assertEqual(b_mx.dtype, a_mx.dtype) - - # Test multi-dum multi-block sort - a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32) - a_mx = mx.array(a_np) - - b_np = np.sort(a_np, axis=-1) - b_mx = mx.sort(a_mx, axis=-1) - - self.assertTrue(np.array_equal(b_np, b_mx)) - self.assertEqual(b_mx.dtype, a_mx.dtype) - - a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32) - a_mx = mx.array(a_np) - - b_np = np.sort(a_np, axis=1) - b_mx = mx.sort(a_mx, axis=1) - - self.assertTrue(np.array_equal(b_np, b_mx)) - self.assertEqual(b_mx.dtype, a_mx.dtype) + b_np = np.broadcast_to(a_np, (16, 8)) + b_mx = mx.broadcast_to(a_mx, (16, 8)) + mx.eval(b_mx) + for axis in (0, 1): + c_np = np.sort(b_np, axis=axis) + c_mx = mx.sort(b_mx, axis=axis) + self.assertTrue(np.array_equal(c_np, c_mx)) + self.assertEqual(b_mx.dtype, c_mx.dtype) def test_partition(self): shape = (3, 4, 5)