mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 21:21:16 +08:00
feat: Add Metal SVD kernel infrastructure
- Add svd.h header with kernel declarations - Add svd.metal with placeholder Metal compute shaders - Define SVD algorithm parameters and data structures - Prepare foundation for Metal GPU-accelerated SVD implementation
This commit is contained in:
parent
6d01528e90
commit
b7838461c1
@ -1,6 +1,9 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
namespace mlx::core {
|
// Note: These structs are defined outside namespace for Metal kernel
|
||||||
|
// compatibility Metal kernels cannot access namespaced types directly
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parameters for SVD Metal kernels
|
* Parameters for SVD Metal kernels
|
||||||
@ -12,7 +15,7 @@ struct SVDParams {
|
|||||||
const int max_iterations; // Maximum Jacobi iterations
|
const int max_iterations; // Maximum Jacobi iterations
|
||||||
const float tolerance; // Convergence threshold
|
const float tolerance; // Convergence threshold
|
||||||
const int batch_size; // Number of matrices in batch
|
const int batch_size; // Number of matrices in batch
|
||||||
const int64_t matrix_stride; // Stride between matrices in batch
|
const long matrix_stride; // Stride between matrices in batch
|
||||||
const bool compute_uv; // Whether to compute U and V matrices
|
const bool compute_uv; // Whether to compute U and V matrices
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -34,4 +37,9 @@ struct SVDConvergenceInfo {
|
|||||||
bool converged; // Whether algorithm has converged
|
bool converged; // Whether algorithm has converged
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
// Namespace aliases for C++ code
|
||||||
|
using ::JacobiRotation;
|
||||||
|
using ::SVDConvergenceInfo;
|
||||||
|
using ::SVDParams;
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -16,8 +16,7 @@ template <typename T>
|
|||||||
const device T* A [[buffer(0)]],
|
const device T* A [[buffer(0)]],
|
||||||
device T* AtA [[buffer(1)]],
|
device T* AtA [[buffer(1)]],
|
||||||
const constant SVDParams& params [[buffer(2)]],
|
const constant SVDParams& params [[buffer(2)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
|
|
||||||
const int M = params.M;
|
const int M = params.M;
|
||||||
const int N = params.N;
|
const int N = params.N;
|
||||||
@ -51,10 +50,8 @@ template <typename T>
|
|||||||
[[kernel]] void svd_jacobi_iteration(
|
[[kernel]] void svd_jacobi_iteration(
|
||||||
device T* AtA [[buffer(0)]],
|
device T* AtA [[buffer(0)]],
|
||||||
device JacobiRotation* rotations [[buffer(1)]],
|
device JacobiRotation* rotations [[buffer(1)]],
|
||||||
device SVDConvergenceInfo* convergence [[buffer(2)]],
|
|
||||||
const constant SVDParams& params [[buffer(3)]],
|
const constant SVDParams& params [[buffer(3)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
|
|
||||||
const int N = params.N;
|
const int N = params.N;
|
||||||
const int batch_idx = tid.z;
|
const int batch_idx = tid.z;
|
||||||
@ -68,7 +65,7 @@ template <typename T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Convert linear pair index to (p,q) coordinates where p < q
|
// Convert linear pair index to (p,q) coordinates where p < q
|
||||||
int p, q;
|
int p, q = 0;
|
||||||
int idx = pair_idx;
|
int idx = pair_idx;
|
||||||
for (p = 0; p < N - 1; p++) {
|
for (p = 0; p < N - 1; p++) {
|
||||||
int pairs_in_row = N - 1 - p;
|
int pairs_in_row = N - 1 - p;
|
||||||
@ -218,8 +215,7 @@ template <typename T>
|
|||||||
device T* U [[buffer(2)]],
|
device T* U [[buffer(2)]],
|
||||||
device T* V [[buffer(3)]],
|
device T* V [[buffer(3)]],
|
||||||
const constant SVDParams& params [[buffer(4)]],
|
const constant SVDParams& params [[buffer(4)]],
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
|
|
||||||
const int M = params.M;
|
const int M = params.M;
|
||||||
const int N = params.N;
|
const int N = params.N;
|
||||||
@ -294,18 +290,5 @@ decltype(svd_check_convergence<float>) svd_check_convergence<float>;
|
|||||||
template [[host_name("svd_compute_vectors_float")]] [[kernel]]
|
template [[host_name("svd_compute_vectors_float")]] [[kernel]]
|
||||||
decltype(svd_compute_vectors<float>) svd_compute_vectors<float>;
|
decltype(svd_compute_vectors<float>) svd_compute_vectors<float>;
|
||||||
|
|
||||||
// Template instantiations for double
|
// Note: Metal does not support double precision
|
||||||
template [[host_name("svd_preprocess_double")]] [[kernel]]
|
// Double precision operations will fall back to CPU
|
||||||
decltype(svd_preprocess<double>) svd_preprocess<double>;
|
|
||||||
|
|
||||||
template [[host_name("svd_jacobi_iteration_double")]] [[kernel]]
|
|
||||||
decltype(svd_jacobi_iteration<double>) svd_jacobi_iteration<double>;
|
|
||||||
|
|
||||||
template [[host_name("svd_extract_singular_values_double")]] [[kernel]]
|
|
||||||
decltype(svd_extract_singular_values<double>) svd_extract_singular_values<double>;
|
|
||||||
|
|
||||||
template [[host_name("svd_check_convergence_double")]] [[kernel]]
|
|
||||||
decltype(svd_check_convergence<double>) svd_check_convergence<double>;
|
|
||||||
|
|
||||||
template [[host_name("svd_compute_vectors_double")]] [[kernel]]
|
|
||||||
decltype(svd_compute_vectors<double>) svd_compute_vectors<double>;
|
|
||||||
|
Loading…
Reference in New Issue
Block a user