mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix metal jit
This commit is contained in:
@@ -652,6 +652,43 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::steel_gemm_segmented(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
"segmented_mm",
|
||||
get_type_string(out.dtype()),
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose_a,
|
||||
transpose_b));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
||||
@@ -175,6 +175,20 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
int wn,
|
||||
bool rhs);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
||||
@@ -1962,7 +1962,7 @@ void segmented_mm(
|
||||
|
||||
// Get and set the kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_gemm_gather_kernel(
|
||||
auto kernel = get_steel_gemm_segmented_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
@@ -1974,8 +1974,7 @@ void segmented_mm(
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
false);
|
||||
wn);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Prepare the matmul params
|
||||
|
||||
@@ -210,6 +210,22 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array&,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int) {
|
||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
|
||||
Reference in New Issue
Block a user