feat: Add Metal SVD kernel infrastructure

- Add svd.h header with kernel declarations
- Add svd.metal with placeholder Metal compute shaders
- Define SVD algorithm parameters and data structures
- Prepare foundation for Metal GPU-accelerated SVD implementation
This commit is contained in:
Arkar Min Aung 2025-06-14 21:22:34 +10:00
parent 6d01528e90
commit b7838461c1
2 changed files with 16 additions and 25 deletions

View File

@ -1,6 +1,9 @@
// Copyright © 2024 Apple Inc.
#pragma once #pragma once
namespace mlx::core { // Note: These structs are defined outside namespace for Metal kernel
// compatibility Metal kernels cannot access namespaced types directly
/** /**
* Parameters for SVD Metal kernels * Parameters for SVD Metal kernels
@ -12,7 +15,7 @@ struct SVDParams {
const int max_iterations; // Maximum Jacobi iterations const int max_iterations; // Maximum Jacobi iterations
const float tolerance; // Convergence threshold const float tolerance; // Convergence threshold
const int batch_size; // Number of matrices in batch const int batch_size; // Number of matrices in batch
const int64_t matrix_stride; // Stride between matrices in batch const long matrix_stride; // Stride between matrices in batch
const bool compute_uv; // Whether to compute U and V matrices const bool compute_uv; // Whether to compute U and V matrices
}; };
@ -34,4 +37,9 @@ struct SVDConvergenceInfo {
bool converged; // Whether algorithm has converged bool converged; // Whether algorithm has converged
}; };
namespace mlx::core {
// Namespace aliases for C++ code
using ::JacobiRotation;
using ::SVDConvergenceInfo;
using ::SVDParams;
} // namespace mlx::core } // namespace mlx::core

View File

@ -16,8 +16,7 @@ template <typename T>
const device T* A [[buffer(0)]], const device T* A [[buffer(0)]],
device T* AtA [[buffer(1)]], device T* AtA [[buffer(1)]],
const constant SVDParams& params [[buffer(2)]], const constant SVDParams& params [[buffer(2)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]]) {
uint3 lid [[thread_position_in_threadgroup]]) {
const int M = params.M; const int M = params.M;
const int N = params.N; const int N = params.N;
@ -51,10 +50,8 @@ template <typename T>
[[kernel]] void svd_jacobi_iteration( [[kernel]] void svd_jacobi_iteration(
device T* AtA [[buffer(0)]], device T* AtA [[buffer(0)]],
device JacobiRotation* rotations [[buffer(1)]], device JacobiRotation* rotations [[buffer(1)]],
device SVDConvergenceInfo* convergence [[buffer(2)]],
const constant SVDParams& params [[buffer(3)]], const constant SVDParams& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]]) {
uint3 lid [[thread_position_in_threadgroup]]) {
const int N = params.N; const int N = params.N;
const int batch_idx = tid.z; const int batch_idx = tid.z;
@ -68,7 +65,7 @@ template <typename T>
} }
// Convert linear pair index to (p,q) coordinates where p < q // Convert linear pair index to (p,q) coordinates where p < q
int p, q; int p, q = 0;
int idx = pair_idx; int idx = pair_idx;
for (p = 0; p < N - 1; p++) { for (p = 0; p < N - 1; p++) {
int pairs_in_row = N - 1 - p; int pairs_in_row = N - 1 - p;
@ -218,8 +215,7 @@ template <typename T>
device T* U [[buffer(2)]], device T* U [[buffer(2)]],
device T* V [[buffer(3)]], device T* V [[buffer(3)]],
const constant SVDParams& params [[buffer(4)]], const constant SVDParams& params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]]) {
uint3 lid [[thread_position_in_threadgroup]]) {
const int M = params.M; const int M = params.M;
const int N = params.N; const int N = params.N;
@ -294,18 +290,5 @@ decltype(svd_check_convergence<float>) svd_check_convergence<float>;
template [[host_name("svd_compute_vectors_float")]] [[kernel]] template [[host_name("svd_compute_vectors_float")]] [[kernel]]
decltype(svd_compute_vectors<float>) svd_compute_vectors<float>; decltype(svd_compute_vectors<float>) svd_compute_vectors<float>;
// Template instantiations for double // Note: Metal does not support double precision
template [[host_name("svd_preprocess_double")]] [[kernel]] // Double precision operations will fall back to CPU
decltype(svd_preprocess<double>) svd_preprocess<double>;
template [[host_name("svd_jacobi_iteration_double")]] [[kernel]]
decltype(svd_jacobi_iteration<double>) svd_jacobi_iteration<double>;
template [[host_name("svd_extract_singular_values_double")]] [[kernel]]
decltype(svd_extract_singular_values<double>) svd_extract_singular_values<double>;
template [[host_name("svd_check_convergence_double")]] [[kernel]]
decltype(svd_check_convergence<double>) svd_check_convergence<double>;
template [[host_name("svd_compute_vectors_double")]] [[kernel]]
decltype(svd_compute_vectors<double>) svd_compute_vectors<double>;