From 6bd28d246e5dbe80abfa38b006bbceb8b863ec18 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Dec 2024 08:59:45 -0800 Subject: [PATCH] Allow no copy negative strides in as_strided and slice (#1688) * allow no copy negative strides in as_strided and slice * fix jit * fix jit --- mlx/backend/common/primitives.cpp | 34 +++++----------------- mlx/backend/metal/binary.cpp | 6 ++-- mlx/backend/metal/indexing.cpp | 20 ++++++------- mlx/backend/metal/jit_kernels.cpp | 30 ++++++++++++------- mlx/backend/metal/kernels/binary.metal | 30 +++++++++---------- mlx/backend/metal/kernels/binary_two.metal | 30 +++++++++---------- mlx/backend/metal/kernels/reduce.metal | 12 ++++---- mlx/backend/metal/kernels/ternary.metal | 22 +++++++------- mlx/backend/metal/kernels/unary.metal | 6 ++-- mlx/backend/metal/kernels/utils.h | 25 +++++----------- mlx/backend/metal/reduce.cpp | 32 ++++++++++---------- mlx/backend/metal/slicing.cpp | 31 +++++--------------- mlx/backend/metal/ternary.cpp | 8 ++--- mlx/backend/metal/unary.cpp | 6 ++-- python/tests/test_ops.py | 4 +++ 15 files changed, 133 insertions(+), 163 deletions(-) diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 2f1f95054..c371eb7aa 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -507,34 +507,16 @@ void Slice::eval(const std::vector& inputs, array& out) { // Calculate out strides, initial offset and if copy needs to be made auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_); - auto copy_needed = std::any_of( - strides_.begin(), strides_.end(), [](auto i) { return i < 0; }); - - // Do copy if needed - if (copy_needed) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - Strides ostrides{out.strides().begin(), out.strides().end()}; - copy_inplace( - /* const array& src = */ in, - /* array& dst = */ out, - /* const std::vector& data_shape = */ out.shape(), - /* const std::vector& i_strides = */ inp_strides, - /* const std::vector& o_strides = */ ostrides, - /* int64_t i_offset = */ data_offset, - /* int64_t o_offset = */ 0, - /* CopyType ctype = */ CopyType::General); - } else { - size_t data_end = 1; - for (int i = 0; i < end_indices_.size(); ++i) { - if (in.shape()[i] > 1) { - auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1; - data_end += end_idx * in.strides()[i]; - } + size_t data_end = 1; + for (int i = 0; i < end_indices_.size(); ++i) { + if (in.shape()[i] > 1) { + auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1; + data_end += end_idx * in.strides()[i]; } - size_t data_size = data_end - data_offset; - Strides ostrides{inp_strides.begin(), inp_strides.end()}; - shared_buffer_slice(in, ostrides, data_offset, data_size, out); } + size_t data_size = data_end - data_offset; + Strides ostrides{inp_strides.begin(), inp_strides.end()}; + shared_buffer_slice(in, ostrides, data_offset, data_size, out); } void SliceUpdate::eval(const std::vector& inputs, array& out) { diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index aeb1df354..c7ebc9635 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -81,13 +81,15 @@ void binary_op_gpu_inplace( }; auto [shape, strides_a, strides_b, strides_out] = maybe_collapse(); - bool large = out.data_size() > UINT32_MAX; + bool large; auto ndim = shape.size(); int work_per_thread; if (bopt == BinaryOpType::General) { - large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX); + large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out.size() > INT32_MAX; work_per_thread = large ? 4 : 2; } else { + large = out.data_size() > UINT32_MAX; work_per_thread = 1; } std::string kernel_name = diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 85d7a711a..b27765a40 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -53,9 +53,9 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { int idx_ndim = nidx ? inputs[1].ndim() : 0; size_t ndim = src.ndim(); - bool large_index = nidx && inputs[1].size() > UINT32_MAX; - bool large_src = src.size() > UINT32_MAX; - bool large_out = out.size() > UINT32_MAX; + bool large_index = nidx && inputs[1].size() > INT32_MAX; + bool large_src = src.size() > INT32_MAX; + bool large_out = out.size() > INT32_MAX; bool large = large_index || large_src || large_out; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; @@ -65,7 +65,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { idx_type_name, nidx, idx_ndim, - large ? "int64_t" : "uint"); + large ? "int64_t" : "int"); std::string lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { @@ -86,7 +86,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { idx_args, idx_arr, idx_ndim, - large ? "int64_t" : "uint"); + large ? "int64_t" : "int"); return kernel_source; }); @@ -234,9 +234,9 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { break; } auto upd_contig = upd.flags().row_contiguous; - bool large_out = out.size() > UINT32_MAX; - bool large_idx = nidx && (inputs[1].size() > UINT32_MAX); - bool large_upd = upd.size() > UINT32_MAX; + bool large_out = out.size() > INT32_MAX; + bool large_idx = nidx && (inputs[1].size() > INT32_MAX); + bool large_upd = upd.size() > INT32_MAX; bool large = large_out || large_idx || large_upd; std::string kernel_name = fmt::format( "scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}", @@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { nidx, upd_contig ? "updc_true" : "updc_false", nwork, - large ? "int64_t" : "uint"); + large ? "int64_t" : "int"); std::string lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { @@ -290,7 +290,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { idx_arr, upd_contig, nwork, - large ? "int64_t" : "uint"); + large ? "int64_t" : "int"); return kernel_source; }); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index b09f920dd..f364cb8ee 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -52,7 +52,7 @@ MTL::ComputePipelineState* get_unary_kernel( kernel_source += get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op); kernel_source += get_template_definition( - "gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint"); + "gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "int"); kernel_source += get_template_definition( "gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4); return kernel_source; @@ -74,7 +74,7 @@ void append_binary_kernels( {"vs2", "binary_vs2"}, {"sv2", "binary_sv2"}, {"vv2", "binary_vv2"}, - {"g1", "binary_g_nd1"}, + {"g1large", "binary_g_nd1"}, {"g2large", "binary_g_nd2"}, {"g3large", "binary_g_nd3"}, }}; @@ -86,11 +86,13 @@ void append_binary_kernels( get_template_definition(name + "_" + lib_name, func, in_t, out_t, op); } kernel_source += get_template_definition( - "g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint"); + "g1_" + lib_name, "binary_g_nd1", in_t, out_t, op, "int"); kernel_source += get_template_definition( - "g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint"); + "g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "int"); kernel_source += get_template_definition( - "gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint"); + "g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "int"); + kernel_source += get_template_definition( + "gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "int"); kernel_source += get_template_definition( "gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4); } @@ -141,7 +143,7 @@ MTL::ComputePipelineState* get_ternary_kernel( const std::array, 5> kernel_types = {{ {"v", "ternary_v"}, {"v2", "ternary_v2"}, - {"g1", "ternary_g_nd1"}, + {"g1large", "ternary_g_nd1"}, {"g2large", "ternary_g_nd2"}, {"g3large", "ternary_g_nd3"}, }}; @@ -150,11 +152,13 @@ MTL::ComputePipelineState* get_ternary_kernel( get_template_definition(name + "_" + lib_name, func, t_str, op); } kernel_source += get_template_definition( - "g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint"); + "g1_" + lib_name, "ternary_g_nd1", t_str, op, "int"); kernel_source += get_template_definition( - "g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint"); + "g2_" + lib_name, "ternary_g_nd2", t_str, op, "int"); kernel_source += get_template_definition( - "gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint"); + "g3_" + lib_name, "ternary_g_nd3", t_str, op, "int"); + kernel_source += get_template_definition( + "gn2_" + lib_name, "ternary_g", t_str, op, 2, "int"); kernel_source += get_template_definition( "gn4large_" + lib_name, "ternary_g", t_str, op, 4); return kernel_source; @@ -178,7 +182,7 @@ MTL::ComputePipelineState* get_copy_kernel( kernel_source += get_template_definition("v_" + lib_name, "copy_v", in_type, out_type); kernel_source += get_template_definition( - "g1_" + lib_name, "copy_g_nd1", in_type, out_type); + "g1_" + lib_name, "copy_g_nd1", in_type, out_type, "int"); kernel_source += get_template_definition( "g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int"); kernel_source += get_template_definition( @@ -186,19 +190,23 @@ MTL::ComputePipelineState* get_copy_kernel( kernel_source += get_template_definition( "gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int"); kernel_source += get_template_definition( - "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type); + "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type, "int"); kernel_source += get_template_definition( "gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int"); kernel_source += get_template_definition( "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int"); kernel_source += get_template_definition( "ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int"); + kernel_source += get_template_definition( + "g1large_" + lib_name, "copy_g_nd1", in_type, out_type); kernel_source += get_template_definition( "g2large_" + lib_name, "copy_g_nd2", in_type, out_type); kernel_source += get_template_definition( "g3large_" + lib_name, "copy_g_nd3", in_type, out_type); kernel_source += get_template_definition( "gn4large_" + lib_name, "copy_g", in_type, out_type, 4); + kernel_source += get_template_definition( + "gg1large_" + lib_name, "copy_gg_nd1", in_type, out_type); kernel_source += get_template_definition( "gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type); kernel_source += get_template_definition( diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index ba5654703..3ef8e6269 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -9,21 +9,21 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary.h" -#define instantiate_binary_all(op, tname, itype, otype) \ - instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ - instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ - instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ - instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ - instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \ - instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ - instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \ - instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \ - instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \ - instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ - instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \ +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ + instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ + instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ + instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \ + instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ + instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \ + instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \ + instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ + instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) #define instantiate_binary_integer(op) \ diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index a12906f8e..984a28320 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -7,21 +7,21 @@ #include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_two.h" -#define instantiate_binary_all(op, tname, itype, otype) \ - instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ - instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ - instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ - instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ - instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ - instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ - instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ - instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \ - instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ - instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \ - instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \ - instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \ - instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \ - instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ +#define instantiate_binary_all(op, tname, itype, otype) \ + instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ + instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ + instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ + instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ + instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ + instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ + instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ + instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, int) \ + instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ + instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, int) \ + instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, int) \ + instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, int) \ + instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) #define instantiate_binary_float(op) \ diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index dc2ce157c..428f65012 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -53,10 +53,10 @@ instantiate_init_min_max(max, Max) #define instantiate_col_reduce_small(name, itype, otype, op, dim) \ instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \ col_reduce_small, \ - itype, otype, op, uint, dim) \ + itype, otype, op, int, dim) \ instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \ col_reduce_longcolumn, \ - itype, otype, op, uint, dim) \ + itype, otype, op, int, dim) \ instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \ col_reduce_small, \ itype, otype, op, int64_t, dim) \ @@ -67,7 +67,7 @@ instantiate_init_min_max(max, Max) #define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \ col_reduce_looped, \ - itype, otype, op, uint, dim, bm, bn) \ + itype, otype, op, int, dim, bm, bn) \ instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \ col_reduce_looped, \ itype, otype, op, int64_t, dim, bm, bn) @@ -75,7 +75,7 @@ instantiate_init_min_max(max, Max) #define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \ instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \ col_reduce_2pass, \ - itype, otype, op, uint, dim, bm, bn) \ + itype, otype, op, int, dim, bm, bn) \ instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \ col_reduce_2pass, \ itype, otype, op, int64_t, dim, bm, bn) @@ -95,7 +95,7 @@ instantiate_init_min_max(max, Max) #define instantiate_row_reduce_small(name, itype, otype, op, dim) \ instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \ row_reduce_small, \ - itype, otype, op, uint, dim) \ + itype, otype, op, int, dim) \ instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \ row_reduce_small, \ itype, otype, op, int64_t, dim) @@ -103,7 +103,7 @@ instantiate_init_min_max(max, Max) #define instantiate_row_reduce_looped(name, itype, otype, op, dim) \ instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \ row_reduce_looped, \ - itype, otype, op, uint, dim) \ + itype, otype, op, int, dim) \ instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \ row_reduce_looped, \ itype, otype, op, int64_t, dim) diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index d188c8ff4..761faa9b8 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -8,17 +8,17 @@ #include "mlx/backend/metal/kernels/ternary_ops.h" #include "mlx/backend/metal/kernels/ternary.h" -#define instantiate_ternary_all(op, tname, type) \ - instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ - instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ - instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \ - instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \ - instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \ - instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \ - instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \ - instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \ - instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ - instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ +#define instantiate_ternary_all(op, tname, type) \ + instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ + instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ + instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, int) \ + instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \ + instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \ + instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \ + instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \ + instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \ + instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ + instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ #define instantiate_ternary_types(op) \ instantiate_ternary_all(op, bool_, bool) \ diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index b196f023f..29220c53b 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -9,19 +9,19 @@ instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \ instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \ instantiate_kernel( \ - "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, uint) \ + "gn1_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 1, int) \ instantiate_kernel( \ "gn4large" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4) #define instantiate_unary_all_same(op, tname, type) \ instantiate_unary_all(op, tname, tname, type, type) -#define instantiate_unary_float(op) \ +#define instantiate_unary_float(op) \ instantiate_unary_all_same(op, float16, half) \ instantiate_unary_all_same(op, float32, float) \ instantiate_unary_all_same(op, bfloat16, bfloat16_t) -#define instantiate_unary_types(op) \ +#define instantiate_unary_types(op) \ instantiate_unary_all_same(op, bool_, bool) \ instantiate_unary_all_same(op, uint8, uint8_t) \ instantiate_unary_all_same(op, uint16, uint16_t) \ diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index d96656075..b31cd20d6 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -91,21 +91,7 @@ struct Limits { template METAL_FUNC IdxT elem_to_loc( - uint elem, - constant const int* shape, - constant const int64_t* strides, - int ndim) { - IdxT loc = 0; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - loc += (elem % shape[i]) * IdxT(strides[i]); - elem /= shape[i]; - } - return loc; -} - -template -METAL_FUNC IdxT elem_to_loc( - int64_t elem, + IdxT elem, constant const int* shape, constant const int64_t* strides, int ndim) { @@ -187,9 +173,12 @@ METAL_FUNC vec elem_to_loc_3_nd( constant const int64_t* c_strides, int ndim) { vec loc = { - elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]), - elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]), - elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])}; + IdxT(elem.x * IdxT(a_strides[ndim - 1])) + + IdxT(elem.y * IdxT(a_strides[ndim - 2])), + IdxT(elem.x * IdxT(b_strides[ndim - 1])) + + IdxT(elem.y * IdxT(b_strides[ndim - 2])), + IdxT(elem.x * IdxT(c_strides[ndim - 1])) + + IdxT(elem.y * IdxT(c_strides[ndim - 2]))}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; loc.x += l * IdxT(a_strides[d]); diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 6720051d5..05deff057 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -393,7 +393,7 @@ void row_reduce_small( auto [in_type, out_type] = remap_reduce_types(in, op_name); const std::string func_name = "row_reduce_small"; std::string kname = func_name; - bool large = in.size() > UINT32_MAX; + bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } @@ -411,7 +411,7 @@ void row_reduce_small( op_name, in_type, out_type, - large ? "size_t" : "uint", + large ? "size_t" : "int", n); compute_encoder.set_compute_pipeline_state(kernel); @@ -490,7 +490,7 @@ void row_reduce_looped( int n = get_kernel_reduce_ndim(args.reduce_ndim); const std::string func_name = "row_reduce_looped"; std::string kname = func_name; - bool large = in.size() > UINT32_MAX; + bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } @@ -508,7 +508,7 @@ void row_reduce_looped( op_name, in_type, out_type, - large ? "size_t" : "uint", + large ? "size_t" : "int", n); compute_encoder.set_compute_pipeline_state(kernel); @@ -574,7 +574,7 @@ void strided_reduce_small( int n = get_kernel_reduce_ndim(args.reduce_ndim); const std::string func_name = "col_reduce_small"; std::string kname = func_name; - bool large = in.size() > UINT32_MAX; + bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } @@ -592,7 +592,7 @@ void strided_reduce_small( op_name, in_type, out_type, - large ? "size_t" : "uint", + large ? "size_t" : "int", n); compute_encoder.set_compute_pipeline_state(kernel); @@ -665,7 +665,7 @@ void strided_reduce_longcolumn( int n = get_kernel_reduce_ndim(args.reduce_ndim); std::string func_name = "col_reduce_longcolumn"; std::string kname = func_name; - bool large = in.size() > UINT32_MAX; + bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } @@ -683,7 +683,7 @@ void strided_reduce_longcolumn( op_name, in_type, out_type, - large ? "int64_t" : "uint", + large ? "int64_t" : "int", n); compute_encoder.set_compute_pipeline_state(kernel); @@ -706,7 +706,7 @@ void strided_reduce_longcolumn( // Set the 2nd kernel func_name = "col_reduce_looped"; kname = func_name; - large = intermediate.size() > UINT32_MAX; + large = intermediate.size() > INT32_MAX; if (large) { kname += "_large"; } @@ -718,7 +718,7 @@ void strided_reduce_longcolumn( op_name, intermediate.dtype(), out_type, - large ? "int64_t" : "uint", + large ? "int64_t" : "int", 1, 32, 32); @@ -760,7 +760,7 @@ void strided_reduce_looped( int n = get_kernel_reduce_ndim(args.reduce_ndim); std::string func_name = "col_reduce_looped"; std::string kname = func_name; - bool large = in.size() > UINT32_MAX; + bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } @@ -782,7 +782,7 @@ void strided_reduce_looped( op_name, in_type, out_type, - large ? "int64_t" : "uint", + large ? "int64_t" : "int", n, BM, BN); @@ -837,7 +837,7 @@ void strided_reduce_2pass( int n = get_kernel_reduce_ndim(args.reduce_ndim); std::string func_name = "col_reduce_2pass"; std::string kname = func_name; - bool large = in.size() > UINT32_MAX; + bool large = in.size() > INT32_MAX; if (large) { kname += "_large"; } @@ -859,7 +859,7 @@ void strided_reduce_2pass( op_name, in_type, out_type, - large ? "int64_t" : "uint", + large ? "int64_t" : "int", n, BM, BN); @@ -882,7 +882,7 @@ void strided_reduce_2pass( // Set the 2nd kernel func_name = "col_reduce_looped"; kname = func_name; - large = intermediate.size() > UINT32_MAX; + large = intermediate.size() > INT32_MAX; if (large) { kname += "_large"; } @@ -894,7 +894,7 @@ void strided_reduce_2pass( op_name, intermediate.dtype(), out_type, - large ? "int64_t" : "uint", + large ? "int64_t" : "int", 1, 32, 32); diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index f2241f607..26664b78e 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -16,33 +16,16 @@ void slice_gpu( const Stream& s) { // Calculate out strides, initial offset and if copy needs to be made auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides); - auto copy_needed = - std::any_of(strides.begin(), strides.end(), [](auto i) { return i < 0; }); - // Do copy if needed - if (copy_needed) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); - copy_gpu_inplace( - /* const array& in = */ in, - /* array& out = */ out, - /* const std::vector& data_shape = */ out.shape(), - /* const std::vector& i_strides = */ inp_strides, - /* const std::vector& o_strides = */ out.strides(), - /* int64_t i_offset = */ data_offset, - /* int64_t o_offset = */ 0, - /* CopyType ctype = */ CopyType::General, - /* const Stream& s = */ s); - } else { - size_t data_end = 1; - for (int i = 0; i < strides.size(); ++i) { - if (in.shape()[i] > 1) { - auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1; - data_end += end_idx * in.strides()[i]; - } + size_t data_end = 1; + for (int i = 0; i < strides.size(); ++i) { + if (in.shape()[i] > 1) { + auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1; + data_end += end_idx * in.strides()[i]; } - size_t data_size = data_end - data_offset; - shared_buffer_slice(in, inp_strides, data_offset, data_size, out); } + size_t data_size = data_end - data_offset; + shared_buffer_slice(in, inp_strides, data_offset, data_size, out); } void concatenate_gpu( diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index d44a5151e..6c55ca02d 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -36,15 +36,15 @@ void ternary_op_gpu_inplace( }; auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse(); - bool large = out.data_size() > UINT_MAX; + bool large; auto ndim = shape.size(); int work_per_thread; if (topt == TernaryOpType::General) { - large |= - (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX || - c.data_size() > UINT32_MAX); + large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + c.data_size() > INT32_MAX || out.size() > INT32_MAX; work_per_thread = large ? 4 : 2; } else { + large = out.data_size() > INT32_MAX; work_per_thread = 1; } std::string kernel_name; diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index ce7dc969b..1496bcb5d 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -36,9 +36,11 @@ void unary_op_gpu_inplace( auto [shape, strides] = maybe_collapse(); int ndim = shape.size(); size_t nthreads = contig ? in.data_size() : in.size(); - bool large = in.data_size() > UINT32_MAX; + bool large; if (!contig) { - large |= in.size() > UINT32_MAX; + large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; + } else { + large = in.data_size() > UINT32_MAX; } int work_per_thread = !contig && large ? 4 : 1; std::string kernel_name; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 7763a2e58..9a5918011 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1758,6 +1758,10 @@ class TestOps(mlx_tests.MLXTestCase): y_mlx = mx.as_strided(x_mlx, shape, stride, offset) self.assertTrue(np.array_equal(y_npy, y_mlx)) + x = mx.random.uniform(shape=(32,)) + y = mx.as_strided(x, (x.size,), (-1,), x.size - 1) + self.assertTrue(mx.array_equal(y, x[::-1])) + def test_scans(self): a_npy = np.random.randn(32, 32, 32).astype(np.float32) a_mlx = mx.array(a_npy)