mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix strided sort bug (#1236)
* Use output strides in sort kernel * fix zero strides bug
This commit is contained in:
parent
5b0af4cdb1
commit
2615660e62
@ -113,14 +113,14 @@ void sort(const array& in, array& out, int axis) {
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
auto remaining_strides = out.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
size_t axis_stride = out.strides()[axis];
|
||||
int axis_size = out.shape(axis);
|
||||
|
||||
// Perform sorting in place
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
@ -143,34 +143,42 @@ void argsort(const array& in, array& out, int axis) {
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
auto in_remaining_shape = in.shape();
|
||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
auto in_remaining_strides = in.strides();
|
||||
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
auto out_remaining_shape = out.shape();
|
||||
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
|
||||
|
||||
auto out_remaining_strides = out.strides();
|
||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
||||
|
||||
size_t in_stride = in.strides()[axis];
|
||||
size_t out_stride = out.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
// Perform sorting
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + loc;
|
||||
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
|
||||
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + in_loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
|
||||
|
||||
StridedIterator st_(idx_ptr, axis_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, axis_stride, 0);
|
||||
StridedIterator ed(idx_ptr, axis_stride, axis_size);
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * axis_stride];
|
||||
auto v2 = data_ptr[b * axis_stride];
|
||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
|
@ -1,81 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view block_sort_kernels = R"(
|
||||
template [[host_name("carg_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("ncarg_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("c_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("nc_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view multiblock_sort_kernels = R"(
|
||||
template [[host_name("sort_{0}")]] [[kernel]] void
|
||||
mb_block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {1}* out_vals [[buffer(1)]],
|
||||
device {2}* out_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("partition_{0}")]] [[kernel]] void
|
||||
mb_block_partition<{1}, {2}, true, {3}, {4}>(
|
||||
device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals [[buffer(1)]],
|
||||
const device {2}* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]);
|
||||
template [[host_name("merge_{0}")]] [[kernel]] void
|
||||
mb_block_merge<{1}, {2}, true, {3}, {4}>(
|
||||
const device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals_in [[buffer(1)]],
|
||||
const device {2}* dev_idxs_in [[buffer(2)]],
|
||||
device {1}* dev_vals_out [[buffer(3)]],
|
||||
device {2}* dev_idxs_out [[buffer(4)]],
|
||||
const constant int& size_sorted_axis [[buffer(5)]],
|
||||
const constant int& merge_tiles [[buffer(6)]],
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
@ -8,7 +8,6 @@
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/sort.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@ -251,14 +250,29 @@ MTL::ComputePipelineState* get_sort_kernel(
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
block_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source << metal::utils() << metal::sort();
|
||||
for (bool is_argsort : {true, false}) {
|
||||
std::string bool_string = is_argsort ? "true" : "false";
|
||||
std::string func_string = is_argsort ? "carg_" : "c_";
|
||||
kernel_source << get_template_definition(
|
||||
func_string + lib_name,
|
||||
"block_sort",
|
||||
in_type,
|
||||
out_type,
|
||||
bool_string,
|
||||
bn,
|
||||
tn);
|
||||
kernel_source << get_template_definition(
|
||||
"n" + func_string + lib_name,
|
||||
"block_sort_nc",
|
||||
in_type,
|
||||
out_type,
|
||||
bool_string,
|
||||
bn,
|
||||
tn);
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@ -275,14 +289,21 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
multiblock_sort_kernels,
|
||||
lib_name,
|
||||
kernel_source << metal::utils() << metal::sort();
|
||||
std::vector<std::pair<std::string, std::string>> kernel_types = {
|
||||
{"sort_", "mb_block_sort"},
|
||||
{"partition_", "mb_block_partition"},
|
||||
{"merge_", "mb_block_merge"}};
|
||||
for (auto [name, func] : kernel_types) {
|
||||
kernel_source << get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(idx.dtype()),
|
||||
"true",
|
||||
bn,
|
||||
tn);
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
@ -235,19 +235,21 @@ struct KernelMergeSort {
|
||||
const device T* inp,
|
||||
device U* out,
|
||||
const constant int& size_sorted_axis,
|
||||
const constant int& stride_sorted_axis,
|
||||
const constant int& stride_segment_axis,
|
||||
const constant int& in_stride_sorted_axis,
|
||||
const constant int& out_stride_sorted_axis,
|
||||
const constant int& in_stride_segment_axis,
|
||||
const constant int& out_stride_segment_axis,
|
||||
threadgroup val_t* tgp_vals,
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// tid.y tells us the segment index
|
||||
inp += tid.y * stride_segment_axis;
|
||||
out += tid.y * stride_segment_axis;
|
||||
inp += tid.y * in_stride_segment_axis;
|
||||
out += tid.y * out_stride_segment_axis;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
|
||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
|
||||
: val_t(CompareOp::init);
|
||||
if (ARG_SORT) {
|
||||
tgp_idxs[i] = i;
|
||||
@ -264,9 +266,9 @@ struct KernelMergeSort {
|
||||
// Write output
|
||||
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
|
||||
if (ARG_SORT) {
|
||||
out[i * stride_sorted_axis] = tgp_idxs[i];
|
||||
out[i * out_stride_sorted_axis] = tgp_idxs[i];
|
||||
} else {
|
||||
out[i * stride_sorted_axis] = tgp_vals[i];
|
||||
out[i * out_stride_sorted_axis] = tgp_vals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -282,8 +284,10 @@ template <
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
const constant int& in_stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& out_stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& in_stride_segment_axis [[buffer(5)]],
|
||||
const constant int& out_stride_segment_axis [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel =
|
||||
@ -298,8 +302,10 @@ template <
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
in_stride_sorted_axis,
|
||||
out_stride_sorted_axis,
|
||||
in_stride_segment_axis,
|
||||
out_stride_segment_axis,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
@ -310,8 +316,10 @@ template <
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
in_stride_sorted_axis,
|
||||
out_stride_sorted_axis,
|
||||
in_stride_segment_axis,
|
||||
out_stride_segment_axis,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
tid,
|
||||
@ -331,10 +339,12 @@ template <
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
const constant int& in_stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& out_stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* in_nc_strides [[buffer(7)]],
|
||||
const device size_t* out_nc_strides [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel =
|
||||
@ -342,9 +352,10 @@ template <
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||
inp += block_idx;
|
||||
out += block_idx;
|
||||
auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
|
||||
auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
|
||||
inp += in_block_idx;
|
||||
out += out_block_idx;
|
||||
|
||||
if (ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
@ -353,7 +364,9 @@ template <
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
in_stride_sorted_axis,
|
||||
out_stride_sorted_axis,
|
||||
zero_helper,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
@ -365,7 +378,9 @@ template <
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
in_stride_sorted_axis,
|
||||
out_stride_sorted_axis,
|
||||
zero_helper,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
|
@ -10,28 +10,10 @@
|
||||
|
||||
#define instantiate_block_sort( \
|
||||
name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
||||
template [[host_name("c" #name "_" #itname "_" #otname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
block_sort<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_segment_axis [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
|
||||
)]] [[kernel]] void \
|
||||
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& nc_dim [[buffer(4)]], \
|
||||
const device int* nc_shape [[buffer(5)]], \
|
||||
const device size_t* nc_strides [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
instantiate_kernel("c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
|
||||
block_sort, itype, otype, arg_sort, bn, tn) \
|
||||
instantiate_kernel("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
|
||||
block_sort_nc, itype, otype, arg_sort, bn, tn)
|
||||
|
||||
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
||||
instantiate_block_sort( \
|
||||
@ -69,43 +51,12 @@ instantiate_block_sort_long(int64, int64_t)
|
||||
|
||||
#define instantiate_multi_block_sort( \
|
||||
vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
||||
template [[host_name("sort_mbsort_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device vtype* inp [[buffer(0)]], \
|
||||
device vtype* out_vals [[buffer(1)]], \
|
||||
device itype* out_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("partition_mbsort_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||
device itype * block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals [[buffer(1)]], \
|
||||
const device itype* dev_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& merge_tiles [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
||||
template [[host_name("merge_mbsort_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals_in [[buffer(1)]], \
|
||||
const device itype* dev_idxs_in [[buffer(2)]], \
|
||||
device vtype* dev_vals_out [[buffer(3)]], \
|
||||
device itype* dev_idxs_out [[buffer(4)]], \
|
||||
const constant int& size_sorted_axis [[buffer(5)]], \
|
||||
const constant int& merge_tiles [[buffer(6)]], \
|
||||
const constant int& num_tiles [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
instantiate_kernel("sort_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
|
||||
mb_block_sort, vtype, itype, arg_sort, bn, tn) \
|
||||
instantiate_kernel("partition_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
|
||||
mb_block_partition, vtype, itype, arg_sort, bn, tn) \
|
||||
instantiate_kernel("merge_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
|
||||
mb_block_merge, vtype, itype, arg_sort, bn, tn)
|
||||
|
||||
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
||||
|
@ -24,8 +24,11 @@ void single_block_sort(
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
std::vector<size_t> nc_str = in.strides();
|
||||
nc_str.erase(nc_str.begin() + axis);
|
||||
std::vector<size_t> in_nc_str = in.strides();
|
||||
in_nc_str.erase(in_nc_str.begin() + axis);
|
||||
|
||||
std::vector<size_t> out_nc_str = out.strides();
|
||||
out_nc_str.erase(out_nc_str.begin() + axis);
|
||||
|
||||
std::vector<int> nc_shape = in.shape();
|
||||
nc_shape.erase(nc_shape.begin() + axis);
|
||||
@ -33,21 +36,28 @@ void single_block_sort(
|
||||
int nc_dim = nc_shape.size();
|
||||
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
int stride_sorted_axis = in.strides()[axis];
|
||||
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end());
|
||||
int in_stride_sorted_axis = in.strides()[axis];
|
||||
int out_stride_sorted_axis = out.strides()[axis];
|
||||
int in_stride_segment_axis =
|
||||
*std::min_element(in_nc_str.begin(), in_nc_str.end());
|
||||
int out_stride_segment_axis =
|
||||
*std::min_element(out_nc_str.begin(), out_nc_str.end());
|
||||
|
||||
// Check if remaining strides are contiguous
|
||||
bool contiguous_write = true;
|
||||
if (axis != in.ndim() - 1 && axis != 0) {
|
||||
for (int i = 0; i < nc_str.size() - 1; ++i) {
|
||||
size_t expected = nc_str[i + 1] * nc_str[i + 1];
|
||||
contiguous_write &= (nc_str[i] == expected);
|
||||
}
|
||||
}
|
||||
// We can only use the contiguous kernel if the sorted axis
|
||||
// has the largest or smallest stride.
|
||||
// We also need the input to be contiguous
|
||||
bool contiguous = in.flags().contiguous;
|
||||
auto check_strides = [](array x, int sort_stride) {
|
||||
int min_stride = *std::min_element(x.strides().begin(), x.strides().end());
|
||||
int max_stride = *std::max_element(x.strides().begin(), x.strides().end());
|
||||
return sort_stride == min_stride || sort_stride == max_stride;
|
||||
};
|
||||
contiguous &= check_strides(in, in_stride_sorted_axis);
|
||||
contiguous &= check_strides(out, out_stride_sorted_axis);
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << (contiguous_write ? "c" : "nc");
|
||||
kname << (contiguous ? "c" : "nc");
|
||||
if (argsort) {
|
||||
kname << "arg";
|
||||
}
|
||||
@ -64,14 +74,17 @@ void single_block_sort(
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4);
|
||||
|
||||
if (contiguous_write) {
|
||||
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4);
|
||||
if (contiguous) {
|
||||
compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 4);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8);
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
|
@ -3,7 +3,7 @@
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
from itertools import permutations, product
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
@ -1751,14 +1751,21 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1))
|
||||
|
||||
def test_sort(self):
|
||||
shape = (3, 4, 5)
|
||||
for dtype in ("int32", "float32"):
|
||||
for axis in (None, 0, 1, 2):
|
||||
with self.subTest(dtype=dtype, axis=axis):
|
||||
shape = (6, 4, 10)
|
||||
tests = product(
|
||||
("int32", "float32"), # type
|
||||
(None, 0, 1, 2), # axis
|
||||
(True, False), # strided
|
||||
)
|
||||
for dtype, axis, strided in tests:
|
||||
with self.subTest(dtype=dtype, axis=axis, strided=strided):
|
||||
np.random.seed(0)
|
||||
np_dtype = getattr(np, dtype)
|
||||
a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype)
|
||||
a_mx = mx.array(a_np)
|
||||
if strided:
|
||||
a_mx = a_mx[::2, :, ::2]
|
||||
a_np = a_np[::2, :, ::2]
|
||||
|
||||
b_np = np.sort(a_np, axis=axis)
|
||||
b_mx = mx.sort(a_mx, axis=axis)
|
||||
@ -1778,9 +1785,15 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
np.random.seed(0)
|
||||
|
||||
# Test multi-block sort
|
||||
for strided in (False, True):
|
||||
with self.subTest(strided=strided):
|
||||
a_np = np.random.normal(size=(32769,)).astype(np.float32)
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
if strided:
|
||||
a_mx = a_mx[::3]
|
||||
a_np = a_np[::3]
|
||||
|
||||
b_np = np.sort(a_np)
|
||||
b_mx = mx.sort(a_mx)
|
||||
|
||||
@ -1791,6 +1804,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32)
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
if strided:
|
||||
a_mx = a_mx[..., ::3]
|
||||
a_np = a_np[..., ::3]
|
||||
|
||||
b_np = np.sort(a_np, axis=-1)
|
||||
b_mx = mx.sort(a_mx, axis=-1)
|
||||
|
||||
@ -1800,12 +1817,28 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32)
|
||||
a_mx = mx.array(a_np)
|
||||
|
||||
if strided:
|
||||
a_mx = a_mx[:, ::3]
|
||||
a_np = a_np[:, ::3]
|
||||
|
||||
b_np = np.sort(a_np, axis=1)
|
||||
b_mx = mx.sort(a_mx, axis=1)
|
||||
|
||||
self.assertTrue(np.array_equal(b_np, b_mx))
|
||||
self.assertEqual(b_mx.dtype, a_mx.dtype)
|
||||
|
||||
# test 0 strides
|
||||
a_np = np.array([1, 0, 2, 1, 3, 0, 4, 0])
|
||||
a_mx = mx.array(a_np)
|
||||
b_np = np.broadcast_to(a_np, (16, 8))
|
||||
b_mx = mx.broadcast_to(a_mx, (16, 8))
|
||||
mx.eval(b_mx)
|
||||
for axis in (0, 1):
|
||||
c_np = np.sort(b_np, axis=axis)
|
||||
c_mx = mx.sort(b_mx, axis=axis)
|
||||
self.assertTrue(np.array_equal(c_np, c_mx))
|
||||
self.assertEqual(b_mx.dtype, c_mx.dtype)
|
||||
|
||||
def test_partition(self):
|
||||
shape = (3, 4, 5)
|
||||
for dtype in ("int32", "float32"):
|
||||
|
Loading…
Reference in New Issue
Block a user