diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index 6a67b4f57..88edc6baa 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -278,7 +278,21 @@ void Compiled::eval_gpu( /* ndim = */ 0, /* dynamic_dims = */ false, /* use_big_index = */ false, - /* work_per_thread = */ work_per_thread); + /* work_per_thread = */ 1); + if (work_per_thread > 1) { + build_kernel( + kernel, + kernel_lib_ + "_contiguous_n", + inputs_, + outputs_, + tape_, + is_constant_, + /* contiguous = */ true, + /* ndim = */ 0, + /* dynamic_dims = */ false, + /* use_big_index = */ false, + /* work_per_thread = */ work_per_thread); + } build_kernel( kernel, kernel_lib_ + "_contiguous_large", @@ -358,12 +372,20 @@ void Compiled::eval_gpu( int ndim = shape.size(); bool dynamic = ndim >= 8; auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); + int work_per_thread = 1; if (!contiguous) { if (dynamic) { kernel_name += "dynamic"; } else { kernel_name += std::to_string(shape.size()); } + work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1; + } else { + work_per_thread = + get_work_per_thread(outputs[0].dtype(), outputs[0].data_size()); + if (work_per_thread > 1 && !large) { + kernel_name += "_n"; + } } if (large) { kernel_name += "_large"; @@ -420,7 +442,6 @@ void Compiled::eval_gpu( // Launch the kernel if (contiguous) { - int work_per_thread = get_work_per_thread(outputs[0].dtype()); size_t nthreads = ceildiv(outputs[0].data_size(), work_per_thread); MTL::Size group_dims( std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); @@ -433,7 +454,6 @@ void Compiled::eval_gpu( size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t rest = outputs[0].size() / (dim0 * dim1); - int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1; dim0 = (dim0 + work_per_thread - 1) / work_per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); int pow2; diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 8dfe15c11..8123b793e 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -55,10 +55,10 @@ void copy_gpu_inplace( std::string kernel_name; switch (ctype) { case CopyType::Scalar: - kernel_name = (large ? "s2" : "s"); + kernel_name = large ? "s2" : "s"; break; case CopyType::Vector: - kernel_name = (large ? "v2" : "v"); + kernel_name = large ? "v2" : "v"; break; case CopyType::General: kernel_name = "g"; @@ -85,7 +85,10 @@ void copy_gpu_inplace( } } } else { - work_per_thread = get_work_per_thread(in.dtype()); + work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); + if (work_per_thread > 1) { + kernel_name += "n"; + } } concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) @@ -170,9 +173,10 @@ void fill_gpu(const array& val, array& out, const Stream& s) { } out.set_data(allocator::malloc(out.nbytes())); bool large = out.data_size() > UINT32_MAX; + int work_per_thread = get_work_per_thread(out.dtype(), out.data_size()); auto& d = metal::device(s.device); - std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + - type_to_name(val) + type_to_name(out); + std::string kernel_name = large ? "s2" : (work_per_thread > 1 ? "sn" : "s"); + concatenate(kernel_name, "_copy", type_to_name(val), type_to_name(out)); auto kernel = get_copy_kernel(d, kernel_name, val, out); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -180,7 +184,6 @@ void fill_gpu(const array& val, array& out, const Stream& s) { compute_encoder.set_input_array(val, 0); compute_encoder.set_output_array(out, 1); - int work_per_thread = get_work_per_thread(val.dtype()); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index ec379717f..b6733c1b7 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -180,12 +180,17 @@ MTL::ComputePipelineState* get_copy_kernel( kernel_source += metal::copy(); auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); + kernel_source += get_template_definition( + "s_" + lib_name, "copy_s", in_type, out_type, 1); kernel_source += - get_template_definition("s_" + lib_name, "copy_s", in_type, out_type); + get_template_definition("sn_" + lib_name, "copy_s", in_type, out_type); kernel_source += get_template_definition("s2_" + lib_name, "copy_s2", in_type, out_type); + kernel_source += get_template_definition( + "v_" + lib_name, "copy_v", in_type, out_type, 1); kernel_source += - get_template_definition("v_" + lib_name, "copy_v", in_type, out_type); + get_template_definition("vn_" + lib_name, "copy_v", in_type, out_type); + kernel_source += get_template_definition("v2_" + lib_name, "copy_v2", in_type, out_type); diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index 2469d1f3d..cf22347ee 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -1,52 +1,76 @@ // Copyright © 2024 Apple Inc. -template ::n> +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - dst[index + i] = static_cast(src[0]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[0]); + } } } -template ::n> +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant uint& size, uint index [[thread_position_in_grid]]) { index *= N; - for (int i = 0; i < N && (index + i) < size; ++i) { - dst[index + i] = static_cast(src[index + i]); + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[index + i]); + } } } -template ::n> +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - dst[offset + i] = static_cast(src[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[0]); + } } } -template ::n> +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); - for (int i = 0; i < N && (offset + i) < size; ++i) { - dst[offset + i] = static_cast(src[offset + i]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } } } diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index bbf268158..fcf8884f8 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -4,9 +4,13 @@ #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/copy.h" -#define instantiate_copy_all(tname, itype, otype) \ - instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ - instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ +#define instantiate_copy_work_per_thread(tname, itype, otype) \ + instantiate_kernel("sn_copy" #tname, copy_s, itype, otype) \ + instantiate_kernel("vn_copy" #tname, copy_v, itype, otype) + +#define instantiate_copy_base(tname, itype, otype) \ + instantiate_kernel("s_copy" #tname, copy_s, itype, otype, 1) \ + instantiate_kernel("v_copy" #tname, copy_v, itype, otype, 1) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ @@ -18,6 +22,10 @@ instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) +#define instantiate_copy_all(tname, itype, otype) \ + instantiate_copy_base(tname, itype, otype) \ + instantiate_copy_work_per_thread(tname, itype, otype) + #define instantiate_copy_same(tname, type) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \ @@ -42,15 +50,15 @@ instantiate_copy_all(itname ##uint8, itype, uint8_t) \ instantiate_copy_all(itname ##uint16, itype, uint16_t) \ instantiate_copy_all(itname ##uint32, itype, uint32_t) \ - instantiate_copy_all(itname ##uint64, itype, uint64_t) \ + instantiate_copy_base(itname ##uint64, itype, uint64_t) \ instantiate_copy_all(itname ##int8, itype, int8_t) \ instantiate_copy_all(itname ##int16, itype, int16_t) \ instantiate_copy_all(itname ##int32, itype, int32_t) \ - instantiate_copy_all(itname ##int64, itype, int64_t) \ + instantiate_copy_base(itname ##int64, itype, int64_t) \ instantiate_copy_all(itname ##float16, itype, half) \ instantiate_copy_all(itname ##float32, itype, float) \ instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \ - instantiate_copy_all(itname ##complex64, itype, complex64_t) + instantiate_copy_base(itname ##complex64, itype, complex64_t) instantiate_copy_itype(bool_, bool) instantiate_copy_itype(uint8, uint8_t)