diff --git a/mlx/backend/metal/jit/gemv_masked.h b/mlx/backend/metal/jit/gemv_masked.h deleted file mode 100644 index b83ad881f..000000000 --- a/mlx/backend/metal/jit/gemv_masked.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view gemv_masked_kernel = R"( -template [[host_name("{name}")]] [[kernel]] void -gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>( - const device {itype}* mat [[buffer(0)]], - const device {itype}* in_vec [[buffer(1)]], - device {itype}* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant int64_t* vector_batch_stride [[buffer(11)]], - const constant int64_t* matrix_batch_stride [[buffer(12)]], - const device {outm_t}* out_mask [[buffer(20)]], - const device {opm_t}* mat_mask [[buffer(21)]], - const device {opm_t}* vec_mask [[buffer(22)]], - const constant int* mask_strides [[buffer(23)]], - const constant int64_t* mask_batch_strides [[buffer(24)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]); -)"; diff --git a/mlx/backend/metal/jit/steel_conv.h b/mlx/backend/metal/jit/steel_conv.h deleted file mode 100644 index 44b8f95b9..000000000 --- a/mlx/backend/metal/jit/steel_conv.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view steel_conv_kernels = R"( -template [[host_name("{name}")]] [[kernel]] void -implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>( - const device {itype}* A [[buffer(0)]], - const device {itype}* B [[buffer(1)]], - device {itype}* C [[buffer(2)]], - const constant MLXConvParams<2>* params [[buffer(3)]], - const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]); -)"; - -constexpr std::string_view steel_conv_general_kernels = R"( -template [[host_name("{name}")]] [[kernel]] void - implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>( - const device {itype}* A [[buffer(0)]], - const device {itype}* B [[buffer(1)]], - device {itype}* C [[buffer(2)]], - const constant MLXConvParams<2>* params [[buffer(3)]], - const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], - const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], - const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], - const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]); -)"; diff --git a/mlx/backend/metal/jit/steel_gemm.h b/mlx/backend/metal/jit/steel_gemm.h deleted file mode 100644 index 85ddc449a..000000000 --- a/mlx/backend/metal/jit/steel_gemm.h +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright © 2024 Apple Inc. - -constexpr std::string_view steel_gemm_fused_kernels = R"( -template [[host_name("{name}")]] -[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>( - const device {itype} *A [[buffer(0)]], - const device {itype} *B [[buffer(1)]], - const device {itype} *C [[buffer(2), function_constant(use_out_source)]], - device {itype} *D [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], - const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], - const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], - const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], - const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], - const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]], - const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]); -)"; - -constexpr std::string_view steel_gemm_masked_kernels = R"( -template [[host_name("{name}")]] [[kernel]] void -block_masked_gemm< - {itype}, - {outmasktype}, - {opmasktype}, - {bm}, - {bn}, - {bk}, - {wm}, - {wn}, - {trans_a}, - {trans_b}, - {mn_aligned}, - {k_aligned}>( - const device {itype}* A [[buffer(0)]], - const device {itype}* B [[buffer(1)]], - device {itype}* D [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], - const device {outmasktype}* out_mask [[buffer(10)]], - const device {opmasktype}* lhs_mask [[buffer(11)]], - const device {opmasktype}* rhs_mask [[buffer(12)]], - const constant int* mask_strides [[buffer(13)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]); -)"; - -constexpr std::string_view steel_gemm_splitk_kernels = R"( -template [[host_name("{name}")]] [[kernel]] void -gemm_splitk< - {itype}, - {otype}, - {bm}, - {bn}, - {bk}, - {wm}, - {wn}, - {trans_a}, - {trans_b}, - {mn_aligned}, - {k_aligned}>( - const device {itype}* A [[buffer(0)]], - const device {itype}* B [[buffer(1)]], - device {otype}* C [[buffer(2)]], - const constant GEMMSpiltKParams* params [[buffer(3)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]); -)"; - -constexpr std::string_view steel_gemm_splitk_accum_kernels = R"( -template [[host_name("{name}")]] [[kernel]] void -gemm_splitk_accum<{atype}, {otype}>( - const device {atype}* C_split [[buffer(0)]], - device {otype}* D [[buffer(1)]], - const constant int& k_partitions [[buffer(2)]], - const constant int& partition_stride [[buffer(3)]], - const constant int& ldd [[buffer(4)]], - uint2 gid [[thread_position_in_grid]]); -)"; - -constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"( -template [[host_name("{name}")]] [[kernel]] void -gemm_splitk_accum_axpby<{atype}, {otype}>( - const device {atype}* C_split [[buffer(0)]], - device {otype}* D [[buffer(1)]], - const constant int& k_partitions [[buffer(2)]], - const constant int& partition_stride [[buffer(3)]], - const constant int& ldd [[buffer(4)]], - const device {otype}* C [[buffer(5)]], - const constant int& ldc [[buffer(6)]], - const constant int& fdc [[buffer(7)]], - const constant float& alpha [[buffer(8)]], - const constant float& beta [[buffer(9)]], - uint2 gid [[thread_position_in_grid]]); -)"; diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 78560bb2a..31bd7903b 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -1,11 +1,8 @@ // Copyright © 2024 Apple Inc. #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/jit/arange.h" -#include "mlx/backend/metal/jit/gemv_masked.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/softmax.h" -#include "mlx/backend/metal/jit/steel_conv.h" -#include "mlx/backend/metal/jit/steel_gemm.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" @@ -445,17 +442,17 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel( std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_fused() - << fmt::format( - steel_gemm_fused_kernels, - "name"_a = lib_name, - "itype"_a = get_type_string(out.dtype()), - "bm"_a = bm, - "bn"_a = bn, - "bk"_a = bk, - "wm"_a = wm, - "wn"_a = wn, - "trans_a"_a = transpose_a, - "trans_b"_a = transpose_b); + << get_template_definition( + lib_name, + "gemm", + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn, + transpose_a, + transpose_b); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib, hash_name, func_consts); @@ -480,20 +477,20 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel( std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() - << fmt::format( - steel_gemm_splitk_kernels, - "name"_a = lib_name, - "itype"_a = get_type_string(in.dtype()), - "otype"_a = get_type_string(out.dtype()), - "bm"_a = bm, - "bn"_a = bn, - "bk"_a = bk, - "wm"_a = wm, - "wn"_a = wn, - "trans_a"_a = transpose_a, - "trans_b"_a = transpose_b, - "mn_aligned"_a = mn_aligned, - "k_aligned"_a = k_aligned); + << get_template_definition( + lib_name, + "gemm_splitk", + get_type_string(in.dtype()), + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn, + transpose_a, + transpose_b, + mn_aligned, + k_aligned); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); @@ -510,13 +507,12 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel( std::ostringstream kernel_source; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_splitk() - << fmt::format( - fmt::runtime( - axbpy ? steel_gemm_splitk_accum_axbpy_kernels - : steel_gemm_splitk_accum_kernels), - "name"_a = lib_name, - "atype"_a = get_type_string(in.dtype()), - "otype"_a = get_type_string(out.dtype())); + << get_template_definition( + lib_name, + axbpy ? "gemm_splitk_accum_axpby" + : "gemm_splitk_accum", + get_type_string(in.dtype()), + get_type_string(out.dtype())); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); @@ -547,21 +543,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel( mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; kernel_source << metal::utils() << metal::gemm() << metal::steel_gemm_masked() - << fmt::format( - steel_gemm_masked_kernels, - "name"_a = lib_name, - "itype"_a = get_type_string(out.dtype()), - "outmasktype"_a = out_mask_type, - "opmasktype"_a = op_mask_type, - "bm"_a = bm, - "bn"_a = bn, - "bk"_a = bk, - "wm"_a = wm, - "wn"_a = wn, - "trans_a"_a = transpose_a, - "trans_b"_a = transpose_b, - "mn_aligned"_a = mn_aligned, - "k_aligned"_a = k_aligned); + << get_template_definition( + lib_name, + "block_masked_gemm", + get_type_string(out.dtype()), + out_mask_type, + op_mask_type, + bm, + bn, + bk, + wm, + wn, + transpose_a, + transpose_b, + mn_aligned, + k_aligned); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); @@ -590,20 +586,19 @@ MTL::ComputePipelineState* get_gemv_masked_kernel( auto op_mask_type = mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t"; kernel_source << metal::utils() << metal::gemv_masked() - << fmt::format( - gemv_masked_kernel, - "name"_a = lib_name, - "itype"_a = get_type_string(out.dtype()), - "outm_t"_a = out_mask_type, - "opm_t"_a = op_mask_type, - "bm"_a = bm, - "bn"_a = bn, - "sm"_a = sm, - "sn"_a = sn, - "tm"_a = tm, - "tn"_a = tn, - "trans"_a = transpose_mat ? "t_" : "", - "nc"_a = contiguous ? "0" : "1"); + << get_template_definition( + lib_name, + (transpose_mat) ? "gemv_t_masked" : "gemv_masked", + get_type_string(out.dtype()), + out_mask_type, + op_mask_type, + bm, + bn, + sm, + sn, + tm, + tn, + contiguous ? 0 : 1); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); @@ -624,17 +619,17 @@ MTL::ComputePipelineState* get_steel_conv_kernel( auto lib = d.get_library(lib_name, [&]() { std::ostringstream kernel_source; kernel_source << metal::utils() << metal::conv() << metal::steel_conv() - << fmt::format( - steel_conv_kernels, - "name"_a = lib_name, - "itype"_a = get_type_string(out.dtype()), - "bm"_a = bm, - "bn"_a = bn, - "bk"_a = bk, - "wm"_a = wm, - "wn"_a = wn, - "n_channels"_a = n_channel_specialization, - "small_filter"_a = small_filter); + << get_template_definition( + lib_name, + "implicit_gemm_conv_2d", + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn, + n_channel_specialization, + small_filter); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib); @@ -654,15 +649,15 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel( std::ostringstream kernel_source; kernel_source << metal::utils() << metal::conv() << metal::steel_conv_general() - << fmt::format( - steel_conv_general_kernels, - "name"_a = lib_name, - "itype"_a = get_type_string(out.dtype()), - "bm"_a = bm, - "bn"_a = bn, - "bk"_a = bk, - "wm"_a = wm, - "wn"_a = wn); + << get_template_definition( + lib_name, + "implicit_gemm_conv_2d_general", + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn); return kernel_source.str(); }); return d.get_kernel(kernel_name, lib);