diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index d0c872451..0352738c2 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -52,6 +52,7 @@ if(MLX_METAL_JIT) make_jit_source(softmax) make_jit_source(scan) make_jit_source(sort) + make_jit_source(svd) make_jit_source( reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h) @@ -110,6 +111,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index 27ae22d05..1b623d25e 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -27,6 +27,7 @@ const char* scan(); const char* scatter_axis(); const char* softmax(); const char* sort(); +const char* svd(); const char* reduce(); const char* gemm(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 467380c3a..cb741ca1c 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -823,4 +823,21 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( return d.get_kernel(kernel_name, lib, hash_name, func_consts); } +MTL::ComputePipelineState* get_svd_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + bool compute_uv) { + auto lib = d.get_library(kernel_name, [&]() { + std::string kernel_source = metal::utils(); + kernel_source += metal::svd(); + // For now, just add a placeholder template definition + // Actual kernel implementations will be added in subsequent PRs + kernel_source += get_template_definition( + kernel_name, "svd_placeholder", get_type_string(out.dtype())); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 1de5fa47c..7ac030cec 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -241,6 +241,12 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( int wn, bool transpose); +MTL::ComputePipelineState* get_svd_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& out, + bool compute_uv); + // Create a GPU kernel template definition for JIT compilation template std::string diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 3ee88ca46..b610848e7 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -112,6 +112,7 @@ if(NOT MLX_METAL_JIT) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) build_kernel(sort sort.h) + build_kernel(svd svd.h) build_kernel(ternary ternary.h ternary_ops.h) build_kernel(unary unary.h unary_ops.h) build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS}) diff --git a/mlx/backend/metal/kernels/svd.h b/mlx/backend/metal/kernels/svd.h new file mode 100644 index 000000000..908336695 --- /dev/null +++ b/mlx/backend/metal/kernels/svd.h @@ -0,0 +1,39 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +namespace mlx::core { + +/** + * Parameters for SVD Metal kernels + */ +struct SVDParams { + const int M; // Matrix rows + const int N; // Matrix columns + const int K; // min(M, N) - number of singular values + const int max_iterations; // Maximum Jacobi iterations + const float tolerance; // Convergence threshold + const int batch_size; // Number of matrices in batch + const int64_t matrix_stride; // Stride between matrices in batch + const bool compute_uv; // Whether to compute U and V matrices +}; + +/** + * Jacobi rotation parameters for SVD computation + */ +struct JacobiRotation { + float cos_theta; // Cosine of rotation angle + float sin_theta; // Sine of rotation angle + int p, q; // Column indices for rotation (p < q) +}; + +/** + * Convergence tracking for iterative SVD algorithms + */ +struct SVDConvergenceInfo { + float off_diagonal_norm; // Norm of off-diagonal elements + int iteration_count; // Current iteration number + bool converged; // Whether algorithm has converged +}; + +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels/svd.metal b/mlx/backend/metal/kernels/svd.metal new file mode 100644 index 000000000..5c8947c69 --- /dev/null +++ b/mlx/backend/metal/kernels/svd.metal @@ -0,0 +1,81 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/svd.h" + +using namespace metal; + +// Forward declarations for SVD kernels +// These will be implemented in subsequent PRs + +/** + * Preprocess matrix for SVD computation + * Computes A^T * A for one-sided Jacobi algorithm + */ +template +[[kernel]] void svd_preprocess( + const device T* A [[buffer(0)]], + device T* AtA [[buffer(1)]], + const constant SVDParams& params [[buffer(2)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); + +/** + * Perform one iteration of Jacobi rotations + * Updates A^T * A matrix and tracks convergence + */ +template +[[kernel]] void svd_jacobi_iteration( + device T* AtA [[buffer(0)]], + device JacobiRotation* rotations [[buffer(1)]], + device SVDConvergenceInfo* convergence [[buffer(2)]], + const constant SVDParams& params [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); + +/** + * Extract singular values from diagonalized matrix + */ +template +[[kernel]] void svd_extract_singular_values( + const device T* AtA [[buffer(0)]], + device T* S [[buffer(1)]], + const constant SVDParams& params [[buffer(2)]], + uint3 tid [[threadgroup_position_in_grid]]); + +/** + * Compute singular vectors U and V + */ +template +[[kernel]] void svd_compute_vectors( + const device T* A [[buffer(0)]], + const device JacobiRotation* rotations [[buffer(1)]], + device T* U [[buffer(2)]], + device T* V [[buffer(3)]], + const constant SVDParams& params [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]); + +// Placeholder kernel implementation for initial PR +// This will be replaced with actual SVD implementation in subsequent PRs +template +[[kernel]] void svd_placeholder( + const device T* A [[buffer(0)]], + device T* S [[buffer(1)]], + const constant SVDParams& params [[buffer(2)]], + uint3 tid [[threadgroup_position_in_grid]]) { + // Placeholder implementation - just copy input to output for now + // This ensures the kernel compiles and can be called + uint index = tid.x; + if (index < params.K) { + S[index] = T(1.0); // Placeholder singular values + } +} + +// Template instantiations for compilation +template [[host_name("svd_placeholder_float")]] [[kernel]] +decltype(svd_placeholder) svd_placeholder; + +template [[host_name("svd_placeholder_double")]] [[kernel]] +decltype(svd_placeholder) svd_placeholder; diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 2ac543ad8..19f3ab446 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -18,6 +18,15 @@ namespace mlx::core { +// Forward declaration for SVD implementation +template +void svd_metal_impl( + const array& a, + std::vector& outputs, + bool compute_uv, + metal::Device& d, + const Stream& s); + template void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { enc.set_bytes(start, 0); @@ -331,7 +340,20 @@ void QRF::eval_gpu( void SVD::eval_gpu( const std::vector& inputs, std::vector& outputs) { - throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI."); + auto& s = stream(); + auto& d = metal::device(s.device); + + switch (inputs[0].dtype()) { + case float32: + svd_metal_impl(inputs[0], outputs, compute_uv_, d, s); + break; + case float64: + svd_metal_impl(inputs[0], outputs, compute_uv_, d, s); + break; + default: + throw std::runtime_error( + "[SVD::eval_gpu] only supports float32 or float64."); + } } void Inverse::eval_gpu(const std::vector& inputs, array& output) { diff --git a/mlx/backend/metal/svd.cpp b/mlx/backend/metal/svd.cpp new file mode 100644 index 000000000..1edca319e --- /dev/null +++ b/mlx/backend/metal/svd.cpp @@ -0,0 +1,105 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/svd.h" +#include "mlx/allocator.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +/** + * Select appropriate SVD algorithm based on matrix properties + */ +enum class SVDAlgorithm { + JACOBI_ONE_SIDED, // Default for most cases + JACOBI_TWO_SIDED, // Better numerical stability (future) + BIDIAGONAL_QR // For very large matrices (future) +}; + +SVDAlgorithm select_svd_algorithm(int M, int N, Dtype dtype) { + // For now, always use one-sided Jacobi + // Future PRs will add algorithm selection heuristics + return SVDAlgorithm::JACOBI_ONE_SIDED; +} + +/** + * Validate SVD input parameters + */ +void validate_svd_inputs(const array& a) { + if (a.ndim() < 2) { + throw std::invalid_argument( + "[SVD::eval_gpu] Input must have >= 2 dimensions"); + } + + if (a.dtype() != float32 && a.dtype() != float64) { + throw std::invalid_argument( + "[SVD::eval_gpu] Only float32 and float64 supported"); + } + + // Check for reasonable matrix size + int M = a.shape(-2); + int N = a.shape(-1); + if (M > 4096 || N > 4096) { + throw std::invalid_argument( + "[SVD::eval_gpu] Matrix too large for current implementation. " + "Maximum supported size is 4096x4096"); + } + + if (M == 0 || N == 0) { + throw std::invalid_argument( + "[SVD::eval_gpu] Matrix dimensions must be positive"); + } +} + +} // anonymous namespace + +/** + * Metal implementation of SVD using one-sided Jacobi algorithm + * This is a placeholder implementation that will be completed in subsequent PRs + */ +template +void svd_metal_impl( + const array& a, + std::vector& outputs, + bool compute_uv, + metal::Device& d, + const Stream& s) { + // Validate inputs + validate_svd_inputs(a); + + // Extract matrix dimensions + const int M = a.shape(-2); + const int N = a.shape(-1); + const int K = std::min(M, N); + const size_t num_matrices = a.size() / (M * N); + + // TODO: Implement actual Metal SVD computation in subsequent PRs + // For now, throw an informative error + throw std::runtime_error( + "[SVD::eval_gpu] Metal SVD implementation in progress. " + "Matrix size: " + + std::to_string(M) + "x" + std::to_string(N) + + ", batch size: " + std::to_string(num_matrices) + + ", compute_uv: " + (compute_uv ? "true" : "false")); +} + +// Explicit template instantiations +template void svd_metal_impl( + const array& a, + std::vector& outputs, + bool compute_uv, + metal::Device& d, + const Stream& s); + +template void svd_metal_impl( + const array& a, + std::vector& outputs, + bool compute_uv, + metal::Device& d, + const Stream& s); + +} // namespace mlx::core