mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
feat: Add Metal SVD kernel infrastructure
- Add svd.h header with kernel declarations - Add svd.metal with placeholder Metal compute shaders - Define SVD algorithm parameters and data structures - Prepare foundation for Metal GPU-accelerated SVD implementation
This commit is contained in:
parent
6d01528e90
commit
b7838461c1
@ -1,6 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core {
|
||||
// Note: These structs are defined outside namespace for Metal kernel
|
||||
// compatibility Metal kernels cannot access namespaced types directly
|
||||
|
||||
/**
|
||||
* Parameters for SVD Metal kernels
|
||||
@ -12,7 +15,7 @@ struct SVDParams {
|
||||
const int max_iterations; // Maximum Jacobi iterations
|
||||
const float tolerance; // Convergence threshold
|
||||
const int batch_size; // Number of matrices in batch
|
||||
const int64_t matrix_stride; // Stride between matrices in batch
|
||||
const long matrix_stride; // Stride between matrices in batch
|
||||
const bool compute_uv; // Whether to compute U and V matrices
|
||||
};
|
||||
|
||||
@ -34,4 +37,9 @@ struct SVDConvergenceInfo {
|
||||
bool converged; // Whether algorithm has converged
|
||||
};
|
||||
|
||||
namespace mlx::core {
|
||||
// Namespace aliases for C++ code
|
||||
using ::JacobiRotation;
|
||||
using ::SVDConvergenceInfo;
|
||||
using ::SVDParams;
|
||||
} // namespace mlx::core
|
||||
|
@ -16,8 +16,7 @@ template <typename T>
|
||||
const device T* A [[buffer(0)]],
|
||||
device T* AtA [[buffer(1)]],
|
||||
const constant SVDParams& params [[buffer(2)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
|
||||
const int M = params.M;
|
||||
const int N = params.N;
|
||||
@ -51,10 +50,8 @@ template <typename T>
|
||||
[[kernel]] void svd_jacobi_iteration(
|
||||
device T* AtA [[buffer(0)]],
|
||||
device JacobiRotation* rotations [[buffer(1)]],
|
||||
device SVDConvergenceInfo* convergence [[buffer(2)]],
|
||||
const constant SVDParams& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
|
||||
const int N = params.N;
|
||||
const int batch_idx = tid.z;
|
||||
@ -68,7 +65,7 @@ template <typename T>
|
||||
}
|
||||
|
||||
// Convert linear pair index to (p,q) coordinates where p < q
|
||||
int p, q;
|
||||
int p, q = 0;
|
||||
int idx = pair_idx;
|
||||
for (p = 0; p < N - 1; p++) {
|
||||
int pairs_in_row = N - 1 - p;
|
||||
@ -218,8 +215,7 @@ template <typename T>
|
||||
device T* U [[buffer(2)]],
|
||||
device T* V [[buffer(3)]],
|
||||
const constant SVDParams& params [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
|
||||
const int M = params.M;
|
||||
const int N = params.N;
|
||||
@ -294,18 +290,5 @@ decltype(svd_check_convergence<float>) svd_check_convergence<float>;
|
||||
template [[host_name("svd_compute_vectors_float")]] [[kernel]]
|
||||
decltype(svd_compute_vectors<float>) svd_compute_vectors<float>;
|
||||
|
||||
// Template instantiations for double
|
||||
template [[host_name("svd_preprocess_double")]] [[kernel]]
|
||||
decltype(svd_preprocess<double>) svd_preprocess<double>;
|
||||
|
||||
template [[host_name("svd_jacobi_iteration_double")]] [[kernel]]
|
||||
decltype(svd_jacobi_iteration<double>) svd_jacobi_iteration<double>;
|
||||
|
||||
template [[host_name("svd_extract_singular_values_double")]] [[kernel]]
|
||||
decltype(svd_extract_singular_values<double>) svd_extract_singular_values<double>;
|
||||
|
||||
template [[host_name("svd_check_convergence_double")]] [[kernel]]
|
||||
decltype(svd_check_convergence<double>) svd_check_convergence<double>;
|
||||
|
||||
template [[host_name("svd_compute_vectors_double")]] [[kernel]]
|
||||
decltype(svd_compute_vectors<double>) svd_compute_vectors<double>;
|
||||
// Note: Metal does not support double precision
|
||||
// Double precision operations will fall back to CPU
|
||||
|
Loading…
Reference in New Issue
Block a user