mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Fix strided sort bug (#1236)
* Use output strides in sort kernel * fix zero strides bug
This commit is contained in:
@@ -1,81 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view block_sort_kernels = R"(
|
||||
template [[host_name("carg_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("ncarg_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("c_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("nc_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view multiblock_sort_kernels = R"(
|
||||
template [[host_name("sort_{0}")]] [[kernel]] void
|
||||
mb_block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {1}* out_vals [[buffer(1)]],
|
||||
device {2}* out_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("partition_{0}")]] [[kernel]] void
|
||||
mb_block_partition<{1}, {2}, true, {3}, {4}>(
|
||||
device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals [[buffer(1)]],
|
||||
const device {2}* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]);
|
||||
template [[host_name("merge_{0}")]] [[kernel]] void
|
||||
mb_block_merge<{1}, {2}, true, {3}, {4}>(
|
||||
const device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals_in [[buffer(1)]],
|
||||
const device {2}* dev_idxs_in [[buffer(2)]],
|
||||
device {1}* dev_vals_out [[buffer(3)]],
|
||||
device {2}* dev_idxs_out [[buffer(4)]],
|
||||
const constant int& size_sorted_axis [[buffer(5)]],
|
||||
const constant int& merge_tiles [[buffer(6)]],
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
@@ -8,7 +8,6 @@
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/sort.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
@@ -251,14 +250,29 @@ MTL::ComputePipelineState* get_sort_kernel(
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
block_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source << metal::utils() << metal::sort();
|
||||
for (bool is_argsort : {true, false}) {
|
||||
std::string bool_string = is_argsort ? "true" : "false";
|
||||
std::string func_string = is_argsort ? "carg_" : "c_";
|
||||
kernel_source << get_template_definition(
|
||||
func_string + lib_name,
|
||||
"block_sort",
|
||||
in_type,
|
||||
out_type,
|
||||
bool_string,
|
||||
bn,
|
||||
tn);
|
||||
kernel_source << get_template_definition(
|
||||
"n" + func_string + lib_name,
|
||||
"block_sort_nc",
|
||||
in_type,
|
||||
out_type,
|
||||
bool_string,
|
||||
bn,
|
||||
tn);
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -275,14 +289,21 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
multiblock_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(idx.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
kernel_source << metal::utils() << metal::sort();
|
||||
std::vector<std::pair<std::string, std::string>> kernel_types = {
|
||||
{"sort_", "mb_block_sort"},
|
||||
{"partition_", "mb_block_partition"},
|
||||
{"merge_", "mb_block_merge"}};
|
||||
for (auto [name, func] : kernel_types) {
|
||||
kernel_source << get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(idx.dtype()),
|
||||
"true",
|
||||
bn,
|
||||
tn);
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
|
@@ -235,19 +235,21 @@ struct KernelMergeSort {
|
||||
const device T* inp,
|
||||
device U* out,
|
||||
const constant int& size_sorted_axis,
|
||||
const constant int& stride_sorted_axis,
|
||||
const constant int& stride_segment_axis,
|
||||
const constant int& in_stride_sorted_axis,
|
||||
const constant int& out_stride_sorted_axis,
|
||||
const constant int& in_stride_segment_axis,
|
||||
const constant int& out_stride_segment_axis,
|
||||
threadgroup val_t* tgp_vals,
|
||||
threadgroup idx_t* tgp_idxs,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
// tid.y tells us the segment index
|
||||
inp += tid.y * stride_segment_axis;
|
||||
out += tid.y * stride_segment_axis;
|
||||
inp += tid.y * in_stride_segment_axis;
|
||||
out += tid.y * out_stride_segment_axis;
|
||||
|
||||
// Copy into threadgroup memory
|
||||
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
|
||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
|
||||
tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
|
||||
: val_t(CompareOp::init);
|
||||
if (ARG_SORT) {
|
||||
tgp_idxs[i] = i;
|
||||
@@ -264,9 +266,9 @@ struct KernelMergeSort {
|
||||
// Write output
|
||||
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
|
||||
if (ARG_SORT) {
|
||||
out[i * stride_sorted_axis] = tgp_idxs[i];
|
||||
out[i * out_stride_sorted_axis] = tgp_idxs[i];
|
||||
} else {
|
||||
out[i * stride_sorted_axis] = tgp_vals[i];
|
||||
out[i * out_stride_sorted_axis] = tgp_vals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -282,8 +284,10 @@ template <
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
const constant int& in_stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& out_stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& in_stride_segment_axis [[buffer(5)]],
|
||||
const constant int& out_stride_segment_axis [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel =
|
||||
@@ -298,8 +302,10 @@ template <
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
in_stride_sorted_axis,
|
||||
out_stride_sorted_axis,
|
||||
in_stride_segment_axis,
|
||||
out_stride_segment_axis,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
tid,
|
||||
@@ -310,8 +316,10 @@ template <
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
stride_segment_axis,
|
||||
in_stride_sorted_axis,
|
||||
out_stride_sorted_axis,
|
||||
in_stride_segment_axis,
|
||||
out_stride_segment_axis,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
tid,
|
||||
@@ -331,10 +339,12 @@ template <
|
||||
const device T* inp [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
const constant int& in_stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& out_stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* in_nc_strides [[buffer(7)]],
|
||||
const device size_t* out_nc_strides [[buffer(8)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
using sort_kernel =
|
||||
@@ -342,9 +352,10 @@ template <
|
||||
using val_t = typename sort_kernel::val_t;
|
||||
using idx_t = typename sort_kernel::idx_t;
|
||||
|
||||
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
|
||||
inp += block_idx;
|
||||
out += block_idx;
|
||||
auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
|
||||
auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
|
||||
inp += in_block_idx;
|
||||
out += out_block_idx;
|
||||
|
||||
if (ARG_SORT) {
|
||||
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
|
||||
@@ -353,7 +364,9 @@ template <
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
in_stride_sorted_axis,
|
||||
out_stride_sorted_axis,
|
||||
zero_helper,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
tgp_idxs,
|
||||
@@ -365,7 +378,9 @@ template <
|
||||
inp,
|
||||
out,
|
||||
size_sorted_axis,
|
||||
stride_sorted_axis,
|
||||
in_stride_sorted_axis,
|
||||
out_stride_sorted_axis,
|
||||
zero_helper,
|
||||
zero_helper,
|
||||
tgp_vals,
|
||||
nullptr,
|
||||
|
@@ -10,28 +10,10 @@
|
||||
|
||||
#define instantiate_block_sort( \
|
||||
name, itname, itype, otname, otype, arg_sort, bn, tn) \
|
||||
template [[host_name("c" #name "_" #itname "_" #otname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
block_sort<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_segment_axis [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
|
||||
)]] [[kernel]] void \
|
||||
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
|
||||
const device itype* inp [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant int& size_sorted_axis [[buffer(2)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(3)]], \
|
||||
const constant int& nc_dim [[buffer(4)]], \
|
||||
const device int* nc_shape [[buffer(5)]], \
|
||||
const device size_t* nc_strides [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
instantiate_kernel("c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
|
||||
block_sort, itype, otype, arg_sort, bn, tn) \
|
||||
instantiate_kernel("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
|
||||
block_sort_nc, itype, otype, arg_sort, bn, tn)
|
||||
|
||||
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
|
||||
instantiate_block_sort( \
|
||||
@@ -69,43 +51,12 @@ instantiate_block_sort_long(int64, int64_t)
|
||||
|
||||
#define instantiate_multi_block_sort( \
|
||||
vtname, vtype, itname, itype, arg_sort, bn, tn) \
|
||||
template [[host_name("sort_mbsort_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device vtype* inp [[buffer(0)]], \
|
||||
device vtype* out_vals [[buffer(1)]], \
|
||||
device itype* out_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& stride_sorted_axis [[buffer(4)]], \
|
||||
const constant int& nc_dim [[buffer(5)]], \
|
||||
const device int* nc_shape [[buffer(6)]], \
|
||||
const device size_t* nc_strides [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("partition_mbsort_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||
device itype * block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals [[buffer(1)]], \
|
||||
const device itype* dev_idxs [[buffer(2)]], \
|
||||
const constant int& size_sorted_axis [[buffer(3)]], \
|
||||
const constant int& merge_tiles [[buffer(4)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]); \
|
||||
template [[host_name("merge_mbsort_" #vtname "_" #itname "_bn" #bn \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
|
||||
const device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals_in [[buffer(1)]], \
|
||||
const device itype* dev_idxs_in [[buffer(2)]], \
|
||||
device vtype* dev_vals_out [[buffer(3)]], \
|
||||
device itype* dev_idxs_out [[buffer(4)]], \
|
||||
const constant int& size_sorted_axis [[buffer(5)]], \
|
||||
const constant int& merge_tiles [[buffer(6)]], \
|
||||
const constant int& num_tiles [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
instantiate_kernel("sort_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
|
||||
mb_block_sort, vtype, itype, arg_sort, bn, tn) \
|
||||
instantiate_kernel("partition_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
|
||||
mb_block_partition, vtype, itype, arg_sort, bn, tn) \
|
||||
instantiate_kernel("merge_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
|
||||
mb_block_merge, vtype, itype, arg_sort, bn, tn)
|
||||
|
||||
#define instantiate_multi_block_sort_base(vtname, vtype) \
|
||||
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
|
||||
|
@@ -24,8 +24,11 @@ void single_block_sort(
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
std::vector<size_t> nc_str = in.strides();
|
||||
nc_str.erase(nc_str.begin() + axis);
|
||||
std::vector<size_t> in_nc_str = in.strides();
|
||||
in_nc_str.erase(in_nc_str.begin() + axis);
|
||||
|
||||
std::vector<size_t> out_nc_str = out.strides();
|
||||
out_nc_str.erase(out_nc_str.begin() + axis);
|
||||
|
||||
std::vector<int> nc_shape = in.shape();
|
||||
nc_shape.erase(nc_shape.begin() + axis);
|
||||
@@ -33,21 +36,28 @@ void single_block_sort(
|
||||
int nc_dim = nc_shape.size();
|
||||
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
int stride_sorted_axis = in.strides()[axis];
|
||||
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end());
|
||||
int in_stride_sorted_axis = in.strides()[axis];
|
||||
int out_stride_sorted_axis = out.strides()[axis];
|
||||
int in_stride_segment_axis =
|
||||
*std::min_element(in_nc_str.begin(), in_nc_str.end());
|
||||
int out_stride_segment_axis =
|
||||
*std::min_element(out_nc_str.begin(), out_nc_str.end());
|
||||
|
||||
// Check if remaining strides are contiguous
|
||||
bool contiguous_write = true;
|
||||
if (axis != in.ndim() - 1 && axis != 0) {
|
||||
for (int i = 0; i < nc_str.size() - 1; ++i) {
|
||||
size_t expected = nc_str[i + 1] * nc_str[i + 1];
|
||||
contiguous_write &= (nc_str[i] == expected);
|
||||
}
|
||||
}
|
||||
// We can only use the contiguous kernel if the sorted axis
|
||||
// has the largest or smallest stride.
|
||||
// We also need the input to be contiguous
|
||||
bool contiguous = in.flags().contiguous;
|
||||
auto check_strides = [](array x, int sort_stride) {
|
||||
int min_stride = *std::min_element(x.strides().begin(), x.strides().end());
|
||||
int max_stride = *std::max_element(x.strides().begin(), x.strides().end());
|
||||
return sort_stride == min_stride || sort_stride == max_stride;
|
||||
};
|
||||
contiguous &= check_strides(in, in_stride_sorted_axis);
|
||||
contiguous &= check_strides(out, out_stride_sorted_axis);
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << (contiguous_write ? "c" : "nc");
|
||||
kname << (contiguous ? "c" : "nc");
|
||||
if (argsort) {
|
||||
kname << "arg";
|
||||
}
|
||||
@@ -64,14 +74,17 @@ void single_block_sort(
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4);
|
||||
|
||||
if (contiguous_write) {
|
||||
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4);
|
||||
if (contiguous) {
|
||||
compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 4);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8);
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
|
Reference in New Issue
Block a user