mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix metal jit
This commit is contained in:
@@ -652,6 +652,43 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
|||||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& out,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn) {
|
||||||
|
const auto& lib_name = kernel_name;
|
||||||
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
|
std::string kernel_source;
|
||||||
|
concatenate(
|
||||||
|
kernel_source,
|
||||||
|
metal::utils(),
|
||||||
|
metal::gemm(),
|
||||||
|
metal::steel_gemm_segmented(),
|
||||||
|
get_template_definition(
|
||||||
|
lib_name,
|
||||||
|
"segmented_mm",
|
||||||
|
get_type_string(out.dtype()),
|
||||||
|
bm,
|
||||||
|
bn,
|
||||||
|
bk,
|
||||||
|
wm,
|
||||||
|
wn,
|
||||||
|
transpose_a,
|
||||||
|
transpose_b));
|
||||||
|
return kernel_source;
|
||||||
|
});
|
||||||
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
|||||||
@@ -175,6 +175,20 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
|||||||
int wn,
|
int wn,
|
||||||
bool rhs);
|
bool rhs);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array& out,
|
||||||
|
bool transpose_a,
|
||||||
|
bool transpose_b,
|
||||||
|
int bm,
|
||||||
|
int bn,
|
||||||
|
int bk,
|
||||||
|
int wm,
|
||||||
|
int wn);
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
|||||||
@@ -1962,7 +1962,7 @@ void segmented_mm(
|
|||||||
|
|
||||||
// Get and set the kernel
|
// Get and set the kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = get_steel_gemm_gather_kernel(
|
auto kernel = get_steel_gemm_segmented_kernel(
|
||||||
d,
|
d,
|
||||||
base_name,
|
base_name,
|
||||||
hash_name,
|
hash_name,
|
||||||
@@ -1974,8 +1974,7 @@ void segmented_mm(
|
|||||||
bn,
|
bn,
|
||||||
bk,
|
bk,
|
||||||
wm,
|
wm,
|
||||||
wn,
|
wn);
|
||||||
false);
|
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Prepare the matmul params
|
// Prepare the matmul params
|
||||||
|
|||||||
@@ -210,6 +210,22 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
|||||||
return d.get_kernel(kernel_name, hash_name, func_consts);
|
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_steel_gemm_segmented_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const std::string& hash_name,
|
||||||
|
const metal::MTLFCList& func_consts,
|
||||||
|
const array&,
|
||||||
|
bool,
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int) {
|
||||||
|
return d.get_kernel(kernel_name, hash_name, func_consts);
|
||||||
|
}
|
||||||
|
|
||||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
|
|||||||
Reference in New Issue
Block a user