mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
feat(benchmarks): add comprehensive SVD performance benchmarks
Add benchmarks for Metal SVD implementation as required by CONTRIBUTING.md: - Square matrix benchmarks (64x64 to 512x512) - Rectangular matrix benchmarks - Batched matrix benchmarks - CPU vs GPU performance comparison - Special matrices (identity, diagonal, zero) Benchmarks validate performance improvements from GPU acceleration and help identify performance regressions in future changes. Usage: python benchmarks/python/svd_bench.py --gpu python benchmarks/python/svd_bench.py --compare python benchmarks/python/svd_bench.py --all
This commit is contained in:
parent
e5c8773371
commit
cb4dc59a9e
183
benchmarks/python/svd_bench.py
Normal file
183
benchmarks/python/svd_bench.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def time_svd_square():
|
||||||
|
"""Benchmark SVD on square matrices of various sizes."""
|
||||||
|
print("Benchmarking SVD on square matrices...")
|
||||||
|
|
||||||
|
sizes = [64, 128, 256, 512]
|
||||||
|
|
||||||
|
for size in sizes:
|
||||||
|
print(f"\n--- {size}x{size} matrix ---")
|
||||||
|
|
||||||
|
# Create random matrix
|
||||||
|
a = mx.random.normal(shape=(size, size))
|
||||||
|
mx.eval(a)
|
||||||
|
|
||||||
|
# Benchmark singular values only
|
||||||
|
print(f"SVD (values only):")
|
||||||
|
time_fn(lambda x: mx.linalg.svd(x, compute_uv=False), a)
|
||||||
|
|
||||||
|
# Benchmark full SVD
|
||||||
|
print(f"SVD (full decomposition):")
|
||||||
|
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), a)
|
||||||
|
|
||||||
|
|
||||||
|
def time_svd_rectangular():
|
||||||
|
"""Benchmark SVD on rectangular matrices."""
|
||||||
|
print("\nBenchmarking SVD on rectangular matrices...")
|
||||||
|
|
||||||
|
shapes = [(128, 64), (64, 128), (256, 128), (128, 256)]
|
||||||
|
|
||||||
|
for m, n in shapes:
|
||||||
|
print(f"\n--- {m}x{n} matrix ---")
|
||||||
|
|
||||||
|
# Create random matrix
|
||||||
|
a = mx.random.normal(shape=(m, n))
|
||||||
|
mx.eval(a)
|
||||||
|
|
||||||
|
# Benchmark full SVD
|
||||||
|
print(f"SVD (full decomposition):")
|
||||||
|
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), a)
|
||||||
|
|
||||||
|
|
||||||
|
def time_svd_batch():
|
||||||
|
"""Benchmark SVD on batched matrices."""
|
||||||
|
print("\nBenchmarking SVD on batched matrices...")
|
||||||
|
|
||||||
|
batch_configs = [
|
||||||
|
(4, 64, 64),
|
||||||
|
(8, 32, 32),
|
||||||
|
(16, 16, 16),
|
||||||
|
]
|
||||||
|
|
||||||
|
for batch_size, m, n in batch_configs:
|
||||||
|
print(f"\n--- Batch of {batch_size} {m}x{n} matrices ---")
|
||||||
|
|
||||||
|
# Create batch of random matrices
|
||||||
|
a = mx.random.normal(shape=(batch_size, m, n))
|
||||||
|
mx.eval(a)
|
||||||
|
|
||||||
|
# Benchmark full SVD
|
||||||
|
print(f"Batched SVD (full decomposition):")
|
||||||
|
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), a)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_cpu_gpu():
|
||||||
|
"""Compare CPU vs GPU performance for SVD."""
|
||||||
|
print("\nComparing CPU vs GPU performance...")
|
||||||
|
|
||||||
|
sizes = [64, 128, 256]
|
||||||
|
|
||||||
|
for size in sizes:
|
||||||
|
print(f"\n--- {size}x{size} matrix comparison ---")
|
||||||
|
|
||||||
|
# Create random matrix
|
||||||
|
a_cpu = mx.random.normal(shape=(size, size))
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
mx.eval(a_cpu)
|
||||||
|
|
||||||
|
a_gpu = mx.array(a_cpu)
|
||||||
|
mx.set_default_device(mx.gpu)
|
||||||
|
mx.eval(a_gpu)
|
||||||
|
|
||||||
|
# Time CPU SVD
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
print("CPU SVD:")
|
||||||
|
start_time = time.time()
|
||||||
|
u_cpu, s_cpu, vt_cpu = mx.linalg.svd(a_cpu, compute_uv=True)
|
||||||
|
mx.eval(u_cpu, s_cpu, vt_cpu)
|
||||||
|
cpu_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Time GPU SVD
|
||||||
|
mx.set_default_device(mx.gpu)
|
||||||
|
print("GPU SVD:")
|
||||||
|
start_time = time.time()
|
||||||
|
u_gpu, s_gpu, vt_gpu = mx.linalg.svd(a_gpu, compute_uv=True)
|
||||||
|
mx.eval(u_gpu, s_gpu, vt_gpu)
|
||||||
|
gpu_time = time.time() - start_time
|
||||||
|
|
||||||
|
speedup = cpu_time / gpu_time if gpu_time > 0 else float("inf")
|
||||||
|
print(f"CPU time: {cpu_time:.4f}s")
|
||||||
|
print(f"GPU time: {gpu_time:.4f}s")
|
||||||
|
print(f"Speedup: {speedup:.2f}x")
|
||||||
|
|
||||||
|
# Verify results are close
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
s_cpu_sorted = mx.sort(s_cpu)
|
||||||
|
mx.set_default_device(mx.gpu)
|
||||||
|
s_gpu_sorted = mx.sort(s_gpu)
|
||||||
|
mx.eval(s_cpu_sorted, s_gpu_sorted)
|
||||||
|
|
||||||
|
# Convert to CPU for comparison
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
s_gpu_cpu = mx.array(s_gpu_sorted)
|
||||||
|
mx.eval(s_gpu_cpu)
|
||||||
|
|
||||||
|
diff = mx.max(mx.abs(s_cpu_sorted - s_gpu_cpu))
|
||||||
|
mx.eval(diff)
|
||||||
|
print(f"Max singular value difference: {diff.item():.2e}")
|
||||||
|
|
||||||
|
|
||||||
|
def time_svd_special_matrices():
|
||||||
|
"""Benchmark SVD on special matrices (identity, diagonal, etc.)."""
|
||||||
|
print("\nBenchmarking SVD on special matrices...")
|
||||||
|
|
||||||
|
size = 256
|
||||||
|
|
||||||
|
# Identity matrix
|
||||||
|
print(f"\n--- {size}x{size} identity matrix ---")
|
||||||
|
identity = mx.eye(size)
|
||||||
|
mx.eval(identity)
|
||||||
|
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), identity)
|
||||||
|
|
||||||
|
# Diagonal matrix
|
||||||
|
print(f"\n--- {size}x{size} diagonal matrix ---")
|
||||||
|
diag_vals = mx.random.uniform(shape=(size,))
|
||||||
|
diagonal = mx.diag(diag_vals)
|
||||||
|
mx.eval(diagonal)
|
||||||
|
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), diagonal)
|
||||||
|
|
||||||
|
# Zero matrix
|
||||||
|
print(f"\n--- {size}x{size} zero matrix ---")
|
||||||
|
zero_matrix = mx.zeros((size, size))
|
||||||
|
mx.eval(zero_matrix)
|
||||||
|
time_fn(lambda x: mx.linalg.svd(x, compute_uv=True), zero_matrix)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("MLX SVD benchmarks.")
|
||||||
|
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--compare", action="store_true", help="Compare CPU vs GPU performance."
|
||||||
|
)
|
||||||
|
parser.add_argument("--all", action="store_true", help="Run all benchmarks.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.gpu:
|
||||||
|
mx.set_default_device(mx.gpu)
|
||||||
|
print("Using GPU (Metal) backend")
|
||||||
|
else:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
print("Using CPU backend")
|
||||||
|
|
||||||
|
if args.compare:
|
||||||
|
compare_cpu_gpu()
|
||||||
|
elif args.all:
|
||||||
|
time_svd_square()
|
||||||
|
time_svd_rectangular()
|
||||||
|
time_svd_batch()
|
||||||
|
time_svd_special_matrices()
|
||||||
|
if mx.metal.is_available():
|
||||||
|
compare_cpu_gpu()
|
||||||
|
else:
|
||||||
|
time_svd_square()
|
||||||
|
if args.gpu and mx.metal.is_available():
|
||||||
|
time_svd_rectangular()
|
||||||
|
time_svd_batch()
|
@ -11,32 +11,14 @@
|
|||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* COMPLETE METAL SVD IMPLEMENTATION
|
* Implementation of a full GPU-accelerated SVD using the one-sided Jacobi
|
||||||
*
|
|
||||||
* This file implements a full GPU-accelerated SVD using the one-sided Jacobi
|
|
||||||
* algorithm.
|
* algorithm.
|
||||||
*
|
|
||||||
* IMPLEMENTED FEATURES:
|
|
||||||
* ✅ Complete Jacobi iteration algorithm with proper Givens rotations
|
|
||||||
* ✅ A^T*A preprocessing for numerical stability
|
|
||||||
* ✅ Convergence checking based on off-diagonal Frobenius norm
|
|
||||||
* ✅ Singular value extraction via sqrt of eigenvalues
|
|
||||||
* ✅ Singular vector computation (both U and V^T)
|
|
||||||
* ✅ Batched operations for multiple matrices
|
|
||||||
* ✅ Proper Metal kernel orchestration and memory management
|
|
||||||
* ✅ Full integration with MLX primitive system
|
|
||||||
* ✅ Comprehensive test framework
|
|
||||||
*
|
|
||||||
* ALGORITHM: One-sided Jacobi SVD
|
|
||||||
* - Computes A^T*A and diagonalizes it using Jacobi rotations
|
* - Computes A^T*A and diagonalizes it using Jacobi rotations
|
||||||
* - Singular values: σᵢ = √λᵢ where λᵢ are eigenvalues of A^T*A
|
* - Singular values: σᵢ = √λᵢ where λᵢ are eigenvalues of A^T*A
|
||||||
* - Right singular vectors: V from eigenvectors of A^T*A
|
* - Right singular vectors: V from eigenvectors of A^T*A
|
||||||
* - Left singular vectors: U = A*V*Σ⁻¹
|
* - Left singular vectors: U = A*V*Σ^-1
|
||||||
*
|
*
|
||||||
* PERFORMANCE: Optimized for matrices up to 4096x4096
|
* - Precision: Float32 (Metal limitation)
|
||||||
* PRECISION: Float32 (Metal limitation)
|
|
||||||
*
|
|
||||||
* STATUS: Complete implementation ready for production use
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -163,19 +145,6 @@ void validate_svd_inputs(const array& a) {
|
|||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
/**
|
|
||||||
* Metal implementation of SVD using one-sided Jacobi algorithm
|
|
||||||
*
|
|
||||||
* IMPLEMENTED FEATURES:
|
|
||||||
* - Complete Jacobi iteration algorithm with proper rotation matrices
|
|
||||||
* - Convergence checking based on off-diagonal norm
|
|
||||||
* - Singular value extraction from diagonalized A^T*A
|
|
||||||
* - Singular vector computation (U and V^T)
|
|
||||||
* - Batched operations support
|
|
||||||
* - Full GPU acceleration using Metal compute kernels
|
|
||||||
*
|
|
||||||
* CURRENT STATUS: Working implementation with Metal GPU acceleration
|
|
||||||
*/
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void svd_metal_impl(
|
void svd_metal_impl(
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -91,8 +91,6 @@ TEST_CASE("test metal svd jacobi implementation") {
|
|||||||
|
|
||||||
CHECK(cpu_error.item<float>() < 1e-5);
|
CHECK(cpu_error.item<float>() < 1e-5);
|
||||||
CHECK(gpu_error.item<float>() < 1e-2); // Relaxed tolerance for Jacobi method
|
CHECK(gpu_error.item<float>() < 1e-2); // Relaxed tolerance for Jacobi method
|
||||||
|
|
||||||
MESSAGE("✅ Metal Jacobi SVD implementation works!");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test metal svd input validation") {
|
TEST_CASE("test metal svd input validation") {
|
||||||
@ -286,11 +284,6 @@ TEST_CASE("test metal svd performance characteristics") {
|
|||||||
CHECK(u.shape() == std::vector<int>{size, size});
|
CHECK(u.shape() == std::vector<int>{size, size});
|
||||||
CHECK(s.shape() == std::vector<int>{size});
|
CHECK(s.shape() == std::vector<int>{size});
|
||||||
CHECK(vt.shape() == std::vector<int>{size, size});
|
CHECK(vt.shape() == std::vector<int>{size, size});
|
||||||
|
|
||||||
// Log timing for manual inspection
|
|
||||||
MESSAGE(
|
|
||||||
"SVD of " << size << "x" << size << " matrix took "
|
|
||||||
<< duration.count() << "ms");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user