mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
feat: Add Metal SVD infrastructure and parameter structures
- Add SVDParams, JacobiRotation, and SVDConvergenceInfo structures - Create placeholder Metal kernel declarations for SVD operations - Add SVD kernel compilation to CMake build system - Update SVD::eval_gpu to dispatch to Metal implementation - Add basic input validation and error handling - Include placeholder kernel implementation for compilation This establishes the foundation for Metal SVD implementation. Actual algorithm implementation will follow in subsequent commits.
This commit is contained in:
parent
c8b4787e4e
commit
a71a9e0ddd
@ -52,6 +52,7 @@ if(MLX_METAL_JIT)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
make_jit_source(sort)
|
||||
make_jit_source(svd)
|
||||
make_jit_source(
|
||||
reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
|
||||
kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
|
||||
@ -110,6 +111,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
|
@ -27,6 +27,7 @@ const char* scan();
|
||||
const char* scatter_axis();
|
||||
const char* softmax();
|
||||
const char* sort();
|
||||
const char* svd();
|
||||
const char* reduce();
|
||||
|
||||
const char* gemm();
|
||||
|
@ -823,4 +823,21 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_svd_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
bool compute_uv) {
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::svd();
|
||||
// For now, just add a placeholder template definition
|
||||
// Actual kernel implementations will be added in subsequent PRs
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name, "svd_placeholder", get_type_string(out.dtype()));
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -241,6 +241,12 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
int wn,
|
||||
bool transpose);
|
||||
|
||||
MTL::ComputePipelineState* get_svd_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
bool compute_uv);
|
||||
|
||||
// Create a GPU kernel template definition for JIT compilation
|
||||
template <typename... Args>
|
||||
std::string
|
||||
|
@ -112,6 +112,7 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
build_kernel(sort sort.h)
|
||||
build_kernel(svd svd.h)
|
||||
build_kernel(ternary ternary.h ternary_ops.h)
|
||||
build_kernel(unary unary.h unary_ops.h)
|
||||
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
|
||||
|
39
mlx/backend/metal/kernels/svd.h
Normal file
39
mlx/backend/metal/kernels/svd.h
Normal file
@ -0,0 +1,39 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/**
|
||||
* Parameters for SVD Metal kernels
|
||||
*/
|
||||
struct SVDParams {
|
||||
const int M; // Matrix rows
|
||||
const int N; // Matrix columns
|
||||
const int K; // min(M, N) - number of singular values
|
||||
const int max_iterations; // Maximum Jacobi iterations
|
||||
const float tolerance; // Convergence threshold
|
||||
const int batch_size; // Number of matrices in batch
|
||||
const int64_t matrix_stride; // Stride between matrices in batch
|
||||
const bool compute_uv; // Whether to compute U and V matrices
|
||||
};
|
||||
|
||||
/**
|
||||
* Jacobi rotation parameters for SVD computation
|
||||
*/
|
||||
struct JacobiRotation {
|
||||
float cos_theta; // Cosine of rotation angle
|
||||
float sin_theta; // Sine of rotation angle
|
||||
int p, q; // Column indices for rotation (p < q)
|
||||
};
|
||||
|
||||
/**
|
||||
* Convergence tracking for iterative SVD algorithms
|
||||
*/
|
||||
struct SVDConvergenceInfo {
|
||||
float off_diagonal_norm; // Norm of off-diagonal elements
|
||||
int iteration_count; // Current iteration number
|
||||
bool converged; // Whether algorithm has converged
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
81
mlx/backend/metal/kernels/svd.metal
Normal file
81
mlx/backend/metal/kernels/svd.metal
Normal file
@ -0,0 +1,81 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/svd.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// Forward declarations for SVD kernels
|
||||
// These will be implemented in subsequent PRs
|
||||
|
||||
/**
|
||||
* Preprocess matrix for SVD computation
|
||||
* Computes A^T * A for one-sided Jacobi algorithm
|
||||
*/
|
||||
template <typename T>
|
||||
[[kernel]] void svd_preprocess(
|
||||
const device T* A [[buffer(0)]],
|
||||
device T* AtA [[buffer(1)]],
|
||||
const constant SVDParams& params [[buffer(2)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
/**
|
||||
* Perform one iteration of Jacobi rotations
|
||||
* Updates A^T * A matrix and tracks convergence
|
||||
*/
|
||||
template <typename T>
|
||||
[[kernel]] void svd_jacobi_iteration(
|
||||
device T* AtA [[buffer(0)]],
|
||||
device JacobiRotation* rotations [[buffer(1)]],
|
||||
device SVDConvergenceInfo* convergence [[buffer(2)]],
|
||||
const constant SVDParams& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
/**
|
||||
* Extract singular values from diagonalized matrix
|
||||
*/
|
||||
template <typename T>
|
||||
[[kernel]] void svd_extract_singular_values(
|
||||
const device T* AtA [[buffer(0)]],
|
||||
device T* S [[buffer(1)]],
|
||||
const constant SVDParams& params [[buffer(2)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]);
|
||||
|
||||
/**
|
||||
* Compute singular vectors U and V
|
||||
*/
|
||||
template <typename T>
|
||||
[[kernel]] void svd_compute_vectors(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device JacobiRotation* rotations [[buffer(1)]],
|
||||
device T* U [[buffer(2)]],
|
||||
device T* V [[buffer(3)]],
|
||||
const constant SVDParams& params [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
|
||||
// Placeholder kernel implementation for initial PR
|
||||
// This will be replaced with actual SVD implementation in subsequent PRs
|
||||
template <typename T>
|
||||
[[kernel]] void svd_placeholder(
|
||||
const device T* A [[buffer(0)]],
|
||||
device T* S [[buffer(1)]],
|
||||
const constant SVDParams& params [[buffer(2)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
// Placeholder implementation - just copy input to output for now
|
||||
// This ensures the kernel compiles and can be called
|
||||
uint index = tid.x;
|
||||
if (index < params.K) {
|
||||
S[index] = T(1.0); // Placeholder singular values
|
||||
}
|
||||
}
|
||||
|
||||
// Template instantiations for compilation
|
||||
template [[host_name("svd_placeholder_float")]] [[kernel]]
|
||||
decltype(svd_placeholder<float>) svd_placeholder<float>;
|
||||
|
||||
template [[host_name("svd_placeholder_double")]] [[kernel]]
|
||||
decltype(svd_placeholder<double>) svd_placeholder<double>;
|
@ -18,6 +18,15 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Forward declaration for SVD implementation
|
||||
template <typename T>
|
||||
void svd_metal_impl(
|
||||
const array& a,
|
||||
std::vector<array>& outputs,
|
||||
bool compute_uv,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
template <typename T>
|
||||
void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||
enc.set_bytes(start, 0);
|
||||
@ -331,7 +340,20 @@ void QRF::eval_gpu(
|
||||
void SVD::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI.");
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
svd_metal_impl<float>(inputs[0], outputs, compute_uv_, d, s);
|
||||
break;
|
||||
case float64:
|
||||
svd_metal_impl<double>(inputs[0], outputs, compute_uv_, d, s);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[SVD::eval_gpu] only supports float32 or float64.");
|
||||
}
|
||||
}
|
||||
|
||||
void Inverse::eval_gpu(const std::vector<array>& inputs, array& output) {
|
||||
|
105
mlx/backend/metal/svd.cpp
Normal file
105
mlx/backend/metal/svd.cpp
Normal file
@ -0,0 +1,105 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/svd.h"
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* Select appropriate SVD algorithm based on matrix properties
|
||||
*/
|
||||
enum class SVDAlgorithm {
|
||||
JACOBI_ONE_SIDED, // Default for most cases
|
||||
JACOBI_TWO_SIDED, // Better numerical stability (future)
|
||||
BIDIAGONAL_QR // For very large matrices (future)
|
||||
};
|
||||
|
||||
SVDAlgorithm select_svd_algorithm(int M, int N, Dtype dtype) {
|
||||
// For now, always use one-sided Jacobi
|
||||
// Future PRs will add algorithm selection heuristics
|
||||
return SVDAlgorithm::JACOBI_ONE_SIDED;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate SVD input parameters
|
||||
*/
|
||||
void validate_svd_inputs(const array& a) {
|
||||
if (a.ndim() < 2) {
|
||||
throw std::invalid_argument(
|
||||
"[SVD::eval_gpu] Input must have >= 2 dimensions");
|
||||
}
|
||||
|
||||
if (a.dtype() != float32 && a.dtype() != float64) {
|
||||
throw std::invalid_argument(
|
||||
"[SVD::eval_gpu] Only float32 and float64 supported");
|
||||
}
|
||||
|
||||
// Check for reasonable matrix size
|
||||
int M = a.shape(-2);
|
||||
int N = a.shape(-1);
|
||||
if (M > 4096 || N > 4096) {
|
||||
throw std::invalid_argument(
|
||||
"[SVD::eval_gpu] Matrix too large for current implementation. "
|
||||
"Maximum supported size is 4096x4096");
|
||||
}
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
throw std::invalid_argument(
|
||||
"[SVD::eval_gpu] Matrix dimensions must be positive");
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
/**
|
||||
* Metal implementation of SVD using one-sided Jacobi algorithm
|
||||
* This is a placeholder implementation that will be completed in subsequent PRs
|
||||
*/
|
||||
template <typename T>
|
||||
void svd_metal_impl(
|
||||
const array& a,
|
||||
std::vector<array>& outputs,
|
||||
bool compute_uv,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
// Validate inputs
|
||||
validate_svd_inputs(a);
|
||||
|
||||
// Extract matrix dimensions
|
||||
const int M = a.shape(-2);
|
||||
const int N = a.shape(-1);
|
||||
const int K = std::min(M, N);
|
||||
const size_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// TODO: Implement actual Metal SVD computation in subsequent PRs
|
||||
// For now, throw an informative error
|
||||
throw std::runtime_error(
|
||||
"[SVD::eval_gpu] Metal SVD implementation in progress. "
|
||||
"Matrix size: " +
|
||||
std::to_string(M) + "x" + std::to_string(N) +
|
||||
", batch size: " + std::to_string(num_matrices) +
|
||||
", compute_uv: " + (compute_uv ? "true" : "false"));
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template void svd_metal_impl<float>(
|
||||
const array& a,
|
||||
std::vector<array>& outputs,
|
||||
bool compute_uv,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
template void svd_metal_impl<double>(
|
||||
const array& a,
|
||||
std::vector<array>& outputs,
|
||||
bool compute_uv,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
Loading…
Reference in New Issue
Block a user