feat: Add Metal SVD kernel infrastructure

- Add svd.h header with kernel declarations
- Add svd.metal with placeholder Metal compute shaders
- Define SVD algorithm parameters and data structures
- Prepare foundation for Metal GPU-accelerated SVD implementation
This commit is contained in:
Arkar Min Aung 2025-06-14 21:22:34 +10:00
parent 6d01528e90
commit b7838461c1
2 changed files with 16 additions and 25 deletions

View File

@ -1,6 +1,9 @@
// Copyright © 2024 Apple Inc.
#pragma once
namespace mlx::core {
// Note: These structs are defined outside namespace for Metal kernel
// compatibility Metal kernels cannot access namespaced types directly
/**
* Parameters for SVD Metal kernels
@ -12,7 +15,7 @@ struct SVDParams {
const int max_iterations; // Maximum Jacobi iterations
const float tolerance; // Convergence threshold
const int batch_size; // Number of matrices in batch
const int64_t matrix_stride; // Stride between matrices in batch
const long matrix_stride; // Stride between matrices in batch
const bool compute_uv; // Whether to compute U and V matrices
};
@ -34,4 +37,9 @@ struct SVDConvergenceInfo {
bool converged; // Whether algorithm has converged
};
namespace mlx::core {
// Namespace aliases for C++ code
using ::JacobiRotation;
using ::SVDConvergenceInfo;
using ::SVDParams;
} // namespace mlx::core

View File

@ -16,8 +16,7 @@ template <typename T>
const device T* A [[buffer(0)]],
device T* AtA [[buffer(1)]],
const constant SVDParams& params [[buffer(2)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
uint3 tid [[threadgroup_position_in_grid]]) {
const int M = params.M;
const int N = params.N;
@ -51,10 +50,8 @@ template <typename T>
[[kernel]] void svd_jacobi_iteration(
device T* AtA [[buffer(0)]],
device JacobiRotation* rotations [[buffer(1)]],
device SVDConvergenceInfo* convergence [[buffer(2)]],
const constant SVDParams& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
uint3 tid [[threadgroup_position_in_grid]]) {
const int N = params.N;
const int batch_idx = tid.z;
@ -68,7 +65,7 @@ template <typename T>
}
// Convert linear pair index to (p,q) coordinates where p < q
int p, q;
int p, q = 0;
int idx = pair_idx;
for (p = 0; p < N - 1; p++) {
int pairs_in_row = N - 1 - p;
@ -218,8 +215,7 @@ template <typename T>
device T* U [[buffer(2)]],
device T* V [[buffer(3)]],
const constant SVDParams& params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
uint3 tid [[threadgroup_position_in_grid]]) {
const int M = params.M;
const int N = params.N;
@ -294,18 +290,5 @@ decltype(svd_check_convergence<float>) svd_check_convergence<float>;
template [[host_name("svd_compute_vectors_float")]] [[kernel]]
decltype(svd_compute_vectors<float>) svd_compute_vectors<float>;
// Template instantiations for double
template [[host_name("svd_preprocess_double")]] [[kernel]]
decltype(svd_preprocess<double>) svd_preprocess<double>;
template [[host_name("svd_jacobi_iteration_double")]] [[kernel]]
decltype(svd_jacobi_iteration<double>) svd_jacobi_iteration<double>;
template [[host_name("svd_extract_singular_values_double")]] [[kernel]]
decltype(svd_extract_singular_values<double>) svd_extract_singular_values<double>;
template [[host_name("svd_check_convergence_double")]] [[kernel]]
decltype(svd_check_convergence<double>) svd_check_convergence<double>;
template [[host_name("svd_compute_vectors_double")]] [[kernel]]
decltype(svd_compute_vectors<double>) svd_compute_vectors<double>;
// Note: Metal does not support double precision
// Double precision operations will fall back to CPU