diff --git a/mlx/backend/metal/kernels/svd.metal b/mlx/backend/metal/kernels/svd.metal index 95a39b71d..e3e46ac48 100644 --- a/mlx/backend/metal/kernels/svd.metal +++ b/mlx/backend/metal/kernels/svd.metal @@ -153,6 +153,63 @@ template S_batch[i] = sqrt(max(diagonal_element, T(0))); // Ensure non-negative } +/** + * Check convergence of Jacobi iterations + * Computes the Frobenius norm of off-diagonal elements + */ +template +[[kernel]] void svd_check_convergence( + const device T* AtA [[buffer(0)]], + device SVDConvergenceInfo* convergence [[buffer(1)]], + const constant SVDParams& params [[buffer(2)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + const int N = params.N; + const int batch_idx = tid.z; + const int thread_id = lid.x; + const int threads_per_group = 256; // Assuming 256 threads per group + + // Shared memory for reduction + threadgroup float shared_sum[256]; + + const device T* AtA_batch = AtA + batch_idx * (N * N); + device SVDConvergenceInfo* conv_batch = convergence + batch_idx; + + // Each thread computes sum of squares of some off-diagonal elements + float local_sum = 0.0f; + + for (int idx = thread_id; idx < N * N; idx += threads_per_group) { + int i = idx / N; + int j = idx % N; + + // Only consider off-diagonal elements + if (i != j) { + float val = static_cast(AtA_batch[i * N + j]); + local_sum += val * val; + } + } + + // Store in shared memory + shared_sum[thread_id] = local_sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Reduction to compute total off-diagonal norm + for (int stride = threads_per_group / 2; stride > 0; stride /= 2) { + if (thread_id < stride) { + shared_sum[thread_id] += shared_sum[thread_id + stride]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Thread 0 writes the result + if (thread_id == 0) { + float off_diagonal_norm = sqrt(shared_sum[0]); + conv_batch->off_diagonal_norm = off_diagonal_norm; + conv_batch->converged = (off_diagonal_norm < params.tolerance); + } +} + /** * Compute singular vectors U and V */ @@ -176,44 +233,50 @@ template return; // Skip if not computing singular vectors } + const int total_pairs = (N * (N - 1)) / 2; + const device JacobiRotation* rot_batch = rotations + batch_idx * total_pairs; + // Initialize V as identity matrix (right singular vectors) if (i < N && j < N) { device T* V_batch = V + batch_idx * (N * N); V_batch[i * N + j] = (i == j) ? T(1) : T(0); - } - // Apply all Jacobi rotations to V in reverse order - const int total_pairs = (N * (N - 1)) / 2; - const device JacobiRotation* rot_batch = rotations + batch_idx * total_pairs; + // Apply accumulated Jacobi rotations to build V + // This gives us the right singular vectors + for (int rot_idx = 0; rot_idx < total_pairs; rot_idx++) { + int p = rot_batch[rot_idx].p; + int q = rot_batch[rot_idx].q; + T c = static_cast(rot_batch[rot_idx].cos_theta); + T s = static_cast(rot_batch[rot_idx].sin_theta); - // Note: In a real implementation, we'd need to apply rotations iteratively - // This is a simplified version for the basic implementation - for (int rot_idx = 0; rot_idx < total_pairs; rot_idx++) { - int p = rot_batch[rot_idx].p; - int q = rot_batch[rot_idx].q; - T c = rot_batch[rot_idx].cos_theta; - T s = rot_batch[rot_idx].sin_theta; - - if (i < N && (j == p || j == q)) { - device T* V_batch = V + batch_idx * (N * N); - if (j == p) { + // Apply rotation to columns p and q of V + if (j == p || j == q) { T vip = V_batch[i * N + p]; T viq = V_batch[i * N + q]; V_batch[i * N + p] = c * vip - s * viq; - } else if (j == q) { - T vip = V_batch[i * N + p]; - T viq = V_batch[i * N + q]; V_batch[i * N + q] = s * vip + c * viq; } } } - // Compute U = A * V * S^(-1) (simplified for basic implementation) - // In practice, this would be done more efficiently + // Compute U = A * V * S^(-1) for left singular vectors if (i < M && j < N) { device T* U_batch = U + batch_idx * (M * M); - // For now, just initialize U as identity (placeholder) - U_batch[i * M + j] = (i == j && i < N) ? T(1) : T(0); + const device T* A_batch = A + batch_idx * params.matrix_stride; + const device T* V_batch = V + batch_idx * (N * N); + + // U[:, j] = A * V[:, j] / S[j] + // This is a simplified computation - in practice we'd need the singular values + T sum = T(0); + for (int k = 0; k < N; k++) { + sum += A_batch[i * N + k] * V_batch[k * N + j]; + } + + // For now, store the result without normalization + // Proper normalization would require the computed singular values + if (j < M) { + U_batch[i * M + j] = sum; + } } } @@ -227,6 +290,9 @@ decltype(svd_jacobi_iteration) svd_jacobi_iteration; template [[host_name("svd_extract_singular_values_float")]] [[kernel]] decltype(svd_extract_singular_values) svd_extract_singular_values; +template [[host_name("svd_check_convergence_float")]] [[kernel]] +decltype(svd_check_convergence) svd_check_convergence; + template [[host_name("svd_compute_vectors_float")]] [[kernel]] decltype(svd_compute_vectors) svd_compute_vectors; @@ -240,5 +306,8 @@ decltype(svd_jacobi_iteration) svd_jacobi_iteration; template [[host_name("svd_extract_singular_values_double")]] [[kernel]] decltype(svd_extract_singular_values) svd_extract_singular_values; +template [[host_name("svd_check_convergence_double")]] [[kernel]] +decltype(svd_check_convergence) svd_check_convergence; + template [[host_name("svd_compute_vectors_double")]] [[kernel]] decltype(svd_compute_vectors) svd_compute_vectors; diff --git a/mlx/backend/metal/svd.cpp b/mlx/backend/metal/svd.cpp index c26b593a9..9c69c5404 100644 --- a/mlx/backend/metal/svd.cpp +++ b/mlx/backend/metal/svd.cpp @@ -21,11 +21,62 @@ enum class SVDAlgorithm { }; SVDAlgorithm select_svd_algorithm(int M, int N, Dtype dtype) { - // For now, always use one-sided Jacobi - // Future PRs will add algorithm selection heuristics + // Algorithm selection based on matrix properties + + // For very large matrices, we might want different algorithms in the future + if (std::max(M, N) > 2048) { + // For now, still use Jacobi but with different parameters + return SVDAlgorithm::JACOBI_ONE_SIDED; + } + + // For very rectangular matrices, one-sided Jacobi is efficient + double aspect_ratio = static_cast(std::max(M, N)) / std::min(M, N); + if (aspect_ratio > 3.0) { + return SVDAlgorithm::JACOBI_ONE_SIDED; + } + + // Default to one-sided Jacobi for most cases return SVDAlgorithm::JACOBI_ONE_SIDED; } +/** + * Compute SVD parameters based on matrix size and algorithm + */ +SVDParams compute_svd_params( + int M, + int N, + size_t num_matrices, + bool compute_uv, + SVDAlgorithm algorithm) { + const int K = std::min(M, N); + + // Adjust parameters based on matrix size and algorithm + int max_iterations = 100; + float tolerance = 1e-6f; + + // For larger matrices, we might need more iterations + if (std::max(M, N) > 512) { + max_iterations = 200; + tolerance = 1e-5f; // Slightly relaxed tolerance for large matrices + } + + // For very small matrices, we can use tighter tolerance + if (std::max(M, N) < 64) { + tolerance = 1e-7f; + } + + return SVDParams{ + M, // M + N, // N + K, // K + max_iterations, // max_iterations + tolerance, // tolerance + static_cast(num_matrices), // batch_size + M * N, // matrix_stride + compute_uv // compute_uv + }; +} + /** * Validate SVD input parameters */ @@ -77,17 +128,10 @@ void svd_metal_impl( const int K = std::min(M, N); const size_t num_matrices = a.size() / (M * N); - // Set up SVD parameters - SVDParams params{ - M, // M - N, // N - K, // K - 100, // max_iterations - 1e-6f, // tolerance - static_cast(num_matrices), // batch_size - M * N, // matrix_stride - compute_uv // compute_uv - }; + // Select algorithm and compute parameters + SVDAlgorithm algorithm = select_svd_algorithm(M, N, a.dtype()); + SVDParams params = + compute_svd_params(M, N, num_matrices, compute_uv, algorithm); // Allocate workspace arrays array AtA({static_cast(num_matrices), N, N}, a.dtype(), nullptr, {}); @@ -102,6 +146,14 @@ void svd_metal_impl( {}); // JacobiRotation struct storage rotations.set_data(allocator::malloc(rotations.nbytes())); + // Allocate convergence tracking + array convergence_info( + {static_cast(num_matrices), 3}, + float32, + nullptr, + {}); // SVDConvergenceInfo struct storage + convergence_info.set_data(allocator::malloc(convergence_info.nbytes())); + // Get command encoder auto& compute_encoder = d.get_command_encoder(s.index); @@ -118,22 +170,40 @@ void svd_metal_impl( compute_encoder.dispatch_threads(grid_dims, group_dims); } - // Step 2: Jacobi iterations - for (int iter = 0; iter < params.max_iterations; iter++) { - auto kernel = - d.get_kernel("svd_jacobi_iteration_" + get_type_string(a.dtype())); - compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(AtA, 0); - compute_encoder.set_input_array(rotations, 1); - // Note: convergence checking would go here in a complete implementation - compute_encoder.set_bytes(params, 3); + // Step 2: Jacobi iterations with convergence checking + bool converged = false; + for (int iter = 0; iter < params.max_iterations && !converged; iter++) { + // Perform Jacobi iteration + { + auto kernel = + d.get_kernel("svd_jacobi_iteration_" + get_type_string(a.dtype())); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(AtA, 0); + compute_encoder.set_input_array(rotations, 1); + compute_encoder.set_input_array(convergence_info, 2); + compute_encoder.set_bytes(params, 3); - MTL::Size grid_dims = MTL::Size(total_pairs, 1, num_matrices); - MTL::Size group_dims = MTL::Size(std::min(256, total_pairs), 1, 1); - compute_encoder.dispatch_threads(grid_dims, group_dims); + MTL::Size grid_dims = MTL::Size(total_pairs, 1, num_matrices); + MTL::Size group_dims = MTL::Size(std::min(256, total_pairs), 1, 1); + compute_encoder.dispatch_threads(grid_dims, group_dims); + } - // In a complete implementation, we would check convergence here - // For now, we just run a fixed number of iterations + // Check convergence every few iterations to avoid overhead + if (iter % 5 == 4 || iter == params.max_iterations - 1) { + auto kernel = + d.get_kernel("svd_check_convergence_" + get_type_string(a.dtype())); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(AtA, 0); + compute_encoder.set_input_array(convergence_info, 1); + compute_encoder.set_bytes(params, 2); + + MTL::Size grid_dims = MTL::Size(1, 1, num_matrices); + MTL::Size group_dims = MTL::Size(256, 1, 1); + compute_encoder.dispatch_threads(grid_dims, group_dims); + + // Note: In a complete implementation, we would read back convergence + // status from GPU and break early. For now, we run all iterations. + } } // Step 3: Extract singular values @@ -173,7 +243,7 @@ void svd_metal_impl( } // Add temporary arrays for cleanup - d.add_temporaries({AtA, rotations}, s.index); + d.add_temporaries({AtA, rotations, convergence_info}, s.index); } // Explicit template instantiations