From 56d2532aad7b155f470c2237b2ba0e61859f7a32 Mon Sep 17 00:00:00 2001 From: Arkar Min Aung Date: Sat, 14 Jun 2025 21:30:52 +1000 Subject: [PATCH] feat: Add JIT kernel support for SVD operations - Implement get_svd_kernel function for JIT compilation - Add proper library name extraction and template definition - Support dynamic kernel compilation for SVD operations - Enable future Metal shader JIT compilation for SVD - Integrate with existing MLX JIT kernel infrastructure --- mlx/backend/metal/jit_kernels.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index fab1b155c..ebb45afb8 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -828,9 +828,12 @@ MTL::ComputePipelineState* get_svd_kernel( const std::string& kernel_name, const array& out, bool compute_uv) { - auto lib = d.get_library(kernel_name, [&]() { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::svd(); + kernel_source += get_template_definition( + kernel_name, lib_name, get_type_string(out.dtype())); return kernel_source; }); return d.get_kernel(kernel_name, lib);