mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Option to JIT steel gemm / conv (#1139)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@@ -12,11 +11,15 @@
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/sort.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
#include "mlx/backend/metal/jit/ternary.h"
|
||||
#include "mlx/backend/metal/jit/unary.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string op_name(const array& arr) {
|
||||
@@ -276,4 +279,208 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_fused()
|
||||
<< fmt::format(
|
||||
steel_gemm_fused_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
steel_gemm_splitk_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool axbpy) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
|
||||
: steel_gemm_splitk_accum_kernels,
|
||||
"name"_a = lib_name,
|
||||
"atype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto out_mask_type = mask_out.has_value()
|
||||
? get_type_string((*mask_out).dtype())
|
||||
: "nomask_t";
|
||||
auto op_mask_type =
|
||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_masked()
|
||||
<< fmt::format(
|
||||
steel_gemm_masked_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"outmasktype"_a = out_mask_type,
|
||||
"opmasktype"_a = op_mask_type,
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
int n_channel_specialization,
|
||||
bool small_filter) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
||||
<< fmt::format(
|
||||
steel_conv_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"n_channels"_a = n_channel_specialization,
|
||||
"small_filter"_a = small_filter);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv()
|
||||
<< metal::steel_conv_general()
|
||||
<< fmt::format(
|
||||
steel_conv_general_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user