mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Allow no copy negative strides in as_strided and slice (#1688)
* allow no copy negative strides in as_strided and slice * fix jit * fix jit
This commit is contained in:
parent
4d595a2a39
commit
6bd28d246e
@ -507,34 +507,16 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
|
||||
auto copy_needed = std::any_of(
|
||||
strides_.begin(), strides_.end(), [](auto i) { return i < 0; });
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
Strides ostrides{out.strides().begin(), out.strides().end()};
|
||||
copy_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General);
|
||||
} else {
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
Strides ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
Strides ostrides{inp_strides.begin(), inp_strides.end()};
|
||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
@ -81,13 +81,15 @@ void binary_op_gpu_inplace(
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
|
||||
|
||||
bool large = out.data_size() > UINT32_MAX;
|
||||
bool large;
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread;
|
||||
if (bopt == BinaryOpType::General) {
|
||||
large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX);
|
||||
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out.size() > INT32_MAX;
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > UINT32_MAX;
|
||||
work_per_thread = 1;
|
||||
}
|
||||
std::string kernel_name =
|
||||
|
@ -53,9 +53,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
bool large_index = nidx && inputs[1].size() > UINT32_MAX;
|
||||
bool large_src = src.size() > UINT32_MAX;
|
||||
bool large_out = out.size() > UINT32_MAX;
|
||||
bool large_index = nidx && inputs[1].size() > INT32_MAX;
|
||||
bool large_src = src.size() > INT32_MAX;
|
||||
bool large_out = out.size() > INT32_MAX;
|
||||
bool large = large_index || large_src || large_out;
|
||||
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
@ -65,7 +65,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_type_name,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
@ -86,7 +86,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_args,
|
||||
idx_arr,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
@ -234,9 +234,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
break;
|
||||
}
|
||||
auto upd_contig = upd.flags().row_contiguous;
|
||||
bool large_out = out.size() > UINT32_MAX;
|
||||
bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
|
||||
bool large_upd = upd.size() > UINT32_MAX;
|
||||
bool large_out = out.size() > INT32_MAX;
|
||||
bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
|
||||
bool large_upd = upd.size() > INT32_MAX;
|
||||
bool large = large_out || large_idx || large_upd;
|
||||
std::string kernel_name = fmt::format(
|
||||
"scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
|
||||
@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nidx,
|
||||
upd_contig ? "updc_true" : "updc_false",
|
||||
nwork,
|
||||
large ? "int64_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
@ -290,7 +290,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
idx_arr,
|
||||
upd_contig,
|
||||
nwork,
|
||||
large ? "int64_t" : "uint");
|
||||
large ? "int64_t" : "int");
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
|
@ -52,7 +52,7 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
|
||||
kernel_source += get_template_definition(
|
||||
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
|
||||
"gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
|
||||
return kernel_source;
|
||||
@ -74,7 +74,7 @@ void append_binary_kernels(
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g1large", "binary_g_nd1"},
|
||||
{"g2large", "binary_g_nd2"},
|
||||
{"g3large", "binary_g_nd3"},
|
||||
}};
|
||||
@ -86,11 +86,13 @@ void append_binary_kernels(
|
||||
get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
|
||||
"g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
|
||||
"g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
|
||||
"g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
|
||||
}
|
||||
@ -141,7 +143,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g1large", "ternary_g_nd1"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
{"g3large", "ternary_g_nd3"},
|
||||
}};
|
||||
@ -150,11 +152,13 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
|
||||
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
|
||||
"g2_" + lib_name, "ternary_g_nd2", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
|
||||
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
|
||||
return kernel_source;
|
||||
@ -178,7 +182,7 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||
"g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@ -186,19 +190,23 @@ MTL::ComputePipelineState* get_copy_kernel(
|
||||
kernel_source += get_template_definition(
|
||||
"gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
|
||||
kernel_source += get_template_definition(
|
||||
"g1large_" + lib_name, "copy_g_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
|
||||
kernel_source += get_template_definition(
|
||||
"gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
|
||||
kernel_source += get_template_definition(
|
||||
|
@ -9,21 +9,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
|
@ -7,21 +7,21 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \
|
||||
instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
|
||||
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \
|
||||
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \
|
||||
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \
|
||||
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
|
@ -53,10 +53,10 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
|
||||
col_reduce_longcolumn, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
col_reduce_small, \
|
||||
itype, otype, op, int64_t, dim) \
|
||||
@ -67,7 +67,7 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
itype, otype, op, int, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_looped, \
|
||||
itype, otype, op, int64_t, dim, bm, bn)
|
||||
@ -75,7 +75,7 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, uint, dim, bm, bn) \
|
||||
itype, otype, op, int, dim, bm, bn) \
|
||||
instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
|
||||
col_reduce_2pass, \
|
||||
itype, otype, op, int64_t, dim, bm, bn)
|
||||
@ -95,7 +95,7 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_small, \
|
||||
itype, otype, op, int64_t, dim)
|
||||
@ -103,7 +103,7 @@ instantiate_init_min_max(max, Max)
|
||||
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
|
||||
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, uint, dim) \
|
||||
itype, otype, op, int, dim) \
|
||||
instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
|
||||
row_reduce_looped, \
|
||||
itype, otype, op, int64_t, dim)
|
||||
|
@ -8,17 +8,17 @@
|
||||
#include "mlx/backend/metal/kernels/ternary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \
|
||||
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
#define instantiate_ternary_all(op, tname, type) \
|
||||
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
|
||||
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
|
||||
instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, int) \
|
||||
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
|
||||
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
|
||||
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \
|
||||
instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
|
||||
instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
|
||||
instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
|
||||
instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \
|
||||
|
||||
#define instantiate_ternary_types(op) \
|
||||
instantiate_ternary_all(op, bool_, bool) \
|
||||
|
@ -9,19 +9,19 @@
|
||||
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
|
||||
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
|
||||
instantiate_kernel( \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \
|
||||
"gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \
|
||||
instantiate_kernel( \
|
||||
"gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
|
||||
|
||||
#define instantiate_unary_all_same(op, tname, type) \
|
||||
instantiate_unary_all(op, tname, tname, type, type)
|
||||
|
||||
#define instantiate_unary_float(op) \
|
||||
#define instantiate_unary_float(op) \
|
||||
instantiate_unary_all_same(op, float16, half) \
|
||||
instantiate_unary_all_same(op, float32, float) \
|
||||
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
|
||||
|
||||
#define instantiate_unary_types(op) \
|
||||
#define instantiate_unary_types(op) \
|
||||
instantiate_unary_all_same(op, bool_, bool) \
|
||||
instantiate_unary_all_same(op, uint8, uint8_t) \
|
||||
instantiate_unary_all_same(op, uint16, uint16_t) \
|
||||
|
@ -91,21 +91,7 @@ struct Limits<complex64_t> {
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
uint elem,
|
||||
constant const int* shape,
|
||||
constant const int64_t* strides,
|
||||
int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
METAL_FUNC IdxT elem_to_loc(
|
||||
int64_t elem,
|
||||
IdxT elem,
|
||||
constant const int* shape,
|
||||
constant const int64_t* strides,
|
||||
int ndim) {
|
||||
@ -187,9 +173,12 @@ METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
|
||||
constant const int64_t* c_strides,
|
||||
int ndim) {
|
||||
vec<IdxT, 3> loc = {
|
||||
elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
|
||||
elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
|
||||
elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
|
||||
IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
|
||||
IdxT(elem.y * IdxT(a_strides[ndim - 2])),
|
||||
IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
|
||||
IdxT(elem.y * IdxT(b_strides[ndim - 2])),
|
||||
IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
|
||||
IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
|
||||
for (int d = ndim - 3; d >= 0; --d) {
|
||||
uint l = elem.z % shape[d];
|
||||
loc.x += l * IdxT(a_strides[d]);
|
||||
|
@ -393,7 +393,7 @@ void row_reduce_small(
|
||||
auto [in_type, out_type] = remap_reduce_types(in, op_name);
|
||||
const std::string func_name = "row_reduce_small";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@ -411,7 +411,7 @@ void row_reduce_small(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
large ? "size_t" : "int",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@ -490,7 +490,7 @@ void row_reduce_looped(
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
const std::string func_name = "row_reduce_looped";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@ -508,7 +508,7 @@ void row_reduce_looped(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
large ? "size_t" : "int",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@ -574,7 +574,7 @@ void strided_reduce_small(
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
const std::string func_name = "col_reduce_small";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@ -592,7 +592,7 @@ void strided_reduce_small(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "size_t" : "uint",
|
||||
large ? "size_t" : "int",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@ -665,7 +665,7 @@ void strided_reduce_longcolumn(
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
std::string func_name = "col_reduce_longcolumn";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@ -683,7 +683,7 @@ void strided_reduce_longcolumn(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "int64_t" : "uint",
|
||||
large ? "int64_t" : "int",
|
||||
n);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
@ -706,7 +706,7 @@ void strided_reduce_longcolumn(
|
||||
// Set the 2nd kernel
|
||||
func_name = "col_reduce_looped";
|
||||
kname = func_name;
|
||||
large = intermediate.size() > UINT32_MAX;
|
||||
large = intermediate.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@ -718,7 +718,7 @@ void strided_reduce_longcolumn(
|
||||
op_name,
|
||||
intermediate.dtype(),
|
||||
out_type,
|
||||
large ? "int64_t" : "uint",
|
||||
large ? "int64_t" : "int",
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
@ -760,7 +760,7 @@ void strided_reduce_looped(
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
std::string func_name = "col_reduce_looped";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@ -782,7 +782,7 @@ void strided_reduce_looped(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "int64_t" : "uint",
|
||||
large ? "int64_t" : "int",
|
||||
n,
|
||||
BM,
|
||||
BN);
|
||||
@ -837,7 +837,7 @@ void strided_reduce_2pass(
|
||||
int n = get_kernel_reduce_ndim(args.reduce_ndim);
|
||||
std::string func_name = "col_reduce_2pass";
|
||||
std::string kname = func_name;
|
||||
bool large = in.size() > UINT32_MAX;
|
||||
bool large = in.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@ -859,7 +859,7 @@ void strided_reduce_2pass(
|
||||
op_name,
|
||||
in_type,
|
||||
out_type,
|
||||
large ? "int64_t" : "uint",
|
||||
large ? "int64_t" : "int",
|
||||
n,
|
||||
BM,
|
||||
BN);
|
||||
@ -882,7 +882,7 @@ void strided_reduce_2pass(
|
||||
// Set the 2nd kernel
|
||||
func_name = "col_reduce_looped";
|
||||
kname = func_name;
|
||||
large = intermediate.size() > UINT32_MAX;
|
||||
large = intermediate.size() > INT32_MAX;
|
||||
if (large) {
|
||||
kname += "_large";
|
||||
}
|
||||
@ -894,7 +894,7 @@ void strided_reduce_2pass(
|
||||
op_name,
|
||||
intermediate.dtype(),
|
||||
out_type,
|
||||
large ? "int64_t" : "uint",
|
||||
large ? "int64_t" : "int",
|
||||
1,
|
||||
32,
|
||||
32);
|
||||
|
@ -16,33 +16,16 @@ void slice_gpu(
|
||||
const Stream& s) {
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
auto copy_needed =
|
||||
std::any_of(strides.begin(), strides.end(), [](auto i) { return i < 0; });
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
/* const array& in = */ in,
|
||||
/* array& out = */ out,
|
||||
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||
/* const std::vector<stride_t>& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ data_offset,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::General,
|
||||
/* const Stream& s = */ s);
|
||||
} else {
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < strides.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
size_t data_end = 1;
|
||||
for (int i = 0; i < strides.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
size_t data_size = data_end - data_offset;
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
void concatenate_gpu(
|
||||
|
@ -36,15 +36,15 @@ void ternary_op_gpu_inplace(
|
||||
};
|
||||
auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse();
|
||||
|
||||
bool large = out.data_size() > UINT_MAX;
|
||||
bool large;
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread;
|
||||
if (topt == TernaryOpType::General) {
|
||||
large |=
|
||||
(a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX ||
|
||||
c.data_size() > UINT32_MAX);
|
||||
large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.size() > INT32_MAX;
|
||||
work_per_thread = large ? 4 : 2;
|
||||
} else {
|
||||
large = out.data_size() > INT32_MAX;
|
||||
work_per_thread = 1;
|
||||
}
|
||||
std::string kernel_name;
|
||||
|
@ -36,9 +36,11 @@ void unary_op_gpu_inplace(
|
||||
auto [shape, strides] = maybe_collapse();
|
||||
int ndim = shape.size();
|
||||
size_t nthreads = contig ? in.data_size() : in.size();
|
||||
bool large = in.data_size() > UINT32_MAX;
|
||||
bool large;
|
||||
if (!contig) {
|
||||
large |= in.size() > UINT32_MAX;
|
||||
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
|
||||
} else {
|
||||
large = in.data_size() > UINT32_MAX;
|
||||
}
|
||||
int work_per_thread = !contig && large ? 4 : 1;
|
||||
std::string kernel_name;
|
||||
|
@ -1758,6 +1758,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
y_mlx = mx.as_strided(x_mlx, shape, stride, offset)
|
||||
self.assertTrue(np.array_equal(y_npy, y_mlx))
|
||||
|
||||
x = mx.random.uniform(shape=(32,))
|
||||
y = mx.as_strided(x, (x.size,), (-1,), x.size - 1)
|
||||
self.assertTrue(mx.array_equal(y, x[::-1]))
|
||||
|
||||
def test_scans(self):
|
||||
a_npy = np.random.randn(32, 32, 32).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
|
Loading…
Reference in New Issue
Block a user