= R"(
template [[host_name("{name}")]]
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
    const device {itype} *A [[buffer(0)]],
    const device {itype} *B [[buffer(1)]],
    const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
    device {itype} *D [[buffer(3)]],
    const constant GEMMParams* params [[buffer(4)]],
    const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
    const constant int* batch_shape [[buffer(6)]],
    const constant size_t* batch_strides [[buffer(7)]],
    const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
    const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
    const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
    const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
    const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
    const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
    uint simd_lane_id [[thread_index_in_simdgroup]],
    uint simd_group_id [[simdgroup_index_in_threadgroup]],
    uint3 tid [[threadgroup_position_in_grid]],
    uint3 lid [[thread_position_in_threadgroup]]);
)"