mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes * some cleanup * add utils * fix * fix 2 * more cleaning * fix * remove unused mps matmul support * one more nit * revert
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
@@ -50,10 +51,12 @@ MTL::ComputePipelineState* get_unary_kernel(
|
||||
std::ostringstream kernel_source;
|
||||
auto u_def = get_template_definition(
|
||||
"v" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||
auto u2_def = get_template_definition(
|
||||
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||
auto g_def = get_template_definition(
|
||||
"g" + lib_name, "unary_g", get_type_string(out_type), op);
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
|
||||
<< u_def << g_def;
|
||||
<< u_def << u2_def << g_def;
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -70,6 +73,9 @@ void add_binary_kernels(
|
||||
{"vs", "binary_vs"},
|
||||
{"sv", "binary_sv"},
|
||||
{"vv", "binary_vv"},
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g2", "binary_g_nd2"},
|
||||
{"g3", "binary_g_nd3"},
|
||||
@@ -146,6 +152,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
std::ostringstream kernel_source;
|
||||
const std::map<std::string, std::string> kernel_types = {
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g", "ternary_g"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g2", "ternary_g_nd2"},
|
||||
@@ -496,6 +503,49 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_mat,
|
||||
int bm,
|
||||
int bn,
|
||||
int sm,
|
||||
int sn,
|
||||
int tm,
|
||||
int tn,
|
||||
bool contiguous) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto out_mask_type = mask_out.has_value()
|
||||
? get_type_string((*mask_out).dtype())
|
||||
: "nomask_t";
|
||||
auto op_mask_type =
|
||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||
kernel_source << metal::utils() << metal::gemv_masked()
|
||||
<< fmt::format(
|
||||
gemv_masked_kernel,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"outm_t"_a = out_mask_type,
|
||||
"opm_t"_a = op_mask_type,
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"sm"_a = sm,
|
||||
"sn"_a = sn,
|
||||
"tm"_a = tm,
|
||||
"tn"_a = tn,
|
||||
"trans"_a = transpose_mat ? "t_" : "",
|
||||
"nc"_a = contiguous ? "0" : "1");
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
||||
Reference in New Issue
Block a user