From 41c603d48a5ff79f291868ca5c1e098fb3653876 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 4 Sep 2024 14:03:10 -0700 Subject: [PATCH] fix jit reduce (#1395) --- mlx/backend/metal/jit_kernels.cpp | 33 +++++----- mlx/backend/metal/kernels.h | 6 +- mlx/backend/metal/kernels/reduce.metal | 64 +++++++++---------- .../metal/kernels/reduction/reduce_col.h | 10 +-- .../metal/kernels/reduction/reduce_row.h | 4 +- mlx/backend/metal/nojit_kernels.cpp | 6 +- mlx/backend/metal/reduce.cpp | 40 +++++++----- 7 files changed, 88 insertions(+), 75 deletions(-) diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 5a2a4e88e..346954be4 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -337,35 +337,36 @@ MTL::ComputePipelineState* get_reduce_init_kernel( MTL::ComputePipelineState* get_reduce_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& func_name, const std::string& op_name, const array& in, - const array& out) { - std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); - auto lib = d.get_library(lib_name); + const array& out, + int ndim /* = -1 */, + int bm /* = -1 */, + int bn /* = -1 */) { + auto lib = d.get_library(kernel_name); if (lib == nullptr) { std::string op_type = op_name; op_type[0] = std::toupper(op_name[0]); - bool non_atomic = out.dtype() == int64 || out.dtype() == uint64; std::ostringstream kernel_source; auto in_type = get_type_string(in.dtype()); auto out_type = get_type_string(out.dtype()); - std::vector> reduce_kernels = { - {"all_reduce", "allReduce"}, - {"col_reduce_small", "colReduceSmall"}, - {"col_reduce_looped", "colReduceLooped"}, - {"row_reduce_small", "rowReduceSmall"}, - {"row_reduce_looped", "rowReduceLooped"}, - {"row_reduce_simple", "rowReduceSimple"}}; std::string op = op_type + "<" + out_type + ">"; kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); - for (auto [func, name] : reduce_kernels) { + if (bm >= 0) { kernel_source << get_template_definition( - name + "_" + lib_name, func, in_type, out_type, op); + kernel_name, func_name, in_type, out_type, op, ndim, bm, bn); + } else if (ndim >= 0) { + kernel_source << get_template_definition( + kernel_name, func_name, in_type, out_type, op, ndim); + } else { + kernel_source << get_template_definition( + kernel_name, func_name, in_type, out_type, op); } - - lib = d.get_library(lib_name, kernel_source.str()); + lib = d.get_library(kernel_name, kernel_source.str()); } - return d.get_kernel(kernel_name, lib); + auto st = d.get_kernel(kernel_name, lib); + return st; } MTL::ComputePipelineState* get_steel_gemm_fused_kernel( diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 44f260777..51d5121d8 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -83,9 +83,13 @@ MTL::ComputePipelineState* get_reduce_init_kernel( MTL::ComputePipelineState* get_reduce_kernel( metal::Device& d, const std::string& kernel_name, + const std::string& func_name, const std::string& op_name, const array& in, - const array& out); + const array& out, + int ndim = -1, + int bm = -1, + int bn = -1); MTL::ComputePipelineState* get_steel_gemm_fused_kernel( metal::Device& d, diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index acfe6b64f..2b55bd9a6 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -82,9 +82,9 @@ otype, \ op) -#define instantiate_init_reduce(name, otype, op) \ - instantiate_kernel("init_reduce_" #name, \ - init_reduce, \ +#define instantiate_init_reduce(name, otype, op) \ + instantiate_kernel("init_reduce_" #name, \ + init_reduce, \ otype, op) #define instantiate_init_reduce_helper(name, tname, type, op) \ @@ -96,9 +96,9 @@ instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper instantiate_init_reduce(andbool_, bool, And) instantiate_init_reduce(orbool_, bool, Or) -#define instantiate_all_reduce(name, itype, otype, op) \ - instantiate_kernel("allReduce_" #name, \ - all_reduce, \ +#define instantiate_all_reduce(name, itype, otype, op) \ + instantiate_kernel("all_reduce_" #name, \ + all_reduce, \ itype, otype, op) #define instantiate_same_all_reduce_helper(name, tname, type, op) \ @@ -114,16 +114,16 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or) instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) #define instantiate_col_reduce_small(name, itype, otype, op, dim) \ - instantiate_kernel("colReduceSmall_" #dim "_reduce_" #name, \ - col_reduce_small, \ + instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \ + col_reduce_small, \ itype, otype, op, dim) -#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ - instantiate_kernel("colReduceLooped_" #dim "_" #bm "_" #bn "_reduce_" #name, \ - col_reduce_looped, \ +#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ + instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \ + col_reduce_looped, \ itype, otype, op, dim, bm, bn) -#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ +#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \ instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) @@ -139,7 +139,7 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum) instantiate_col_reduce_looped(name, itype, otype, op, 3) \ instantiate_col_reduce_looped(name, itype, otype, op, 4) -#define instantiate_same_col_reduce_helper(name, tname, type, op) \ +#define instantiate_same_col_reduce_helper(name, tname, type, op) \ instantiate_col_reduce_general(name##tname, type, type, op) instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types) @@ -149,32 +149,32 @@ instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum) instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And) instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or) -#define instantiate_row_reduce_small(name, itype, otype, op, dim) \ - instantiate_kernel("rowReduceSmall_" #dim "_reduce_" #name, \ - row_reduce_small, \ +#define instantiate_row_reduce_small(name, itype, otype, op, dim) \ + instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \ + row_reduce_small, \ itype, otype, op, dim) -#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \ - instantiate_kernel("rowReduceLooped_" #dim "_reduce_" #name, \ +#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \ + instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \ row_reduce_looped, \ itype, otype, op, dim) -#define instantiate_row_reduce_general(name, itype, otype, op) \ - instantiate_row_reduce_small(name, itype, otype, op, 0) \ - instantiate_row_reduce_small(name, itype, otype, op, 1) \ - instantiate_row_reduce_small(name, itype, otype, op, 2) \ - instantiate_row_reduce_small(name, itype, otype, op, 3) \ - instantiate_row_reduce_small(name, itype, otype, op, 4) \ - instantiate_row_reduce_looped(name, itype, otype, op, 0) \ - instantiate_row_reduce_looped(name, itype, otype, op, 1) \ - instantiate_row_reduce_looped(name, itype, otype, op, 2) \ - instantiate_row_reduce_looped(name, itype, otype, op, 3) \ - instantiate_row_reduce_looped(name, itype, otype, op, 4) \ - instantiate_kernel("rowReduceSimple_" #name, \ - row_reduce_simple, \ +#define instantiate_row_reduce_general(name, itype, otype, op) \ + instantiate_row_reduce_small(name, itype, otype, op, 0) \ + instantiate_row_reduce_small(name, itype, otype, op, 1) \ + instantiate_row_reduce_small(name, itype, otype, op, 2) \ + instantiate_row_reduce_small(name, itype, otype, op, 3) \ + instantiate_row_reduce_small(name, itype, otype, op, 4) \ + instantiate_row_reduce_looped(name, itype, otype, op, 0) \ + instantiate_row_reduce_looped(name, itype, otype, op, 1) \ + instantiate_row_reduce_looped(name, itype, otype, op, 2) \ + instantiate_row_reduce_looped(name, itype, otype, op, 3) \ + instantiate_row_reduce_looped(name, itype, otype, op, 4) \ + instantiate_kernel("row_reduce_simple_" #name, \ + row_reduce_simple, \ itype, otype, op) -#define instantiate_same_row_reduce_helper(name, tname, type, op) \ +#define instantiate_same_row_reduce_helper(name, tname, type, op) \ instantiate_row_reduce_general(name##tname, type, type, op) instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types) diff --git a/mlx/backend/metal/kernels/reduction/reduce_col.h b/mlx/backend/metal/kernels/reduction/reduce_col.h index f92562b84..52e763ddc 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_col.h +++ b/mlx/backend/metal/kernels/reduction/reduce_col.h @@ -4,7 +4,7 @@ template < typename T, typename U, typename Op, - int NDIMS = 0, + int NDIMS, int N_READS = REDUCE_N_READS> [[kernel]] void col_reduce_small( const device T* in [[buffer(0)]], @@ -198,13 +198,7 @@ template < * totals with a loop. * 7. Write them to the output */ -template < - typename T, - typename U, - typename Op, - int NDIMS = 0, - int BM = 8, - int BN = 128> +template [[kernel]] void col_reduce_looped( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], diff --git a/mlx/backend/metal/kernels/reduction/reduce_row.h b/mlx/backend/metal/kernels/reduction/reduce_row.h index 903265295..af8a01da7 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_row.h +++ b/mlx/backend/metal/kernels/reduction/reduce_row.h @@ -193,7 +193,7 @@ template < typename T, typename U, typename Op, - int NDIMS = 0, + int NDIMS, int N_READS = REDUCE_N_READS> [[kernel]] void row_reduce_small( const device T* in [[buffer(0)]], @@ -306,7 +306,7 @@ template < typename T, typename U, typename Op, - int NDIMS = 0, + int NDIMS, int N_READS = REDUCE_N_READS> [[kernel]] void row_reduce_looped( const device T* in [[buffer(0)]], diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 361e28916..3947ea419 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -104,8 +104,12 @@ MTL::ComputePipelineState* get_reduce_kernel( metal::Device& d, const std::string& kernel_name, const std::string&, + const std::string&, const array&, - const array&) { + const array&, + int, + int, + int) { return d.get_kernel(kernel_name); } diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 72f15d802..744105812 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -255,8 +255,9 @@ void all_reduce_dispatch( std::vector& copies) { // Set the kernel std::ostringstream kname; - kname << "allReduce_" << op_name << type_to_name(in); - auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); + const std::string func_name = "all_reduce"; + kname << func_name << "_" << op_name << type_to_name(in); + auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out); compute_encoder->setComputePipelineState(kernel); size_t in_size = in.size(); @@ -309,9 +310,9 @@ void all_reduce_dispatch( // 2nd pass std::ostringstream kname_2nd_pass; - kname_2nd_pass << "allReduce_" << op_name << type_to_name(intermediate); - auto kernel_2nd_pass = - get_reduce_kernel(d, kname_2nd_pass.str(), op_name, intermediate, out); + kname_2nd_pass << func_name << "_" << op_name << type_to_name(intermediate); + auto kernel_2nd_pass = get_reduce_kernel( + d, kname_2nd_pass.str(), func_name, op_name, intermediate, out); compute_encoder->setComputePipelineState(kernel_2nd_pass); size_t intermediate_size = n_rows; grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1); @@ -335,8 +336,10 @@ void row_reduce_small( // Set the kernel std::ostringstream kname; int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; - kname << "rowReduceSmall_" << n << "_reduce_" << op_name << type_to_name(in); - auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); + const std::string func_name = "row_reduce_small"; + kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); + auto kernel = + get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); compute_encoder->setComputePipelineState(kernel); // Figure out the grid dims @@ -370,8 +373,9 @@ void row_reduce_simple( const Stream& s) { // Set the kernel std::ostringstream kname; - kname << "rowReduceSimple_" << op_name << type_to_name(in); - auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); + const std::string func_name = "row_reduce_simple"; + kname << func_name << "_" << op_name << type_to_name(in); + auto kernel = get_reduce_kernel(d, kname.str(), func_name, op_name, in, out); compute_encoder->setComputePipelineState(kernel); // Figure out the grid dims @@ -407,8 +411,10 @@ void row_reduce_looped( // Set the kernel std::ostringstream kname; int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; - kname << "rowReduceLooped_" << n << "_reduce_" << op_name << type_to_name(in); - auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); + const std::string func_name = "row_reduce_looped"; + kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); + auto kernel = + get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); compute_encoder->setComputePipelineState(kernel); // Figure out the grid @@ -497,8 +503,10 @@ void strided_reduce_small( // Set the kernel int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; std::ostringstream kname; - kname << "colReduceSmall_" << n << "_reduce_" << op_name << type_to_name(in); - auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); + const std::string func_name = "col_reduce_small"; + kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in); + auto kernel = + get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n); compute_encoder->setComputePipelineState(kernel); // Launch @@ -535,9 +543,11 @@ void strided_reduce_looped( // Set the kernel int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0; std::ostringstream kname; - kname << "colReduceLooped_" << n << "_" << BM << "_" << BN << "_reduce_" + const std::string func_name = "col_reduce_looped"; + kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_" << op_name << type_to_name(in); - auto kernel = get_reduce_kernel(d, kname.str(), op_name, in, out); + auto kernel = + get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN); compute_encoder->setComputePipelineState(kernel); // Launch