feat: Add JIT kernel support for SVD operations

- Implement get_svd_kernel function for JIT compilation
- Add proper library name extraction and template definition
- Support dynamic kernel compilation for SVD operations
- Enable future Metal shader JIT compilation for SVD
- Integrate with existing MLX JIT kernel infrastructure
This commit is contained in:
Arkar Min Aung 2025-06-14 21:30:52 +10:00
parent f2c731c29b
commit 56d2532aad

View File

@ -828,9 +828,12 @@ MTL::ComputePipelineState* get_svd_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const array& out, const array& out,
bool compute_uv) { bool compute_uv) {
auto lib = d.get_library(kernel_name, [&]() { std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils(); std::string kernel_source = metal::utils();
kernel_source += metal::svd(); kernel_source += metal::svd();
kernel_source += get_template_definition(
kernel_name, lib_name, get_type_string(out.dtype()));
return kernel_source; return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);