mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
feat: Add JIT kernel support for SVD operations
- Implement get_svd_kernel function for JIT compilation - Add proper library name extraction and template definition - Support dynamic kernel compilation for SVD operations - Enable future Metal shader JIT compilation for SVD - Integrate with existing MLX JIT kernel infrastructure
This commit is contained in:
parent
f2c731c29b
commit
56d2532aad
@ -828,9 +828,12 @@ MTL::ComputePipelineState* get_svd_kernel(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const array& out,
|
const array& out,
|
||||||
bool compute_uv) {
|
bool compute_uv) {
|
||||||
auto lib = d.get_library(kernel_name, [&]() {
|
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||||
|
auto lib = d.get_library(lib_name, [&]() {
|
||||||
std::string kernel_source = metal::utils();
|
std::string kernel_source = metal::utils();
|
||||||
kernel_source += metal::svd();
|
kernel_source += metal::svd();
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
kernel_name, lib_name, get_type_string(out.dtype()));
|
||||||
return kernel_source;
|
return kernel_source;
|
||||||
});
|
});
|
||||||
return d.get_kernel(kernel_name, lib);
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
Loading…
Reference in New Issue
Block a user