mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
Move some kernels to get_template_definition
(#1782)
This commit is contained in:
parent
90532b1f37
commit
1f4c127fb9
@ -1,25 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
constexpr std::string_view gemv_masked_kernel = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
|
|
||||||
const device {itype}* mat [[buffer(0)]],
|
|
||||||
const device {itype}* in_vec [[buffer(1)]],
|
|
||||||
device {itype}* out_vec [[buffer(3)]],
|
|
||||||
const constant int& in_vec_size [[buffer(4)]],
|
|
||||||
const constant int& out_vec_size [[buffer(5)]],
|
|
||||||
const constant int& marix_ld [[buffer(6)]],
|
|
||||||
const constant int& batch_ndim [[buffer(9)]],
|
|
||||||
const constant int* batch_shape [[buffer(10)]],
|
|
||||||
const constant int64_t* vector_batch_stride [[buffer(11)]],
|
|
||||||
const constant int64_t* matrix_batch_stride [[buffer(12)]],
|
|
||||||
const device {outm_t}* out_mask [[buffer(20)]],
|
|
||||||
const device {opm_t}* mat_mask [[buffer(21)]],
|
|
||||||
const device {opm_t}* vec_mask [[buffer(22)]],
|
|
||||||
const constant int* mask_strides [[buffer(23)]],
|
|
||||||
const constant int64_t* mask_batch_strides [[buffer(24)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
|
||||||
)";
|
|
@ -1,32 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
constexpr std::string_view steel_conv_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
|
|
||||||
const device {itype}* A [[buffer(0)]],
|
|
||||||
const device {itype}* B [[buffer(1)]],
|
|
||||||
device {itype}* C [[buffer(2)]],
|
|
||||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
|
||||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view steel_conv_general_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
|
|
||||||
const device {itype}* A [[buffer(0)]],
|
|
||||||
const device {itype}* B [[buffer(1)]],
|
|
||||||
device {itype}* C [[buffer(2)]],
|
|
||||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
|
||||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
|
||||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
|
||||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
|
||||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]],
|
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
|
||||||
)";
|
|
@ -1,106 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
constexpr std::string_view steel_gemm_fused_kernels = R"(
|
|
||||||
template [[host_name("{name}")]]
|
|
||||||
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
|
|
||||||
const device {itype} *A [[buffer(0)]],
|
|
||||||
const device {itype} *B [[buffer(1)]],
|
|
||||||
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
|
|
||||||
device {itype} *D [[buffer(3)]],
|
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
|
||||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
|
||||||
const constant int64_t* batch_strides [[buffer(7)]],
|
|
||||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
|
||||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
|
||||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
|
||||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
|
||||||
const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
|
||||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view steel_gemm_masked_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
block_masked_gemm<
|
|
||||||
{itype},
|
|
||||||
{outmasktype},
|
|
||||||
{opmasktype},
|
|
||||||
{bm},
|
|
||||||
{bn},
|
|
||||||
{bk},
|
|
||||||
{wm},
|
|
||||||
{wn},
|
|
||||||
{trans_a},
|
|
||||||
{trans_b},
|
|
||||||
{mn_aligned},
|
|
||||||
{k_aligned}>(
|
|
||||||
const device {itype}* A [[buffer(0)]],
|
|
||||||
const device {itype}* B [[buffer(1)]],
|
|
||||||
device {itype}* D [[buffer(3)]],
|
|
||||||
const constant GEMMParams* params [[buffer(4)]],
|
|
||||||
const constant int* batch_shape [[buffer(6)]],
|
|
||||||
const constant int64_t* batch_strides [[buffer(7)]],
|
|
||||||
const device {outmasktype}* out_mask [[buffer(10)]],
|
|
||||||
const device {opmasktype}* lhs_mask [[buffer(11)]],
|
|
||||||
const device {opmasktype}* rhs_mask [[buffer(12)]],
|
|
||||||
const constant int* mask_strides [[buffer(13)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view steel_gemm_splitk_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
gemm_splitk<
|
|
||||||
{itype},
|
|
||||||
{otype},
|
|
||||||
{bm},
|
|
||||||
{bn},
|
|
||||||
{bk},
|
|
||||||
{wm},
|
|
||||||
{wn},
|
|
||||||
{trans_a},
|
|
||||||
{trans_b},
|
|
||||||
{mn_aligned},
|
|
||||||
{k_aligned}>(
|
|
||||||
const device {itype}* A [[buffer(0)]],
|
|
||||||
const device {itype}* B [[buffer(1)]],
|
|
||||||
device {otype}* C [[buffer(2)]],
|
|
||||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
gemm_splitk_accum<{atype}, {otype}>(
|
|
||||||
const device {atype}* C_split [[buffer(0)]],
|
|
||||||
device {otype}* D [[buffer(1)]],
|
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
|
||||||
const constant int& ldd [[buffer(4)]],
|
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
|
|
||||||
template [[host_name("{name}")]] [[kernel]] void
|
|
||||||
gemm_splitk_accum_axpby<{atype}, {otype}>(
|
|
||||||
const device {atype}* C_split [[buffer(0)]],
|
|
||||||
device {otype}* D [[buffer(1)]],
|
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
|
||||||
const constant int& ldd [[buffer(4)]],
|
|
||||||
const device {otype}* C [[buffer(5)]],
|
|
||||||
const constant int& ldc [[buffer(6)]],
|
|
||||||
const constant int& fdc [[buffer(7)]],
|
|
||||||
const constant float& alpha [[buffer(8)]],
|
|
||||||
const constant float& beta [[buffer(9)]],
|
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
|
||||||
)";
|
|
@ -1,11 +1,8 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/metal/jit/arange.h"
|
#include "mlx/backend/metal/jit/arange.h"
|
||||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
|
||||||
#include "mlx/backend/metal/jit/includes.h"
|
#include "mlx/backend/metal/jit/includes.h"
|
||||||
#include "mlx/backend/metal/jit/softmax.h"
|
#include "mlx/backend/metal/jit/softmax.h"
|
||||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
|
||||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
|
||||||
@ -445,17 +442,17 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
|||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gemm()
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
<< metal::steel_gemm_fused()
|
<< metal::steel_gemm_fused()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
steel_gemm_fused_kernels,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
"gemm",
|
||||||
"itype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"bk"_a = bk,
|
bk,
|
||||||
"wm"_a = wm,
|
wm,
|
||||||
"wn"_a = wn,
|
wn,
|
||||||
"trans_a"_a = transpose_a,
|
transpose_a,
|
||||||
"trans_b"_a = transpose_b);
|
transpose_b);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
@ -480,20 +477,20 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
|||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gemm()
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
<< metal::steel_gemm_splitk()
|
<< metal::steel_gemm_splitk()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
steel_gemm_splitk_kernels,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
"gemm_splitk",
|
||||||
"itype"_a = get_type_string(in.dtype()),
|
get_type_string(in.dtype()),
|
||||||
"otype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"bk"_a = bk,
|
bk,
|
||||||
"wm"_a = wm,
|
wm,
|
||||||
"wn"_a = wn,
|
wn,
|
||||||
"trans_a"_a = transpose_a,
|
transpose_a,
|
||||||
"trans_b"_a = transpose_b,
|
transpose_b,
|
||||||
"mn_aligned"_a = mn_aligned,
|
mn_aligned,
|
||||||
"k_aligned"_a = k_aligned);
|
k_aligned);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@ -510,13 +507,12 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
|||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::gemm()
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
<< metal::steel_gemm_splitk()
|
<< metal::steel_gemm_splitk()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
fmt::runtime(
|
lib_name,
|
||||||
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
|
axbpy ? "gemm_splitk_accum_axpby"
|
||||||
: steel_gemm_splitk_accum_kernels),
|
: "gemm_splitk_accum",
|
||||||
"name"_a = lib_name,
|
get_type_string(in.dtype()),
|
||||||
"atype"_a = get_type_string(in.dtype()),
|
get_type_string(out.dtype()));
|
||||||
"otype"_a = get_type_string(out.dtype()));
|
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@ -547,21 +543,21 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|||||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||||
kernel_source << metal::utils() << metal::gemm()
|
kernel_source << metal::utils() << metal::gemm()
|
||||||
<< metal::steel_gemm_masked()
|
<< metal::steel_gemm_masked()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
steel_gemm_masked_kernels,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
"block_masked_gemm",
|
||||||
"itype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"outmasktype"_a = out_mask_type,
|
out_mask_type,
|
||||||
"opmasktype"_a = op_mask_type,
|
op_mask_type,
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"bk"_a = bk,
|
bk,
|
||||||
"wm"_a = wm,
|
wm,
|
||||||
"wn"_a = wn,
|
wn,
|
||||||
"trans_a"_a = transpose_a,
|
transpose_a,
|
||||||
"trans_b"_a = transpose_b,
|
transpose_b,
|
||||||
"mn_aligned"_a = mn_aligned,
|
mn_aligned,
|
||||||
"k_aligned"_a = k_aligned);
|
k_aligned);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@ -590,20 +586,19 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
|
|||||||
auto op_mask_type =
|
auto op_mask_type =
|
||||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||||
kernel_source << metal::utils() << metal::gemv_masked()
|
kernel_source << metal::utils() << metal::gemv_masked()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
gemv_masked_kernel,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
(transpose_mat) ? "gemv_t_masked" : "gemv_masked",
|
||||||
"itype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"outm_t"_a = out_mask_type,
|
out_mask_type,
|
||||||
"opm_t"_a = op_mask_type,
|
op_mask_type,
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"sm"_a = sm,
|
sm,
|
||||||
"sn"_a = sn,
|
sn,
|
||||||
"tm"_a = tm,
|
tm,
|
||||||
"tn"_a = tn,
|
tn,
|
||||||
"trans"_a = transpose_mat ? "t_" : "",
|
contiguous ? 0 : 1);
|
||||||
"nc"_a = contiguous ? "0" : "1");
|
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@ -624,17 +619,17 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
|||||||
auto lib = d.get_library(lib_name, [&]() {
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
steel_conv_kernels,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
"implicit_gemm_conv_2d",
|
||||||
"itype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"bk"_a = bk,
|
bk,
|
||||||
"wm"_a = wm,
|
wm,
|
||||||
"wn"_a = wn,
|
wn,
|
||||||
"n_channels"_a = n_channel_specialization,
|
n_channel_specialization,
|
||||||
"small_filter"_a = small_filter);
|
small_filter);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
@ -654,15 +649,15 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
|||||||
std::ostringstream kernel_source;
|
std::ostringstream kernel_source;
|
||||||
kernel_source << metal::utils() << metal::conv()
|
kernel_source << metal::utils() << metal::conv()
|
||||||
<< metal::steel_conv_general()
|
<< metal::steel_conv_general()
|
||||||
<< fmt::format(
|
<< get_template_definition(
|
||||||
steel_conv_general_kernels,
|
lib_name,
|
||||||
"name"_a = lib_name,
|
"implicit_gemm_conv_2d_general",
|
||||||
"itype"_a = get_type_string(out.dtype()),
|
get_type_string(out.dtype()),
|
||||||
"bm"_a = bm,
|
bm,
|
||||||
"bn"_a = bn,
|
bn,
|
||||||
"bk"_a = bk,
|
bk,
|
||||||
"wm"_a = wm,
|
wm,
|
||||||
"wn"_a = wn);
|
wn);
|
||||||
return kernel_source.str();
|
return kernel_source.str();
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
Loading…
Reference in New Issue
Block a user