diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 585183967..8e3015790 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -85,6 +85,7 @@ void binary_op_gpu_inplace( auto ndim = shape.size(); int work_per_thread; if (bopt == BinaryOpType::General) { + large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX); work_per_thread = large ? 4 : 2; } else { work_per_thread = 1; diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index f7b2fd865..60d63a5ac 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -77,7 +77,7 @@ void copy_gpu_inplace( bool large; if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { // Allow for negative strides - large = out.data_size() > INT32_MAX; + large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; } else { large = out.data_size() > UINT32_MAX; } @@ -134,13 +134,13 @@ void copy_gpu_inplace( compute_encoder.set_vector_bytes(strides_out, ndim, 4); } - int dim0 = ndim > 0 ? shape[ndim - 1] : 1; - int dim1 = ndim > 1 ? shape[ndim - 2] : 1; + size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; + size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t data_size = 1; for (auto& s : shape) data_size *= s; - int rest = data_size / (dim0 * dim1); + size_t rest = data_size / (dim0 * dim1); if (ndim > MAX_COPY_SPECIALIZED_DIMS) { compute_encoder.set_bytes(ndim, 5); diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 4b260bc30..d9f0a9710 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -69,7 +69,7 @@ template c[offset] = Op()(a[offset], b[offset]); } -template +template [[kernel]] void binary_g_nd1( device const T* a, device const T* b, @@ -77,8 +77,8 @@ template constant const size_t& a_stride, constant const size_t& b_stride, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); c[index] = Op()(a[a_idx], b[b_idx]); } diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal index a9d7044d8..ba5654703 100644 --- a/mlx/backend/metal/kernels/binary.metal +++ b/mlx/backend/metal/kernels/binary.metal @@ -19,7 +19,8 @@ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \ instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ - instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \ + instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \ instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \ diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index 6057dd41b..17dfb0f62 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -90,7 +90,7 @@ template d[offset] = out[1]; } -template +template [[kernel]] void binary_g_nd1( device const T* a, device const T* b, @@ -99,8 +99,8 @@ template constant const size_t& a_stride, constant const size_t& b_stride, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); auto out = Op()(a[a_idx], b[b_idx]); c[index] = out[0]; d[index] = out[1]; diff --git a/mlx/backend/metal/kernels/binary_two.metal b/mlx/backend/metal/kernels/binary_two.metal index da9ac3a5d..a12906f8e 100644 --- a/mlx/backend/metal/kernels/binary_two.metal +++ b/mlx/backend/metal/kernels/binary_two.metal @@ -17,9 +17,10 @@ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \ instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \ - instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ + instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \ instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \ instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \ + instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op) diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index 2113c825a..7b664b3a4 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -36,13 +36,13 @@ template dst[offset] = static_cast(src[offset]); } -template +template [[kernel]] void copy_g_nd1( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t& src_stride [[buffer(3)]], uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); + auto src_idx = elem_to_loc_1(index, src_stride); dst[index] = static_cast(src[src_idx]); } @@ -97,15 +97,15 @@ template } } -template +template [[kernel]] void copy_gg_nd1( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t& src_stride [[buffer(3)]], constant const int64_t& dst_stride [[buffer(4)]], uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); - auto dst_idx = elem_to_loc_1(index, dst_stride); + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); dst[dst_idx] = static_cast(src[src_idx]); } diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index 5c444b30d..298b48fe9 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -9,16 +9,18 @@ instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ - instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \ + instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \ instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \ - instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \ + instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \ instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \ instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \ instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \ + instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \ instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \ instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ + instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \ instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \ instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \ instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \ diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index e19ea23df..3cf776ad5 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -22,7 +22,7 @@ template d[offset] = Op()(a[offset], b[offset], c[offset]); } -template +template [[kernel]] void ternary_g_nd1( device const bool* a, device const T* b, @@ -32,9 +32,9 @@ template constant const size_t& b_strides, constant const size_t& c_strides, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_strides); - auto b_idx = elem_to_loc_1(index, b_strides); - auto c_idx = elem_to_loc_1(index, c_strides); + auto a_idx = elem_to_loc_1(index, a_strides); + auto b_idx = elem_to_loc_1(index, b_strides); + auto c_idx = elem_to_loc_1(index, c_strides); d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); } diff --git a/mlx/backend/metal/kernels/ternary.metal b/mlx/backend/metal/kernels/ternary.metal index a509dacce..d188c8ff4 100644 --- a/mlx/backend/metal/kernels/ternary.metal +++ b/mlx/backend/metal/kernels/ternary.metal @@ -12,9 +12,10 @@ instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \ instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 1, uint) \ - instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \ + instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, uint) \ instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, uint) \ instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, uint) \ + instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \ instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \ instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \ instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \ diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 50c4c8cce..f81ea9240 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -40,6 +40,9 @@ void ternary_op_gpu_inplace( auto ndim = shape.size(); int work_per_thread; if (topt == TernaryOpType::General) { + large |= + (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX || + c.data_size() > UINT32_MAX); work_per_thread = large ? 4 : 2; } else { work_per_thread = 1; diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index c1004df22..d1d833dc7 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -36,7 +36,10 @@ void unary_op_gpu_inplace( auto [shape, strides] = maybe_collapse(); int ndim = shape.size(); size_t nthreads = contig ? in.data_size() : in.size(); - bool large = nthreads > UINT32_MAX; + bool large = in.data_size() > UINT32_MAX; + if (!contig) { + large |= in.size() > UINT32_MAX; + } int work_per_thread = !contig && large ? 4 : 1; std::string kernel_name; if (contig) {