diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 467380c3a..fd0e0db09 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -652,6 +652,43 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( return d.get_kernel(kernel_name, lib, hash_name, func_consts); } +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + metal::gemm(), + metal::steel_gemm_segmented(), + get_template_definition( + lib_name, + "segmented_mm", + get_type_string(out.dtype()), + bm, + bn, + bk, + wm, + wn, + transpose_a, + transpose_b)); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 1de5fa47c..794c67bdc 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -175,6 +175,20 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( int wn, bool rhs); +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& out, + bool transpose_a, + bool transpose_b, + int bm, + int bn, + int bk, + int wm, + int wn); + MTL::ComputePipelineState* get_steel_conv_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 8803aaf0a..55b8be3a9 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1962,7 +1962,7 @@ void segmented_mm( // Get and set the kernel auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = get_steel_gemm_gather_kernel( + auto kernel = get_steel_gemm_segmented_kernel( d, base_name, hash_name, @@ -1974,8 +1974,7 @@ void segmented_mm( bn, bk, wm, - wn, - false); + wn); compute_encoder.set_compute_pipeline_state(kernel); // Prepare the matmul params diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index b0375e37f..32d3e75f7 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -210,6 +210,22 @@ MTL::ComputePipelineState* get_steel_gemm_gather_kernel( return d.get_kernel(kernel_name, hash_name, func_consts); } +MTL::ComputePipelineState* get_steel_gemm_segmented_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + bool, + bool, + int, + int, + int, + int, + int) { + return d.get_kernel(kernel_name, hash_name, func_consts); +} + MTL::ComputePipelineState* get_gemv_masked_kernel( metal::Device& d, const std::string& kernel_name,