diff --git a/benchmarks/python/svd_bench.py b/benchmarks/python/svd_bench.py new file mode 100644 index 000000000..5a7d5df75 --- /dev/null +++ b/benchmarks/python/svd_bench.py @@ -0,0 +1,183 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import time + +import mlx.core as mx +from time_utils import time_fn + + +def time_svd_square(): + """Benchmark SVD on square matrices of various sizes.""" + print("Benchmarking SVD on square matrices...") + + sizes = [64, 128, 256, 512] + + for size in sizes: + print(f"\n--- {size}x{size} matrix ---") + + # Create random matrix + a = mx.random.normal(shape=(size, size)) + mx.eval(a) + + # Benchmark singular values only + print(f"SVD (values only):") + time_fn(lambda x: mx.linalg.svd(x, compute_uv=False), a) + + # Benchmark full SVD + print(f"SVD (full decomposition):") + time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), a) + + +def time_svd_rectangular(): + """Benchmark SVD on rectangular matrices.""" + print("\nBenchmarking SVD on rectangular matrices...") + + shapes = [(128, 64), (64, 128), (256, 128), (128, 256)] + + for m, n in shapes: + print(f"\n--- {m}x{n} matrix ---") + + # Create random matrix + a = mx.random.normal(shape=(m, n)) + mx.eval(a) + + # Benchmark full SVD + print(f"SVD (full decomposition):") + time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), a) + + +def time_svd_batch(): + """Benchmark SVD on batched matrices.""" + print("\nBenchmarking SVD on batched matrices...") + + batch_configs = [ + (4, 64, 64), + (8, 32, 32), + (16, 16, 16), + ] + + for batch_size, m, n in batch_configs: + print(f"\n--- Batch of {batch_size} {m}x{n} matrices ---") + + # Create batch of random matrices + a = mx.random.normal(shape=(batch_size, m, n)) + mx.eval(a) + + # Benchmark full SVD + print(f"Batched SVD (full decomposition):") + time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), a) + + +def compare_cpu_gpu(): + """Compare CPU vs GPU performance for SVD.""" + print("\nComparing CPU vs GPU performance...") + + sizes = [64, 128, 256] + + for size in sizes: + print(f"\n--- {size}x{size} matrix comparison ---") + + # Create random matrix + a_cpu = mx.random.normal(shape=(size, size)) + mx.set_default_device(mx.cpu) + mx.eval(a_cpu) + + a_gpu = mx.array(a_cpu) + mx.set_default_device(mx.gpu) + mx.eval(a_gpu) + + # Time CPU SVD + mx.set_default_device(mx.cpu) + print("CPU SVD:") + start_time = time.time() + u_cpu, s_cpu, vt_cpu = mx.linalg.svd(a_cpu, compute_uv=True) + mx.eval(u_cpu, s_cpu, vt_cpu) + cpu_time = time.time() - start_time + + # Time GPU SVD + mx.set_default_device(mx.gpu) + print("GPU SVD:") + start_time = time.time() + u_gpu, s_gpu, vt_gpu = mx.linalg.svd(a_gpu, compute_uv=True) + mx.eval(u_gpu, s_gpu, vt_gpu) + gpu_time = time.time() - start_time + + speedup = cpu_time / gpu_time if gpu_time > 0 else float("inf") + print(f"CPU time: {cpu_time:.4f}s") + print(f"GPU time: {gpu_time:.4f}s") + print(f"Speedup: {speedup:.2f}x") + + # Verify results are close + mx.set_default_device(mx.cpu) + s_cpu_sorted = mx.sort(s_cpu) + mx.set_default_device(mx.gpu) + s_gpu_sorted = mx.sort(s_gpu) + mx.eval(s_cpu_sorted, s_gpu_sorted) + + # Convert to CPU for comparison + mx.set_default_device(mx.cpu) + s_gpu_cpu = mx.array(s_gpu_sorted) + mx.eval(s_gpu_cpu) + + diff = mx.max(mx.abs(s_cpu_sorted - s_gpu_cpu)) + mx.eval(diff) + print(f"Max singular value difference: {diff.item():.2e}") + + +def time_svd_special_matrices(): + """Benchmark SVD on special matrices (identity, diagonal, etc.).""" + print("\nBenchmarking SVD on special matrices...") + + size = 256 + + # Identity matrix + print(f"\n--- {size}x{size} identity matrix ---") + identity = mx.eye(size) + mx.eval(identity) + time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), identity) + + # Diagonal matrix + print(f"\n--- {size}x{size} diagonal matrix ---") + diag_vals = mx.random.uniform(shape=(size,)) + diagonal = mx.diag(diag_vals) + mx.eval(diagonal) + time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), diagonal) + + # Zero matrix + print(f"\n--- {size}x{size} zero matrix ---") + zero_matrix = mx.zeros((size, size)) + mx.eval(zero_matrix) + time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), zero_matrix) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("MLX SVD benchmarks.") + parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") + parser.add_argument( + "--compare", action="store_true", help="Compare CPU vs GPU performance." + ) + parser.add_argument("--all", action="store_true", help="Run all benchmarks.") + args = parser.parse_args() + + if args.gpu: + mx.set_default_device(mx.gpu) + print("Using GPU (Metal) backend") + else: + mx.set_default_device(mx.cpu) + print("Using CPU backend") + + if args.compare: + compare_cpu_gpu() + elif args.all: + time_svd_square() + time_svd_rectangular() + time_svd_batch() + time_svd_special_matrices() + if mx.metal.is_available(): + compare_cpu_gpu() + else: + time_svd_square() + if args.gpu and mx.metal.is_available(): + time_svd_rectangular() + time_svd_batch() diff --git a/mlx/backend/metal/svd.cpp b/mlx/backend/metal/svd.cpp index a4676cbfc..d1e9962df 100644 --- a/mlx/backend/metal/svd.cpp +++ b/mlx/backend/metal/svd.cpp @@ -11,32 +11,14 @@ #include "mlx/scheduler.h" /** - * COMPLETE METAL SVD IMPLEMENTATION - * - * This file implements a full GPU-accelerated SVD using the one-sided Jacobi + * Implementation of a full GPU-accelerated SVD using the one-sided Jacobi * algorithm. - * - * IMPLEMENTED FEATURES: - * ✅ Complete Jacobi iteration algorithm with proper Givens rotations - * ✅ A^T*A preprocessing for numerical stability - * ✅ Convergence checking based on off-diagonal Frobenius norm - * ✅ Singular value extraction via sqrt of eigenvalues - * ✅ Singular vector computation (both U and V^T) - * ✅ Batched operations for multiple matrices - * ✅ Proper Metal kernel orchestration and memory management - * ✅ Full integration with MLX primitive system - * ✅ Comprehensive test framework - * - * ALGORITHM: One-sided Jacobi SVD * - Computes A^T*A and diagonalizes it using Jacobi rotations * - Singular values: σᵢ = √λᵢ where λᵢ are eigenvalues of A^T*A * - Right singular vectors: V from eigenvectors of A^T*A - * - Left singular vectors: U = A*V*Σ⁻¹ + * - Left singular vectors: U = A*V*Σ^-1 * - * PERFORMANCE: Optimized for matrices up to 4096x4096 - * PRECISION: Float32 (Metal limitation) - * - * STATUS: Complete implementation ready for production use + * - Precision: Float32 (Metal limitation) */ namespace mlx::core { @@ -163,19 +145,6 @@ void validate_svd_inputs(const array& a) { } // anonymous namespace -/** - * Metal implementation of SVD using one-sided Jacobi algorithm - * - * IMPLEMENTED FEATURES: - * - Complete Jacobi iteration algorithm with proper rotation matrices - * - Convergence checking based on off-diagonal norm - * - Singular value extraction from diagonalized A^T*A - * - Singular vector computation (U and V^T) - * - Batched operations support - * - Full GPU acceleration using Metal compute kernels - * - * CURRENT STATUS: Working implementation with Metal GPU acceleration - */ template void svd_metal_impl( const array& a, diff --git a/tests/test_metal_svd.cpp b/tests/test_metal_svd.cpp index 4eb17fd9d..07016e923 100644 --- a/tests/test_metal_svd.cpp +++ b/tests/test_metal_svd.cpp @@ -91,8 +91,6 @@ TEST_CASE("test metal svd jacobi implementation") { CHECK(cpu_error.item() < 1e-5); CHECK(gpu_error.item() < 1e-2); // Relaxed tolerance for Jacobi method - - MESSAGE("✅ Metal Jacobi SVD implementation works!"); } TEST_CASE("test metal svd input validation") { @@ -286,11 +284,6 @@ TEST_CASE("test metal svd performance characteristics") { CHECK(u.shape() == std::vector{size, size}); CHECK(s.shape() == std::vector{size}); CHECK(vt.shape() == std::vector{size, size}); - - // Log timing for manual inspection - MESSAGE( - "SVD of " << size << "x" << size << " matrix took " - << duration.count() << "ms"); } } }