From 67b6bf530d1ad895495dc48650afb172a79c3a39 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 17 Sep 2024 17:59:51 -0700 Subject: [PATCH] Optimization for general ND copies (#1421) --- mlx/backend/metal/copy.cpp | 23 +++++++++------ mlx/backend/metal/jit_kernels.cpp | 6 +++- mlx/backend/metal/kernels/copy.h | 43 ++++++++++++++++++++++------ mlx/backend/metal/kernels/copy.metal | 4 ++- mlx/backend/metal/kernels/utils.h | 9 +++--- 5 files changed, 62 insertions(+), 23 deletions(-) diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 19596f6b4..a58e4c467 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -73,9 +73,11 @@ void copy_gpu_inplace( } }; auto [shape, strides_in_, strides_out_] = maybe_collapse(); + int ndim = shape.size(); bool use_2d = out.data_size() > UINT32_MAX; auto& d = metal::device(s.device); + int work_per_thread = 1; std::string kernel_name; { std::ostringstream kname; @@ -93,9 +95,13 @@ void copy_gpu_inplace( kname << "gg"; break; } - if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) && - shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { - kname << shape.size(); + if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { + if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { + kname << shape.size(); + } else if (shape[ndim - 1] >= 4) { + work_per_thread = 4; + kname << "n4"; + } } kname << "_copy"; kname << type_to_name(in) << type_to_name(out); @@ -115,10 +121,8 @@ void copy_gpu_inplace( compute_encoder.set_output_array(out, 1, out_offset); if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { - int ndim = shape.size(); std::vector strides_in{strides_in_.begin(), strides_in_.end()}; std::vector strides_out{strides_out_.begin(), strides_out_.end()}; - if (ndim > 3) { set_vector_bytes(compute_encoder, shape, ndim, 2); } @@ -127,10 +131,6 @@ void copy_gpu_inplace( set_vector_bytes(compute_encoder, strides_out, ndim, 4); } - if (ndim > MAX_COPY_SPECIALIZED_DIMS) { - compute_encoder->setBytes(&ndim, sizeof(int), 5); - } - int dim0 = ndim > 0 ? shape[ndim - 1] : 1; int dim1 = ndim > 1 ? shape[ndim - 2] : 1; @@ -139,6 +139,11 @@ void copy_gpu_inplace( data_size *= s; int rest = data_size / (dim0 * dim1); + if (ndim > MAX_COPY_SPECIALIZED_DIMS) { + compute_encoder->setBytes(&ndim, sizeof(int), 5); + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + } + // NB assuming thread_group_size is a power of 2 larger than 32 x 32 NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size != 1024) { diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 2c22e9668..74957f150 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -176,6 +176,8 @@ MTL::ComputePipelineState* get_copy_kernel( << get_template_definition( "g3_" + lib_name, "copy_g_nd3", in_type, out_type) << get_template_definition("g_" + lib_name, "copy_g", in_type, out_type) + << get_template_definition( + "gn4_" + lib_name, "copy_g", in_type, out_type, 4) << get_template_definition( "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type) << get_template_definition( @@ -183,7 +185,9 @@ MTL::ComputePipelineState* get_copy_kernel( << get_template_definition( "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type) << get_template_definition( - "gg_" + lib_name, "copy_gg", in_type, out_type); + "gg_" + lib_name, "copy_gg", in_type, out_type) + << get_template_definition( + "ggn4_" + lib_name, "copy_gg", in_type, out_type, 4); lib = d.get_library(lib_name, kernel_source.str()); } return d.get_kernel(kernel_name, lib); diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index 2d836ff65..914aebfd6 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -71,7 +71,7 @@ template dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_g( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], @@ -80,10 +80,22 @@ template constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); + auto src_idx = elem_to_loc( + {N * index.x, index.y, index.z}, src_shape, src_strides, ndim); + if (N == 1) { + int64_t dst_idx = + index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z); + dst[dst_idx] = static_cast(src[src_idx]); + return; + } + auto xshape = src_shape[ndim - 1]; int64_t dst_idx = - index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); - dst[dst_idx] = static_cast(src[src_idx]); + N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z); + auto src_xstride = src_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[dst_idx + i] = static_cast(src[src_idx]); + src_idx += src_xstride; + } } template @@ -122,7 +134,7 @@ template dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_gg( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], @@ -131,7 +143,22 @@ template constant const int64_t* dst_strides [[buffer(4)]], constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); - auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim); - dst[dst_idx] = static_cast(src[src_idx]); + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, + src_shape, + src_strides, + dst_strides, + ndim); + if (N == 1) { + dst[idx.y] = static_cast(src[idx.x]); + return; + } + auto src_xstride = src_strides[ndim - 1]; + auto dst_xstride = dst_strides[ndim - 1]; + auto xshape = src_shape[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[idx.y] = static_cast(src[idx.x]); + idx.x += src_xstride; + idx.y += dst_xstride; + } } diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index 76cfbb867..a631183b7 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -17,7 +17,9 @@ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \ instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \ instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \ - instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) + instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \ + instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) \ + instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4) #define instantiate_copy_itype(itname, itype) \ instantiate_copy_all(itname ##bool_, itype, bool) \ diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 0ec69e191..721c094ca 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -149,15 +149,16 @@ elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) { /////////////////////////////////////////////////////////////////////////////// // Multiple Arrays with generic dims +template METAL_FUNC ulong2 elem_to_loc_2_nd( uint3 elem, constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, + constant const stride_t* a_strides, + constant const stride_t* b_strides, int ndim) { ulong2 loc = { - elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2], - elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]}; + ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), + ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; loc.x += l * a_strides[d];