Fix strided sort bug (#1236)

* Use output strides in sort kernel

* fix zero strides bug
This commit is contained in:
Alex Barron 2024-06-26 14:32:11 -07:00 committed by GitHub
parent 5b0af4cdb1
commit 2615660e62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 222 additions and 262 deletions

View File

@ -113,14 +113,14 @@ void sort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape(); auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides(); auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis); remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis]; size_t axis_stride = out.strides()[axis];
int axis_size = in.shape(axis); int axis_size = out.shape(axis);
// Perform sorting in place // Perform sorting in place
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
@ -143,34 +143,42 @@ void argsort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape(); auto in_remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis); in_remaining_shape.erase(in_remaining_shape.begin() + axis);
auto remaining_strides = in.strides(); auto in_remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis); in_remaining_strides.erase(in_remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis]; auto out_remaining_shape = out.shape();
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis); int axis_size = in.shape(axis);
// Perform sorting // Perform sorting
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
const T* data_ptr = in.data<T>() + loc; size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
IdxT* idx_ptr = out.data<IdxT>() + loc; const T* data_ptr = in.data<T>() + in_loc;
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
StridedIterator st_(idx_ptr, axis_stride, 0); StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, axis_stride, axis_size); StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota // Initialize with iota
std::iota(st_, ed_, IdxT(0)); std::iota(st_, ed_, IdxT(0));
// Sort according to vals // Sort according to vals
StridedIterator st(idx_ptr, axis_stride, 0); StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, axis_stride, axis_size); StridedIterator ed(idx_ptr, out_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) { std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * axis_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * axis_stride]; auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }

View File

@ -1,81 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view block_sort_kernels = R"(
template [[host_name("carg_{0}")]] [[kernel]] void
block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("ncarg_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("c_{0}")]] [[kernel]] void
block_sort<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("nc_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view multiblock_sort_kernels = R"(
template [[host_name("sort_{0}")]] [[kernel]] void
mb_block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {1}* out_vals [[buffer(1)]],
device {2}* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("partition_{0}")]] [[kernel]] void
mb_block_partition<{1}, {2}, true, {3}, {4}>(
device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals [[buffer(1)]],
const device {2}* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]);
template [[host_name("merge_{0}")]] [[kernel]] void
mb_block_merge<{1}, {2}, true, {3}, {4}>(
const device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals_in [[buffer(1)]],
const device {2}* dev_idxs_in [[buffer(2)]],
device {1}* dev_vals_out [[buffer(3)]],
device {2}* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";

View File

@ -8,7 +8,6 @@
#include "mlx/backend/metal/jit/reduce.h" #include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h" #include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h" #include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/sort.h"
#include "mlx/backend/metal/jit/steel_conv.h" #include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h" #include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
@ -251,14 +250,29 @@ MTL::ComputePipelineState* get_sort_kernel(
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name);
if (lib == nullptr) { if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort() auto in_type = get_type_string(in.dtype());
<< fmt::format( auto out_type = get_type_string(out.dtype());
block_sort_kernels, kernel_source << metal::utils() << metal::sort();
lib_name, for (bool is_argsort : {true, false}) {
get_type_string(in.dtype()), std::string bool_string = is_argsort ? "true" : "false";
get_type_string(out.dtype()), std::string func_string = is_argsort ? "carg_" : "c_";
kernel_source << get_template_definition(
func_string + lib_name,
"block_sort",
in_type,
out_type,
bool_string,
bn, bn,
tn); tn);
kernel_source << get_template_definition(
"n" + func_string + lib_name,
"block_sort_nc",
in_type,
out_type,
bool_string,
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str()); lib = d.get_library(lib_name, kernel_source.str());
} }
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
@ -275,14 +289,21 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name);
if (lib == nullptr) { if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort() kernel_source << metal::utils() << metal::sort();
<< fmt::format( std::vector<std::pair<std::string, std::string>> kernel_types = {
multiblock_sort_kernels, {"sort_", "mb_block_sort"},
lib_name, {"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}};
for (auto [name, func] : kernel_types) {
kernel_source << get_template_definition(
name + lib_name,
func,
get_type_string(in.dtype()), get_type_string(in.dtype()),
get_type_string(idx.dtype()), get_type_string(idx.dtype()),
"true",
bn, bn,
tn); tn);
}
lib = d.get_library(lib_name, kernel_source.str()); lib = d.get_library(lib_name, kernel_source.str());
} }
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);

View File

@ -235,19 +235,21 @@ struct KernelMergeSort {
const device T* inp, const device T* inp,
device U* out, device U* out,
const constant int& size_sorted_axis, const constant int& size_sorted_axis,
const constant int& stride_sorted_axis, const constant int& in_stride_sorted_axis,
const constant int& stride_segment_axis, const constant int& out_stride_sorted_axis,
const constant int& in_stride_segment_axis,
const constant int& out_stride_segment_axis,
threadgroup val_t* tgp_vals, threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs, threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index // tid.y tells us the segment index
inp += tid.y * stride_segment_axis; inp += tid.y * in_stride_segment_axis;
out += tid.y * stride_segment_axis; out += tid.y * out_stride_segment_axis;
// Copy into threadgroup memory // Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
: val_t(CompareOp::init); : val_t(CompareOp::init);
if (ARG_SORT) { if (ARG_SORT) {
tgp_idxs[i] = i; tgp_idxs[i] = i;
@ -264,9 +266,9 @@ struct KernelMergeSort {
// Write output // Write output
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
if (ARG_SORT) { if (ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i]; out[i * out_stride_sorted_axis] = tgp_idxs[i];
} else { } else {
out[i * stride_sorted_axis] = tgp_vals[i]; out[i * out_stride_sorted_axis] = tgp_vals[i];
} }
} }
} }
@ -282,8 +284,10 @@ template <
const device T* inp [[buffer(0)]], const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]], const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]], const constant int& in_stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]], const constant int& out_stride_sorted_axis [[buffer(4)]],
const constant int& in_stride_segment_axis [[buffer(5)]],
const constant int& out_stride_segment_axis [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = using sort_kernel =
@ -298,8 +302,10 @@ template <
inp, inp,
out, out,
size_sorted_axis, size_sorted_axis,
stride_sorted_axis, in_stride_sorted_axis,
stride_segment_axis, out_stride_sorted_axis,
in_stride_segment_axis,
out_stride_segment_axis,
tgp_vals, tgp_vals,
tgp_idxs, tgp_idxs,
tid, tid,
@ -310,8 +316,10 @@ template <
inp, inp,
out, out,
size_sorted_axis, size_sorted_axis,
stride_sorted_axis, in_stride_sorted_axis,
stride_segment_axis, out_stride_sorted_axis,
in_stride_segment_axis,
out_stride_segment_axis,
tgp_vals, tgp_vals,
nullptr, nullptr,
tid, tid,
@ -331,10 +339,12 @@ template <
const device T* inp [[buffer(0)]], const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]], const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]], const constant int& in_stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]], const constant int& out_stride_sorted_axis [[buffer(4)]],
const device int* nc_shape [[buffer(5)]], const constant int& nc_dim [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]], const device int* nc_shape [[buffer(6)]],
const device size_t* in_nc_strides [[buffer(7)]],
const device size_t* out_nc_strides [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = using sort_kernel =
@ -342,9 +352,10 @@ template <
using val_t = typename sort_kernel::val_t; using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t; using idx_t = typename sort_kernel::idx_t;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
inp += block_idx; auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
out += block_idx; inp += in_block_idx;
out += out_block_idx;
if (ARG_SORT) { if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
@ -353,7 +364,9 @@ template <
inp, inp,
out, out,
size_sorted_axis, size_sorted_axis,
stride_sorted_axis, in_stride_sorted_axis,
out_stride_sorted_axis,
zero_helper,
zero_helper, zero_helper,
tgp_vals, tgp_vals,
tgp_idxs, tgp_idxs,
@ -365,7 +378,9 @@ template <
inp, inp,
out, out,
size_sorted_axis, size_sorted_axis,
stride_sorted_axis, in_stride_sorted_axis,
out_stride_sorted_axis,
zero_helper,
zero_helper, zero_helper,
tgp_vals, tgp_vals,
nullptr, nullptr,

View File

@ -10,28 +10,10 @@
#define instantiate_block_sort( \ #define instantiate_block_sort( \
name, itname, itype, otname, otype, arg_sort, bn, tn) \ name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name("c" #name "_" #itname "_" #otname "_bn" #bn \ instantiate_kernel("c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
"_tn" #tn)]] [[kernel]] void \ block_sort, itype, otype, arg_sort, bn, tn) \
block_sort<itype, otype, arg_sort, bn, tn>( \ instantiate_kernel("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
const device itype* inp [[buffer(0)]], \ block_sort_nc, itype, otype, arg_sort, bn, tn)
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
)]] [[kernel]] void \
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& nc_dim [[buffer(4)]], \
const device int* nc_shape [[buffer(5)]], \
const device size_t* nc_strides [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \ #define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort( \ instantiate_block_sort( \
@ -69,43 +51,12 @@ instantiate_block_sort_long(int64, int64_t)
#define instantiate_multi_block_sort( \ #define instantiate_multi_block_sort( \
vtname, vtype, itname, itype, arg_sort, bn, tn) \ vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("sort_mbsort_" #vtname "_" #itname "_bn" #bn \ instantiate_kernel("sort_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
"_tn" #tn)]] [[kernel]] void \ mb_block_sort, vtype, itype, arg_sort, bn, tn) \
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \ instantiate_kernel("partition_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
const device vtype* inp [[buffer(0)]], \ mb_block_partition, vtype, itype, arg_sort, bn, tn) \
device vtype* out_vals [[buffer(1)]], \ instantiate_kernel("merge_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
device itype* out_idxs [[buffer(2)]], \ mb_block_merge, vtype, itype, arg_sort, bn, tn)
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& stride_sorted_axis [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("partition_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype * block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& merge_tiles [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("merge_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals_in [[buffer(1)]], \
const device itype* dev_idxs_in [[buffer(2)]], \
device vtype* dev_vals_out [[buffer(3)]], \
device itype* dev_idxs_out [[buffer(4)]], \
const constant int& size_sorted_axis [[buffer(5)]], \
const constant int& merge_tiles [[buffer(6)]], \
const constant int& num_tiles [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_multi_block_sort_base(vtname, vtype) \ #define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8) instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)

View File

@ -24,8 +24,11 @@ void single_block_sort(
// Prepare shapes // Prepare shapes
int n_rows = in.size() / in.shape(axis); int n_rows = in.size() / in.shape(axis);
std::vector<size_t> nc_str = in.strides(); std::vector<size_t> in_nc_str = in.strides();
nc_str.erase(nc_str.begin() + axis); in_nc_str.erase(in_nc_str.begin() + axis);
std::vector<size_t> out_nc_str = out.strides();
out_nc_str.erase(out_nc_str.begin() + axis);
std::vector<int> nc_shape = in.shape(); std::vector<int> nc_shape = in.shape();
nc_shape.erase(nc_shape.begin() + axis); nc_shape.erase(nc_shape.begin() + axis);
@ -33,21 +36,28 @@ void single_block_sort(
int nc_dim = nc_shape.size(); int nc_dim = nc_shape.size();
int size_sorted_axis = in.shape(axis); int size_sorted_axis = in.shape(axis);
int stride_sorted_axis = in.strides()[axis]; int in_stride_sorted_axis = in.strides()[axis];
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end()); int out_stride_sorted_axis = out.strides()[axis];
int in_stride_segment_axis =
*std::min_element(in_nc_str.begin(), in_nc_str.end());
int out_stride_segment_axis =
*std::min_element(out_nc_str.begin(), out_nc_str.end());
// Check if remaining strides are contiguous // We can only use the contiguous kernel if the sorted axis
bool contiguous_write = true; // has the largest or smallest stride.
if (axis != in.ndim() - 1 && axis != 0) { // We also need the input to be contiguous
for (int i = 0; i < nc_str.size() - 1; ++i) { bool contiguous = in.flags().contiguous;
size_t expected = nc_str[i + 1] * nc_str[i + 1]; auto check_strides = [](array x, int sort_stride) {
contiguous_write &= (nc_str[i] == expected); int min_stride = *std::min_element(x.strides().begin(), x.strides().end());
} int max_stride = *std::max_element(x.strides().begin(), x.strides().end());
} return sort_stride == min_stride || sort_stride == max_stride;
};
contiguous &= check_strides(in, in_stride_sorted_axis);
contiguous &= check_strides(out, out_stride_sorted_axis);
// Prepare kernel name // Prepare kernel name
std::ostringstream kname; std::ostringstream kname;
kname << (contiguous_write ? "c" : "nc"); kname << (contiguous ? "c" : "nc");
if (argsort) { if (argsort) {
kname << "arg"; kname << "arg";
} }
@ -64,14 +74,17 @@ void single_block_sort(
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3); compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4);
if (contiguous_write) { if (contiguous) {
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4); compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5);
compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6);
} else { } else {
compute_encoder->setBytes(&nc_dim, sizeof(int), 4); compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5); compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6); compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7);
compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8);
} }
MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size group_dims = MTL::Size(bn, 1, 1);

View File

@ -3,7 +3,7 @@
import math import math
import os import os
import unittest import unittest
from itertools import permutations from itertools import permutations, product
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
@ -1751,14 +1751,21 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1)) self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1))
def test_sort(self): def test_sort(self):
shape = (3, 4, 5) shape = (6, 4, 10)
for dtype in ("int32", "float32"): tests = product(
for axis in (None, 0, 1, 2): ("int32", "float32"), # type
with self.subTest(dtype=dtype, axis=axis): (None, 0, 1, 2), # axis
(True, False), # strided
)
for dtype, axis, strided in tests:
with self.subTest(dtype=dtype, axis=axis, strided=strided):
np.random.seed(0) np.random.seed(0)
np_dtype = getattr(np, dtype) np_dtype = getattr(np, dtype)
a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype) a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype)
a_mx = mx.array(a_np) a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[::2, :, ::2]
a_np = a_np[::2, :, ::2]
b_np = np.sort(a_np, axis=axis) b_np = np.sort(a_np, axis=axis)
b_mx = mx.sort(a_mx, axis=axis) b_mx = mx.sort(a_mx, axis=axis)
@ -1778,9 +1785,15 @@ class TestOps(mlx_tests.MLXTestCase):
np.random.seed(0) np.random.seed(0)
# Test multi-block sort # Test multi-block sort
for strided in (False, True):
with self.subTest(strided=strided):
a_np = np.random.normal(size=(32769,)).astype(np.float32) a_np = np.random.normal(size=(32769,)).astype(np.float32)
a_mx = mx.array(a_np) a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[::3]
a_np = a_np[::3]
b_np = np.sort(a_np) b_np = np.sort(a_np)
b_mx = mx.sort(a_mx) b_mx = mx.sort(a_mx)
@ -1791,6 +1804,10 @@ class TestOps(mlx_tests.MLXTestCase):
a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32) a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32)
a_mx = mx.array(a_np) a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[..., ::3]
a_np = a_np[..., ::3]
b_np = np.sort(a_np, axis=-1) b_np = np.sort(a_np, axis=-1)
b_mx = mx.sort(a_mx, axis=-1) b_mx = mx.sort(a_mx, axis=-1)
@ -1800,12 +1817,28 @@ class TestOps(mlx_tests.MLXTestCase):
a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32) a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32)
a_mx = mx.array(a_np) a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[:, ::3]
a_np = a_np[:, ::3]
b_np = np.sort(a_np, axis=1) b_np = np.sort(a_np, axis=1)
b_mx = mx.sort(a_mx, axis=1) b_mx = mx.sort(a_mx, axis=1)
self.assertTrue(np.array_equal(b_np, b_mx)) self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype) self.assertEqual(b_mx.dtype, a_mx.dtype)
# test 0 strides
a_np = np.array([1, 0, 2, 1, 3, 0, 4, 0])
a_mx = mx.array(a_np)
b_np = np.broadcast_to(a_np, (16, 8))
b_mx = mx.broadcast_to(a_mx, (16, 8))
mx.eval(b_mx)
for axis in (0, 1):
c_np = np.sort(b_np, axis=axis)
c_mx = mx.sort(b_mx, axis=axis)
self.assertTrue(np.array_equal(c_np, c_mx))
self.assertEqual(b_mx.dtype, c_mx.dtype)
def test_partition(self): def test_partition(self):
shape = (3, 4, 5) shape = (3, 4, 5)
for dtype in ("int32", "float32"): for dtype in ("int32", "float32"):