feat: Add Metal SVD infrastructure and parameter structures

- Add SVDParams, JacobiRotation, and SVDConvergenceInfo structures
- Create placeholder Metal kernel declarations for SVD operations
- Add SVD kernel compilation to CMake build system
- Update SVD::eval_gpu to dispatch to Metal implementation
- Add basic input validation and error handling
- Include placeholder kernel implementation for compilation

This establishes the foundation for Metal SVD implementation.
Actual algorithm implementation will follow in subsequent commits.
This commit is contained in:
Arkar Min Aung 2025-06-13 23:28:52 +10:00
parent c8b4787e4e
commit a71a9e0ddd
9 changed files with 275 additions and 1 deletions

View File

@ -52,6 +52,7 @@ if(MLX_METAL_JIT)
make_jit_source(softmax)
make_jit_source(scan)
make_jit_source(sort)
make_jit_source(svd)
make_jit_source(
reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
@ -110,6 +111,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp

View File

@ -27,6 +27,7 @@ const char* scan();
const char* scatter_axis();
const char* softmax();
const char* sort();
const char* svd();
const char* reduce();
const char* gemm();

View File

@ -823,4 +823,21 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_svd_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
bool compute_uv) {
auto lib = d.get_library(kernel_name, [&]() {
std::string kernel_source = metal::utils();
kernel_source += metal::svd();
// For now, just add a placeholder template definition
// Actual kernel implementations will be added in subsequent PRs
kernel_source += get_template_definition(
kernel_name, "svd_placeholder", get_type_string(out.dtype()));
return kernel_source;
});
return d.get_kernel(kernel_name, lib);
}
} // namespace mlx::core

View File

@ -241,6 +241,12 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
int wn,
bool transpose);
MTL::ComputePipelineState* get_svd_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
bool compute_uv);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string

View File

@ -112,6 +112,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h)
build_kernel(svd svd.h)
build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h)
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})

View File

@ -0,0 +1,39 @@
// Copyright © 2024 Apple Inc.
#pragma once
namespace mlx::core {
/**
* Parameters for SVD Metal kernels
*/
struct SVDParams {
const int M; // Matrix rows
const int N; // Matrix columns
const int K; // min(M, N) - number of singular values
const int max_iterations; // Maximum Jacobi iterations
const float tolerance; // Convergence threshold
const int batch_size; // Number of matrices in batch
const int64_t matrix_stride; // Stride between matrices in batch
const bool compute_uv; // Whether to compute U and V matrices
};
/**
* Jacobi rotation parameters for SVD computation
*/
struct JacobiRotation {
float cos_theta; // Cosine of rotation angle
float sin_theta; // Sine of rotation angle
int p, q; // Column indices for rotation (p < q)
};
/**
* Convergence tracking for iterative SVD algorithms
*/
struct SVDConvergenceInfo {
float off_diagonal_norm; // Norm of off-diagonal elements
int iteration_count; // Current iteration number
bool converged; // Whether algorithm has converged
};
} // namespace mlx::core

View File

@ -0,0 +1,81 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/svd.h"
using namespace metal;
// Forward declarations for SVD kernels
// These will be implemented in subsequent PRs
/**
* Preprocess matrix for SVD computation
* Computes A^T * A for one-sided Jacobi algorithm
*/
template <typename T>
[[kernel]] void svd_preprocess(
const device T* A [[buffer(0)]],
device T* AtA [[buffer(1)]],
const constant SVDParams& params [[buffer(2)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
/**
* Perform one iteration of Jacobi rotations
* Updates A^T * A matrix and tracks convergence
*/
template <typename T>
[[kernel]] void svd_jacobi_iteration(
device T* AtA [[buffer(0)]],
device JacobiRotation* rotations [[buffer(1)]],
device SVDConvergenceInfo* convergence [[buffer(2)]],
const constant SVDParams& params [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
/**
* Extract singular values from diagonalized matrix
*/
template <typename T>
[[kernel]] void svd_extract_singular_values(
const device T* AtA [[buffer(0)]],
device T* S [[buffer(1)]],
const constant SVDParams& params [[buffer(2)]],
uint3 tid [[threadgroup_position_in_grid]]);
/**
* Compute singular vectors U and V
*/
template <typename T>
[[kernel]] void svd_compute_vectors(
const device T* A [[buffer(0)]],
const device JacobiRotation* rotations [[buffer(1)]],
device T* U [[buffer(2)]],
device T* V [[buffer(3)]],
const constant SVDParams& params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
// Placeholder kernel implementation for initial PR
// This will be replaced with actual SVD implementation in subsequent PRs
template <typename T>
[[kernel]] void svd_placeholder(
const device T* A [[buffer(0)]],
device T* S [[buffer(1)]],
const constant SVDParams& params [[buffer(2)]],
uint3 tid [[threadgroup_position_in_grid]]) {
// Placeholder implementation - just copy input to output for now
// This ensures the kernel compiles and can be called
uint index = tid.x;
if (index < params.K) {
S[index] = T(1.0); // Placeholder singular values
}
}
// Template instantiations for compilation
template [[host_name("svd_placeholder_float")]] [[kernel]]
decltype(svd_placeholder<float>) svd_placeholder<float>;
template [[host_name("svd_placeholder_double")]] [[kernel]]
decltype(svd_placeholder<double>) svd_placeholder<double>;

View File

@ -18,6 +18,15 @@
namespace mlx::core {
// Forward declaration for SVD implementation
template <typename T>
void svd_metal_impl(
const array& a,
std::vector<array>& outputs,
bool compute_uv,
metal::Device& d,
const Stream& s);
template <typename T>
void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
enc.set_bytes(start, 0);
@ -331,7 +340,20 @@ void QRF::eval_gpu(
void SVD::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI.");
auto& s = stream();
auto& d = metal::device(s.device);
switch (inputs[0].dtype()) {
case float32:
svd_metal_impl<float>(inputs[0], outputs, compute_uv_, d, s);
break;
case float64:
svd_metal_impl<double>(inputs[0], outputs, compute_uv_, d, s);
break;
default:
throw std::runtime_error(
"[SVD::eval_gpu] only supports float32 or float64.");
}
}
void Inverse::eval_gpu(const std::vector<array>& inputs, array& output) {

105
mlx/backend/metal/svd.cpp Normal file
View File

@ -0,0 +1,105 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/svd.h"
#include "mlx/allocator.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
/**
* Select appropriate SVD algorithm based on matrix properties
*/
enum class SVDAlgorithm {
JACOBI_ONE_SIDED, // Default for most cases
JACOBI_TWO_SIDED, // Better numerical stability (future)
BIDIAGONAL_QR // For very large matrices (future)
};
SVDAlgorithm select_svd_algorithm(int M, int N, Dtype dtype) {
// For now, always use one-sided Jacobi
// Future PRs will add algorithm selection heuristics
return SVDAlgorithm::JACOBI_ONE_SIDED;
}
/**
* Validate SVD input parameters
*/
void validate_svd_inputs(const array& a) {
if (a.ndim() < 2) {
throw std::invalid_argument(
"[SVD::eval_gpu] Input must have >= 2 dimensions");
}
if (a.dtype() != float32 && a.dtype() != float64) {
throw std::invalid_argument(
"[SVD::eval_gpu] Only float32 and float64 supported");
}
// Check for reasonable matrix size
int M = a.shape(-2);
int N = a.shape(-1);
if (M > 4096 || N > 4096) {
throw std::invalid_argument(
"[SVD::eval_gpu] Matrix too large for current implementation. "
"Maximum supported size is 4096x4096");
}
if (M == 0 || N == 0) {
throw std::invalid_argument(
"[SVD::eval_gpu] Matrix dimensions must be positive");
}
}
} // anonymous namespace
/**
* Metal implementation of SVD using one-sided Jacobi algorithm
* This is a placeholder implementation that will be completed in subsequent PRs
*/
template <typename T>
void svd_metal_impl(
const array& a,
std::vector<array>& outputs,
bool compute_uv,
metal::Device& d,
const Stream& s) {
// Validate inputs
validate_svd_inputs(a);
// Extract matrix dimensions
const int M = a.shape(-2);
const int N = a.shape(-1);
const int K = std::min(M, N);
const size_t num_matrices = a.size() / (M * N);
// TODO: Implement actual Metal SVD computation in subsequent PRs
// For now, throw an informative error
throw std::runtime_error(
"[SVD::eval_gpu] Metal SVD implementation in progress. "
"Matrix size: " +
std::to_string(M) + "x" + std::to_string(N) +
", batch size: " + std::to_string(num_matrices) +
", compute_uv: " + (compute_uv ? "true" : "false"));
}
// Explicit template instantiations
template void svd_metal_impl<float>(
const array& a,
std::vector<array>& outputs,
bool compute_uv,
metal::Device& d,
const Stream& s);
template void svd_metal_impl<double>(
const array& a,
std::vector<array>& outputs,
bool compute_uv,
metal::Device& d,
const Stream& s);
} // namespace mlx::core