Allow no copy negative strides in as_strided and slice (#1688)

* allow no copy negative strides in as_strided and slice

* fix jit

* fix jit
This commit is contained in:
Awni Hannun 2024-12-12 08:59:45 -08:00 committed by GitHub
parent 4d595a2a39
commit 6bd28d246e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 133 additions and 163 deletions

View File

@ -507,34 +507,16 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
auto copy_needed = std::any_of(
strides_.begin(), strides_.end(), [](auto i) { return i < 0; });
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
Strides ostrides{out.strides().begin(), out.strides().end()};
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General);
} else {
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
size_t data_end = 1;
for (int i = 0; i < end_indices_.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
data_end += end_idx * in.strides()[i];
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
size_t data_size = data_end - data_offset;
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {

View File

@ -81,13 +81,15 @@ void binary_op_gpu_inplace(
};
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool large = out.data_size() > UINT32_MAX;
bool large;
auto ndim = shape.size();
int work_per_thread;
if (bopt == BinaryOpType::General) {
large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX);
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out.size() > INT32_MAX;
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > UINT32_MAX;
work_per_thread = 1;
}
std::string kernel_name =

View File

@ -53,9 +53,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
bool large_src = src.size() > UINT32_MAX;
bool large_out = out.size() > UINT32_MAX;
bool large_index = nidx && inputs[1].size() > INT32_MAX;
bool large_src = src.size() > INT32_MAX;
bool large_out = out.size() > INT32_MAX;
bool large = large_index || large_src || large_out;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
@ -65,7 +65,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_type_name,
nidx,
idx_ndim,
large ? "int64_t" : "uint");
large ? "int64_t" : "int");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
@ -86,7 +86,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_args,
idx_arr,
idx_ndim,
large ? "int64_t" : "uint");
large ? "int64_t" : "int");
return kernel_source;
});
@ -234,9 +234,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break;
}
auto upd_contig = upd.flags().row_contiguous;
bool large_out = out.size() > UINT32_MAX;
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
bool large_upd = upd.size() > UINT32_MAX;
bool large_out = out.size() > INT32_MAX;
bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
bool large_upd = upd.size() > INT32_MAX;
bool large = large_out || large_idx || large_upd;
std::string kernel_name = fmt::format(
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nidx,
upd_contig ? "updc_true" : "updc_false",
nwork,
large ? "int64_t" : "uint");
large ? "int64_t" : "int");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() {
@ -290,7 +290,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_arr,
upd_contig,
nwork,
large ? "int64_t" : "uint");
large ? "int64_t" : "int");
return kernel_source;
});

View File

@ -52,7 +52,7 @@ MTL::ComputePipelineState* get_unary_kernel(
kernel_source +=
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source += get_template_definition(
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source;
@ -74,7 +74,7 @@ void append_binary_kernels(
{"vs2", "binary_vs2"},
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"},
{"g1large", "binary_g_nd1"},
{"g2large", "binary_g_nd2"},
{"g3large", "binary_g_nd3"},
}};
@ -86,11 +86,13 @@ void append_binary_kernels(
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
}
kernel_source += get_template_definition(
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
}
@ -141,7 +143,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g1", "ternary_g_nd1"},
{"g1large", "ternary_g_nd1"},
{"g2large", "ternary_g_nd2"},
{"g3large", "ternary_g_nd3"},
}};
@ -150,11 +152,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
get_template_definition(name + "_" + lib_name, func, t_str, op);
}
kernel_source += get_template_definition(
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
kernel_source += get_template_definition(
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "int");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "int");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
return kernel_source;
@ -178,7 +182,7 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source +=
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
kernel_source += get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
@ -186,19 +190,23 @@ MTL::ComputePipelineState* get_copy_kernel(
kernel_source += get_template_definition(
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
kernel_source += get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
kernel_source += get_template_definition(
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
kernel_source += get_template_definition(
"g1large_" + lib_name, "copy_g_nd1", in_type, out_type);
kernel_source += get_template_definition(
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
kernel_source += get_template_definition(
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
kernel_source += get_template_definition(
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
kernel_source += get_template_definition(
"gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type);
kernel_source += get_template_definition(
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
kernel_source += get_template_definition(

View File

@ -9,21 +9,21 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \

View File

@ -7,21 +7,21 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_float(op) \

View File

@ -53,10 +53,10 @@ instantiate_init_min_max(max, Max)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, int64_t, dim) \
@ -67,7 +67,7 @@ instantiate_init_min_max(max, Max)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, uint, dim, bm, bn) \
itype, otype, op, int, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, int64_t, dim, bm, bn)
@ -75,7 +75,7 @@ instantiate_init_min_max(max, Max)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, uint, dim, bm, bn) \
itype, otype, op, int, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, int64_t, dim, bm, bn)
@ -95,7 +95,7 @@ instantiate_init_min_max(max, Max)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, int64_t, dim)
@ -103,7 +103,7 @@ instantiate_init_min_max(max, Max)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, uint, dim) \
itype, otype, op, int, dim) \
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, int64_t, dim)

View File

@ -8,17 +8,17 @@
#include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h"
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, int) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
#define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \

View File

@ -9,19 +9,19 @@
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel( \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
instantiate_kernel( \
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type)
#define instantiate_unary_float(op) \
#define instantiate_unary_float(op) \
instantiate_unary_all_same(op, float16, half) \
instantiate_unary_all_same(op, float32, float) \
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
#define instantiate_unary_types(op) \
#define instantiate_unary_types(op) \
instantiate_unary_all_same(op, bool_, bool) \
instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \

View File

@ -91,21 +91,7 @@ struct Limits<complex64_t> {
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc(
uint elem,
constant const int* shape,
constant const int64_t* strides,
int ndim) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc(
int64_t elem,
IdxT elem,
constant const int* shape,
constant const int64_t* strides,
int ndim) {
@ -187,9 +173,12 @@ METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
constant const int64_t* c_strides,
int ndim) {
vec<IdxT, 3> loc = {
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
IdxT(elem.y * IdxT(a_strides[ndim - 2])),
IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
IdxT(elem.y * IdxT(b_strides[ndim - 2])),
IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * IdxT(a_strides[d]);

View File

@ -393,7 +393,7 @@ void row_reduce_small(
auto [in_type, out_type] = remap_reduce_types(in, op_name);
const std::string func_name = "row_reduce_small";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@ -411,7 +411,7 @@ void row_reduce_small(
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
large ? "size_t" : "int",
n);
compute_encoder.set_compute_pipeline_state(kernel);
@ -490,7 +490,7 @@ void row_reduce_looped(
int n = get_kernel_reduce_ndim(args.reduce_ndim);
const std::string func_name = "row_reduce_looped";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@ -508,7 +508,7 @@ void row_reduce_looped(
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
large ? "size_t" : "int",
n);
compute_encoder.set_compute_pipeline_state(kernel);
@ -574,7 +574,7 @@ void strided_reduce_small(
int n = get_kernel_reduce_ndim(args.reduce_ndim);
const std::string func_name = "col_reduce_small";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@ -592,7 +592,7 @@ void strided_reduce_small(
op_name,
in_type,
out_type,
large ? "size_t" : "uint",
large ? "size_t" : "int",
n);
compute_encoder.set_compute_pipeline_state(kernel);
@ -665,7 +665,7 @@ void strided_reduce_longcolumn(
int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::string func_name = "col_reduce_longcolumn";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@ -683,7 +683,7 @@ void strided_reduce_longcolumn(
op_name,
in_type,
out_type,
large ? "int64_t" : "uint",
large ? "int64_t" : "int",
n);
compute_encoder.set_compute_pipeline_state(kernel);
@ -706,7 +706,7 @@ void strided_reduce_longcolumn(
// Set the 2nd kernel
func_name = "col_reduce_looped";
kname = func_name;
large = intermediate.size() > UINT32_MAX;
large = intermediate.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@ -718,7 +718,7 @@ void strided_reduce_longcolumn(
op_name,
intermediate.dtype(),
out_type,
large ? "int64_t" : "uint",
large ? "int64_t" : "int",
1,
32,
32);
@ -760,7 +760,7 @@ void strided_reduce_looped(
int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::string func_name = "col_reduce_looped";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@ -782,7 +782,7 @@ void strided_reduce_looped(
op_name,
in_type,
out_type,
large ? "int64_t" : "uint",
large ? "int64_t" : "int",
n,
BM,
BN);
@ -837,7 +837,7 @@ void strided_reduce_2pass(
int n = get_kernel_reduce_ndim(args.reduce_ndim);
std::string func_name = "col_reduce_2pass";
std::string kname = func_name;
bool large = in.size() > UINT32_MAX;
bool large = in.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@ -859,7 +859,7 @@ void strided_reduce_2pass(
op_name,
in_type,
out_type,
large ? "int64_t" : "uint",
large ? "int64_t" : "int",
n,
BM,
BN);
@ -882,7 +882,7 @@ void strided_reduce_2pass(
// Set the 2nd kernel
func_name = "col_reduce_looped";
kname = func_name;
large = intermediate.size() > UINT32_MAX;
large = intermediate.size() > INT32_MAX;
if (large) {
kname += "_large";
}
@ -894,7 +894,7 @@ void strided_reduce_2pass(
op_name,
intermediate.dtype(),
out_type,
large ? "int64_t" : "uint",
large ? "int64_t" : "int",
1,
32,
32);

View File

@ -16,33 +16,16 @@ void slice_gpu(
const Stream& s) {
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
auto copy_needed =
std::any_of(strides.begin(), strides.end(), [](auto i) { return i < 0; });
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
copy_gpu_inplace(
/* const array& in = */ in,
/* array& out = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ out.strides(),
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General,
/* const Stream& s = */ s);
} else {
size_t data_end = 1;
for (int i = 0; i < strides.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
size_t data_end = 1;
for (int i = 0; i < strides.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
}
size_t data_size = data_end - data_offset;
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}
size_t data_size = data_end - data_offset;
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}
void concatenate_gpu(

View File

@ -36,15 +36,15 @@ void ternary_op_gpu_inplace(
};
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
bool large = out.data_size() > UINT_MAX;
bool large;
auto ndim = shape.size();
int work_per_thread;
if (topt == TernaryOpType::General) {
large |=
(a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
c.data_size() > UINT32_MAX);
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
c.data_size() > INT32_MAX || out.size() > INT32_MAX;
work_per_thread = large ? 4 : 2;
} else {
large = out.data_size() > INT32_MAX;
work_per_thread = 1;
}
std::string kernel_name;

View File

@ -36,9 +36,11 @@ void unary_op_gpu_inplace(
auto [shape, strides] = maybe_collapse();
int ndim = shape.size();
size_t nthreads = contig ? in.data_size() : in.size();
bool large = in.data_size() > UINT32_MAX;
bool large;
if (!contig) {
large |= in.size() > UINT32_MAX;
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
} else {
large = in.data_size() > UINT32_MAX;
}
int work_per_thread = !contig && large ? 4 : 1;
std::string kernel_name;

View File

@ -1758,6 +1758,10 @@ class TestOps(mlx_tests.MLXTestCase):
y_mlx = mx.as_strided(x_mlx, shape, stride, offset)
self.assertTrue(np.array_equal(y_npy, y_mlx))
x = mx.random.uniform(shape=(32,))
y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)
self.assertTrue(mx.array_equal(y, x[::-1]))
def test_scans(self):
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
a_mlx = mx.array(a_npy)