mlx/benchmarks/python/svd_bench.py
Arkar Min Aung cb4dc59a9e feat(benchmarks): add comprehensive SVD performance benchmarks
Add benchmarks for Metal SVD implementation as required by CONTRIBUTING.md:
- Square matrix benchmarks (64x64 to 512x512)
- Rectangular matrix benchmarks
- Batched matrix benchmarks
- CPU vs GPU performance comparison
- Special matrices (identity, diagonal, zero)

Benchmarks validate performance improvements from GPU acceleration
and help identify performance regressions in future changes.

Usage:
  python benchmarks/python/svd_bench.py --gpu
  python benchmarks/python/svd_bench.py --compare
  python benchmarks/python/svd_bench.py --all
2025-06-15 18:09:11 +10:00

184 lines
5.3 KiB
Python

# Copyright © 2023 Apple Inc.
import argparse
import time
import mlx.core as mx
from time_utils import time_fn
def time_svd_square():
"""Benchmark SVD on square matrices of various sizes."""
print("Benchmarking SVD on square matrices...")
sizes = [64, 128, 256, 512]
for size in sizes:
print(f"\n--- {size}x{size} matrix ---")
# Create random matrix
a = mx.random.normal(shape=(size, size))
mx.eval(a)
# Benchmark singular values only
print(f"SVD (values only):")
time_fn(lambda x: mx.linalg.svd(x, compute_uv=False), a)
# Benchmark full SVD
print(f"SVD (full decomposition):")
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), a)
def time_svd_rectangular():
"""Benchmark SVD on rectangular matrices."""
print("\nBenchmarking SVD on rectangular matrices...")
shapes = [(128, 64), (64, 128), (256, 128), (128, 256)]
for m, n in shapes:
print(f"\n--- {m}x{n} matrix ---")
# Create random matrix
a = mx.random.normal(shape=(m, n))
mx.eval(a)
# Benchmark full SVD
print(f"SVD (full decomposition):")
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), a)
def time_svd_batch():
"""Benchmark SVD on batched matrices."""
print("\nBenchmarking SVD on batched matrices...")
batch_configs = [
(4, 64, 64),
(8, 32, 32),
(16, 16, 16),
]
for batch_size, m, n in batch_configs:
print(f"\n--- Batch of {batch_size} {m}x{n} matrices ---")
# Create batch of random matrices
a = mx.random.normal(shape=(batch_size, m, n))
mx.eval(a)
# Benchmark full SVD
print(f"Batched SVD (full decomposition):")
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), a)
def compare_cpu_gpu():
"""Compare CPU vs GPU performance for SVD."""
print("\nComparing CPU vs GPU performance...")
sizes = [64, 128, 256]
for size in sizes:
print(f"\n--- {size}x{size} matrix comparison ---")
# Create random matrix
a_cpu = mx.random.normal(shape=(size, size))
mx.set_default_device(mx.cpu)
mx.eval(a_cpu)
a_gpu = mx.array(a_cpu)
mx.set_default_device(mx.gpu)
mx.eval(a_gpu)
# Time CPU SVD
mx.set_default_device(mx.cpu)
print("CPU SVD:")
start_time = time.time()
u_cpu, s_cpu, vt_cpu = mx.linalg.svd(a_cpu, compute_uv=True)
mx.eval(u_cpu, s_cpu, vt_cpu)
cpu_time = time.time() - start_time
# Time GPU SVD
mx.set_default_device(mx.gpu)
print("GPU SVD:")
start_time = time.time()
u_gpu, s_gpu, vt_gpu = mx.linalg.svd(a_gpu, compute_uv=True)
mx.eval(u_gpu, s_gpu, vt_gpu)
gpu_time = time.time() - start_time
speedup = cpu_time / gpu_time if gpu_time > 0 else float("inf")
print(f"CPU time: {cpu_time:.4f}s")
print(f"GPU time: {gpu_time:.4f}s")
print(f"Speedup: {speedup:.2f}x")
# Verify results are close
mx.set_default_device(mx.cpu)
s_cpu_sorted = mx.sort(s_cpu)
mx.set_default_device(mx.gpu)
s_gpu_sorted = mx.sort(s_gpu)
mx.eval(s_cpu_sorted, s_gpu_sorted)
# Convert to CPU for comparison
mx.set_default_device(mx.cpu)
s_gpu_cpu = mx.array(s_gpu_sorted)
mx.eval(s_gpu_cpu)
diff = mx.max(mx.abs(s_cpu_sorted - s_gpu_cpu))
mx.eval(diff)
print(f"Max singular value difference: {diff.item():.2e}")
def time_svd_special_matrices():
"""Benchmark SVD on special matrices (identity, diagonal, etc.)."""
print("\nBenchmarking SVD on special matrices...")
size = 256
# Identity matrix
print(f"\n--- {size}x{size} identity matrix ---")
identity = mx.eye(size)
mx.eval(identity)
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), identity)
# Diagonal matrix
print(f"\n--- {size}x{size} diagonal matrix ---")
diag_vals = mx.random.uniform(shape=(size,))
diagonal = mx.diag(diag_vals)
mx.eval(diagonal)
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), diagonal)
# Zero matrix
print(f"\n--- {size}x{size} zero matrix ---")
zero_matrix = mx.zeros((size, size))
mx.eval(zero_matrix)
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), zero_matrix)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX SVD benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
parser.add_argument(
"--compare", action="store_true", help="Compare CPU vs GPU performance."
)
parser.add_argument("--all", action="store_true", help="Run all benchmarks.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
print("Using GPU (Metal) backend")
else:
mx.set_default_device(mx.cpu)
print("Using CPU backend")
if args.compare:
compare_cpu_gpu()
elif args.all:
time_svd_square()
time_svd_rectangular()
time_svd_batch()
time_svd_special_matrices()
if mx.metal.is_available():
compare_cpu_gpu()
else:
time_svd_square()
if args.gpu and mx.metal.is_available():
time_svd_rectangular()
time_svd_batch()