Fix strided sort bug (#1236)

* Use output strides in sort kernel

* fix zero strides bug
This commit is contained in:
Alex Barron 2024-06-26 14:32:11 -07:00 committed by GitHub
parent 5b0af4cdb1
commit 2615660e62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 222 additions and 262 deletions

View File

@ -113,14 +113,14 @@ void sort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
size_t axis_stride = out.strides()[axis];
int axis_size = out.shape(axis);
// Perform sorting in place
for (int i = 0; i < n_rows; i++) {
@ -143,34 +143,42 @@ void argsort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto in_remaining_strides = in.strides();
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
auto out_remaining_shape = out.shape();
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis);
// Perform sorting
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
const T* data_ptr = in.data<T>() + loc;
IdxT* idx_ptr = out.data<IdxT>() + loc;
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
const T* data_ptr = in.data<T>() + in_loc;
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
StridedIterator st_(idx_ptr, axis_stride, 0);
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, axis_stride, 0);
StridedIterator ed(idx_ptr, axis_stride, axis_size);
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * axis_stride];
auto v2 = data_ptr[b * axis_stride];
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}

View File

@ -1,81 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view block_sort_kernels = R"(
template [[host_name("carg_{0}")]] [[kernel]] void
block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("ncarg_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("c_{0}")]] [[kernel]] void
block_sort<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("nc_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view multiblock_sort_kernels = R"(
template [[host_name("sort_{0}")]] [[kernel]] void
mb_block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {1}* out_vals [[buffer(1)]],
device {2}* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("partition_{0}")]] [[kernel]] void
mb_block_partition<{1}, {2}, true, {3}, {4}>(
device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals [[buffer(1)]],
const device {2}* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]);
template [[host_name("merge_{0}")]] [[kernel]] void
mb_block_merge<{1}, {2}, true, {3}, {4}>(
const device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals_in [[buffer(1)]],
const device {2}* dev_idxs_in [[buffer(2)]],
device {1}* dev_vals_out [[buffer(3)]],
device {2}* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";

View File

@ -8,7 +8,6 @@
#include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/sort.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/kernels.h"
@ -251,14 +250,29 @@ MTL::ComputePipelineState* get_sort_kernel(
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
block_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
bn,
tn);
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source << metal::utils() << metal::sort();
for (bool is_argsort : {true, false}) {
std::string bool_string = is_argsort ? "true" : "false";
std::string func_string = is_argsort ? "carg_" : "c_";
kernel_source << get_template_definition(
func_string + lib_name,
"block_sort",
in_type,
out_type,
bool_string,
bn,
tn);
kernel_source << get_template_definition(
"n" + func_string + lib_name,
"block_sort_nc",
in_type,
out_type,
bool_string,
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@ -275,14 +289,21 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
multiblock_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(idx.dtype()),
bn,
tn);
kernel_source << metal::utils() << metal::sort();
std::vector<std::pair<std::string, std::string>> kernel_types = {
{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}};
for (auto [name, func] : kernel_types) {
kernel_source << get_template_definition(
name + lib_name,
func,
get_type_string(in.dtype()),
get_type_string(idx.dtype()),
"true",
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);

View File

@ -235,19 +235,21 @@ struct KernelMergeSort {
const device T* inp,
device U* out,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
const constant int& stride_segment_axis,
const constant int& in_stride_sorted_axis,
const constant int& out_stride_sorted_axis,
const constant int& in_stride_segment_axis,
const constant int& out_stride_segment_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
inp += tid.y * stride_segment_axis;
out += tid.y * stride_segment_axis;
inp += tid.y * in_stride_segment_axis;
out += tid.y * out_stride_segment_axis;
// Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
: val_t(CompareOp::init);
if (ARG_SORT) {
tgp_idxs[i] = i;
@ -264,9 +266,9 @@ struct KernelMergeSort {
// Write output
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
if (ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i];
out[i * out_stride_sorted_axis] = tgp_idxs[i];
} else {
out[i * stride_sorted_axis] = tgp_vals[i];
out[i * out_stride_sorted_axis] = tgp_vals[i];
}
}
}
@ -282,8 +284,10 @@ template <
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
const constant int& in_stride_sorted_axis [[buffer(3)]],
const constant int& out_stride_sorted_axis [[buffer(4)]],
const constant int& in_stride_segment_axis [[buffer(5)]],
const constant int& out_stride_segment_axis [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
@ -298,8 +302,10 @@ template <
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
in_stride_segment_axis,
out_stride_segment_axis,
tgp_vals,
tgp_idxs,
tid,
@ -310,8 +316,10 @@ template <
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
in_stride_segment_axis,
out_stride_segment_axis,
tgp_vals,
nullptr,
tid,
@ -331,10 +339,12 @@ template <
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
const constant int& in_stride_sorted_axis [[buffer(3)]],
const constant int& out_stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* in_nc_strides [[buffer(7)]],
const device size_t* out_nc_strides [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
@ -342,9 +352,10 @@ template <
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out += block_idx;
auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
inp += in_block_idx;
out += out_block_idx;
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
@ -353,7 +364,9 @@ template <
inp,
out,
size_sorted_axis,
stride_sorted_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
zero_helper,
zero_helper,
tgp_vals,
tgp_idxs,
@ -365,7 +378,9 @@ template <
inp,
out,
size_sorted_axis,
stride_sorted_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
zero_helper,
zero_helper,
tgp_vals,
nullptr,

View File

@ -10,28 +10,10 @@
#define instantiate_block_sort( \
name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name("c" #name "_" #itname "_" #otname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
block_sort<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
)]] [[kernel]] void \
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& nc_dim [[buffer(4)]], \
const device int* nc_shape [[buffer(5)]], \
const device size_t* nc_strides [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
instantiate_kernel("c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
block_sort, itype, otype, arg_sort, bn, tn) \
instantiate_kernel("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
block_sort_nc, itype, otype, arg_sort, bn, tn)
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort( \
@ -69,43 +51,12 @@ instantiate_block_sort_long(int64, int64_t)
#define instantiate_multi_block_sort( \
vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("sort_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
const device vtype* inp [[buffer(0)]], \
device vtype* out_vals [[buffer(1)]], \
device itype* out_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& stride_sorted_axis [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("partition_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype * block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& merge_tiles [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("merge_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals_in [[buffer(1)]], \
const device itype* dev_idxs_in [[buffer(2)]], \
device vtype* dev_vals_out [[buffer(3)]], \
device itype* dev_idxs_out [[buffer(4)]], \
const constant int& size_sorted_axis [[buffer(5)]], \
const constant int& merge_tiles [[buffer(6)]], \
const constant int& num_tiles [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
instantiate_kernel("sort_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
mb_block_sort, vtype, itype, arg_sort, bn, tn) \
instantiate_kernel("partition_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
mb_block_partition, vtype, itype, arg_sort, bn, tn) \
instantiate_kernel("merge_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
mb_block_merge, vtype, itype, arg_sort, bn, tn)
#define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)

View File

@ -24,8 +24,11 @@ void single_block_sort(
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
std::vector<size_t> nc_str = in.strides();
nc_str.erase(nc_str.begin() + axis);
std::vector<size_t> in_nc_str = in.strides();
in_nc_str.erase(in_nc_str.begin() + axis);
std::vector<size_t> out_nc_str = out.strides();
out_nc_str.erase(out_nc_str.begin() + axis);
std::vector<int> nc_shape = in.shape();
nc_shape.erase(nc_shape.begin() + axis);
@ -33,21 +36,28 @@ void single_block_sort(
int nc_dim = nc_shape.size();
int size_sorted_axis = in.shape(axis);
int stride_sorted_axis = in.strides()[axis];
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end());
int in_stride_sorted_axis = in.strides()[axis];
int out_stride_sorted_axis = out.strides()[axis];
int in_stride_segment_axis =
*std::min_element(in_nc_str.begin(), in_nc_str.end());
int out_stride_segment_axis =
*std::min_element(out_nc_str.begin(), out_nc_str.end());
// Check if remaining strides are contiguous
bool contiguous_write = true;
if (axis != in.ndim() - 1 && axis != 0) {
for (int i = 0; i < nc_str.size() - 1; ++i) {
size_t expected = nc_str[i + 1] * nc_str[i + 1];
contiguous_write &= (nc_str[i] == expected);
}
}
// We can only use the contiguous kernel if the sorted axis
// has the largest or smallest stride.
// We also need the input to be contiguous
bool contiguous = in.flags().contiguous;
auto check_strides = [](array x, int sort_stride) {
int min_stride = *std::min_element(x.strides().begin(), x.strides().end());
int max_stride = *std::max_element(x.strides().begin(), x.strides().end());
return sort_stride == min_stride || sort_stride == max_stride;
};
contiguous &= check_strides(in, in_stride_sorted_axis);
contiguous &= check_strides(out, out_stride_sorted_axis);
// Prepare kernel name
std::ostringstream kname;
kname << (contiguous_write ? "c" : "nc");
kname << (contiguous ? "c" : "nc");
if (argsort) {
kname << "arg";
}
@ -64,14 +74,17 @@ void single_block_sort(
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4);
if (contiguous_write) {
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4);
if (contiguous) {
compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5);
compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6);
} else {
compute_encoder->setBytes(&nc_dim, sizeof(int), 4);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5);
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6);
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7);
compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8);
}
MTL::Size group_dims = MTL::Size(bn, 1, 1);

View File

@ -3,7 +3,7 @@
import math
import os
import unittest
from itertools import permutations
from itertools import permutations, product
import mlx.core as mx
import mlx_tests
@ -1751,60 +1751,93 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1))
def test_sort(self):
shape = (3, 4, 5)
for dtype in ("int32", "float32"):
for axis in (None, 0, 1, 2):
with self.subTest(dtype=dtype, axis=axis):
np.random.seed(0)
np_dtype = getattr(np, dtype)
a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype)
a_mx = mx.array(a_np)
shape = (6, 4, 10)
tests = product(
("int32", "float32"), # type
(None, 0, 1, 2), # axis
(True, False), # strided
)
for dtype, axis, strided in tests:
with self.subTest(dtype=dtype, axis=axis, strided=strided):
np.random.seed(0)
np_dtype = getattr(np, dtype)
a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype)
a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[::2, :, ::2]
a_np = a_np[::2, :, ::2]
b_np = np.sort(a_np, axis=axis)
b_mx = mx.sort(a_mx, axis=axis)
b_np = np.sort(a_np, axis=axis)
b_mx = mx.sort(a_mx, axis=axis)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
c_np = np.argsort(a_np, axis=axis)
c_mx = mx.argsort(a_mx, axis=axis)
d_np = np.take_along_axis(a_np, c_np, axis=axis)
d_mx = mx.take_along_axis(a_mx, c_mx, axis=axis)
c_np = np.argsort(a_np, axis=axis)
c_mx = mx.argsort(a_mx, axis=axis)
d_np = np.take_along_axis(a_np, c_np, axis=axis)
d_mx = mx.take_along_axis(a_mx, c_mx, axis=axis)
self.assertTrue(np.array_equal(d_np, d_mx))
self.assertEqual(c_mx.dtype, mx.uint32)
self.assertTrue(np.array_equal(d_np, d_mx))
self.assertEqual(c_mx.dtype, mx.uint32)
# Set random seed
np.random.seed(0)
# Test multi-block sort
a_np = np.random.normal(size=(32769,)).astype(np.float32)
for strided in (False, True):
with self.subTest(strided=strided):
a_np = np.random.normal(size=(32769,)).astype(np.float32)
a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[::3]
a_np = a_np[::3]
b_np = np.sort(a_np)
b_mx = mx.sort(a_mx)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
# Test multi-dum multi-block sort
a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32)
a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[..., ::3]
a_np = a_np[..., ::3]
b_np = np.sort(a_np, axis=-1)
b_mx = mx.sort(a_mx, axis=-1)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32)
a_mx = mx.array(a_np)
if strided:
a_mx = a_mx[:, ::3]
a_np = a_np[:, ::3]
b_np = np.sort(a_np, axis=1)
b_mx = mx.sort(a_mx, axis=1)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
# test 0 strides
a_np = np.array([1, 0, 2, 1, 3, 0, 4, 0])
a_mx = mx.array(a_np)
b_np = np.sort(a_np)
b_mx = mx.sort(a_mx)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
# Test multi-dum multi-block sort
a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32)
a_mx = mx.array(a_np)
b_np = np.sort(a_np, axis=-1)
b_mx = mx.sort(a_mx, axis=-1)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32)
a_mx = mx.array(a_np)
b_np = np.sort(a_np, axis=1)
b_mx = mx.sort(a_mx, axis=1)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
b_np = np.broadcast_to(a_np, (16, 8))
b_mx = mx.broadcast_to(a_mx, (16, 8))
mx.eval(b_mx)
for axis in (0, 1):
c_np = np.sort(b_np, axis=axis)
c_mx = mx.sort(b_mx, axis=axis)
self.assertTrue(np.array_equal(c_np, c_mx))
self.assertEqual(b_mx.dtype, c_mx.dtype)
def test_partition(self):
shape = (3, 4, 5)