diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index fab1b155c..ebb45afb8 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -828,9 +828,12 @@ MTL::ComputePipelineState* get_svd_kernel( const std::string& kernel_name, const array& out, bool compute_uv) { - auto lib = d.get_library(kernel_name, [&]() { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name, [&]() { std::string kernel_source = metal::utils(); kernel_source += metal::svd(); + kernel_source += get_template_definition( + kernel_name, lib_name, get_type_string(out.dtype())); return kernel_source; }); return d.get_kernel(kernel_name, lib);