diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index f80f8c3e4..5943c5b27 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -90,7 +90,7 @@ void binary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > UINT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(a.dtype()); } std::string kernel_name = get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread); @@ -137,13 +137,20 @@ void binary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(a.size(), arg_idx++); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(a.size(), arg_idx++); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 3399201de..ee004359f 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -104,6 +104,8 @@ void copy_gpu_inplace( "[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy"); } } + } else { + work_per_thread = get_work_per_thread(in.dtype()); } concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) @@ -165,13 +167,19 @@ void copy_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } @@ -214,14 +222,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) { compute_encoder.set_input_array(val, 0); compute_encoder.set_output_array(out, 1); + int work_per_thread = get_work_per_thread(val.dtype()); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - size_t nthreads = out.data_size(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index 91a02c818..ffc33ad82 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -9,64 +9,85 @@ template c[index] = Op()(a[0], b[0]); } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[0]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[0], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index 8f6b3392d..e261d33c4 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -12,82 +12,103 @@ template d[index] = out[1]; } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[0], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[0]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[0], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[0]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } } template diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index b1367cf4f..2469d1f3d 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -1,39 +1,53 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[0]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + dst[index + i] = static_cast(src[0]); + } } -template +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } } -template +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[0]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } } -template +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 4b3adcc80..8debf2267 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -1,25 +1,32 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, device const T* c, + constant uint& size, device T* d, uint index [[thread_position_in_grid]]) { - d[index] = Op()(a[index], b[index], c[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } } -template +template ::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, device const T* c, device T* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - d[offset] = Op()(a[offset], b[offset], c[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } } template diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index 69828599f..b5eaab2e9 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -1,21 +1,28 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void unary_v( device const T* in, device U* out, + constant uint& size, uint index [[thread_position_in_grid]]) { - out[index] = Op()(in[index]); + index *= N; + for (int i = 0; i < N && (index + i) < size; ++i) { + out[index + i] = Op()(in[index + i]); + } } -template +template ::n> [[kernel]] void unary_v2( device const T* in, device U* out, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - out[offset] = Op()(in[offset]); + auto offset = N * (index.x + grid_dim.x * int64_t(index.y)); + for (int i = 0; i < N && (offset + i) < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } } template < diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 1170d5576..0734c2326 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -15,6 +15,13 @@ typedef half float16_t; +// Work per thread values for different types +template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index 36bfd3e2b..d59968453 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -45,7 +45,7 @@ void ternary_op_gpu_inplace( work_per_thread = large ? 4 : 2; } else { large = out.data_size() > INT32_MAX; - work_per_thread = 1; + work_per_thread = get_work_per_thread(b.dtype()); } std::string kernel_name; if (topt == TernaryOpType::General) { @@ -106,13 +106,20 @@ void ternary_op_gpu_inplace( compute_encoder.dispatch_threads(grid_dims, group_dims); } else { // Launch a 1D or 2D grid of threads - size_t nthreads = out.data_size(); + auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + size_t nthreads = ceildiv(out.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(out.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index be43c41c2..368e693a9 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -34,18 +34,19 @@ void unary_op_gpu_inplace( }; auto [shape, strides] = maybe_collapse(); int ndim = shape.size(); - size_t nthreads = contig ? in.data_size() : in.size(); bool large; if (!contig) { large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; } else { large = in.data_size() > UINT32_MAX; } - int work_per_thread = !contig && large ? 4 : 1; + int work_per_thread; std::string kernel_name; if (contig) { + work_per_thread = get_work_per_thread(in.dtype()); kernel_name = (large ? "v2" : "v"); } else { + work_per_thread = large ? 4 : 1; kernel_name = "gn" + std::to_string(work_per_thread); if (large) { kernel_name += "large"; @@ -75,12 +76,20 @@ void unary_op_gpu_inplace( MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder.dispatch_threads(grid_dims, group_dims); } else { + size_t nthreads = ceildiv(in.data_size(), work_per_thread); if (thread_group_size > nthreads) { thread_group_size = nthreads; } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides()) - : MTL::Size(nthreads, 1, 1); + MTL::Size grid_dims; + if (large) { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = get_2d_grid_dims(out.shape(), out.strides(), work_per_thread); + } else { + compute_encoder.set_bytes(in.data_size(), 2); + grid_dims = MTL::Size(nthreads, 1, 1); + } compute_encoder.dispatch_threads(grid_dims, group_dims); } } diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 079d15f17..1a49b4f51 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -84,4 +84,12 @@ void concatenate(std::string& acc, T first, Args... args) { concatenate(acc, args...); } +inline int get_work_per_thread(Dtype dtype) { + return std::max(1, 8 / dtype.size()); +} + +inline int ceildiv(int n, int m) { + return (n + m - 1) / m; +} + } // namespace mlx::core