From b7838461c1a2332272666a5dee4274bad3ca79c5 Mon Sep 17 00:00:00 2001 From: Arkar Min Aung Date: Sat, 14 Jun 2025 21:22:34 +1000 Subject: [PATCH] feat: Add Metal SVD kernel infrastructure - Add svd.h header with kernel declarations - Add svd.metal with placeholder Metal compute shaders - Define SVD algorithm parameters and data structures - Prepare foundation for Metal GPU-accelerated SVD implementation --- mlx/backend/metal/kernels/svd.h | 12 ++++++++++-- mlx/backend/metal/kernels/svd.metal | 29 ++++++----------------------- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/mlx/backend/metal/kernels/svd.h b/mlx/backend/metal/kernels/svd.h index 1a030a2f7..cc2587e0f 100644 --- a/mlx/backend/metal/kernels/svd.h +++ b/mlx/backend/metal/kernels/svd.h @@ -1,6 +1,9 @@ +// Copyright © 2024 Apple Inc. + #pragma once -namespace mlx::core { +// Note: These structs are defined outside namespace for Metal kernel +// compatibility Metal kernels cannot access namespaced types directly /** * Parameters for SVD Metal kernels @@ -12,7 +15,7 @@ struct SVDParams { const int max_iterations; // Maximum Jacobi iterations const float tolerance; // Convergence threshold const int batch_size; // Number of matrices in batch - const int64_t matrix_stride; // Stride between matrices in batch + const long matrix_stride; // Stride between matrices in batch const bool compute_uv; // Whether to compute U and V matrices }; @@ -34,4 +37,9 @@ struct SVDConvergenceInfo { bool converged; // Whether algorithm has converged }; +namespace mlx::core { +// Namespace aliases for C++ code +using ::JacobiRotation; +using ::SVDConvergenceInfo; +using ::SVDParams; } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/svd.metal b/mlx/backend/metal/kernels/svd.metal index 879287337..e4f6ddb5c 100644 --- a/mlx/backend/metal/kernels/svd.metal +++ b/mlx/backend/metal/kernels/svd.metal @@ -16,8 +16,7 @@ template const device T* A [[buffer(0)]], device T* AtA [[buffer(1)]], const constant SVDParams& params [[buffer(2)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { + uint3 tid [[threadgroup_position_in_grid]]) { const int M = params.M; const int N = params.N; @@ -51,10 +50,8 @@ template [[kernel]] void svd_jacobi_iteration( device T* AtA [[buffer(0)]], device JacobiRotation* rotations [[buffer(1)]], - device SVDConvergenceInfo* convergence [[buffer(2)]], const constant SVDParams& params [[buffer(3)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { + uint3 tid [[threadgroup_position_in_grid]]) { const int N = params.N; const int batch_idx = tid.z; @@ -68,7 +65,7 @@ template } // Convert linear pair index to (p,q) coordinates where p < q - int p, q; + int p, q = 0; int idx = pair_idx; for (p = 0; p < N - 1; p++) { int pairs_in_row = N - 1 - p; @@ -218,8 +215,7 @@ template device T* U [[buffer(2)]], device T* V [[buffer(3)]], const constant SVDParams& params [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { + uint3 tid [[threadgroup_position_in_grid]]) { const int M = params.M; const int N = params.N; @@ -294,18 +290,5 @@ decltype(svd_check_convergence) svd_check_convergence; template [[host_name("svd_compute_vectors_float")]] [[kernel]] decltype(svd_compute_vectors) svd_compute_vectors; -// Template instantiations for double -template [[host_name("svd_preprocess_double")]] [[kernel]] -decltype(svd_preprocess) svd_preprocess; - -template [[host_name("svd_jacobi_iteration_double")]] [[kernel]] -decltype(svd_jacobi_iteration) svd_jacobi_iteration; - -template [[host_name("svd_extract_singular_values_double")]] [[kernel]] -decltype(svd_extract_singular_values) svd_extract_singular_values; - -template [[host_name("svd_check_convergence_double")]] [[kernel]] -decltype(svd_check_convergence) svd_check_convergence; - -template [[host_name("svd_compute_vectors_double")]] [[kernel]] -decltype(svd_compute_vectors) svd_compute_vectors; +// Note: Metal does not support double precision +// Double precision operations will fall back to CPU