mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
feat: Add Metal SVD infrastructure and parameter structures
- Add SVDParams, JacobiRotation, and SVDConvergenceInfo structures - Create placeholder Metal kernel declarations for SVD operations - Add SVD kernel compilation to CMake build system - Update SVD::eval_gpu to dispatch to Metal implementation - Add basic input validation and error handling - Include placeholder kernel implementation for compilation This establishes the foundation for Metal SVD implementation. Actual algorithm implementation will follow in subsequent commits.
This commit is contained in:
parent
c8b4787e4e
commit
a71a9e0ddd
@ -52,6 +52,7 @@ if(MLX_METAL_JIT)
|
|||||||
make_jit_source(softmax)
|
make_jit_source(softmax)
|
||||||
make_jit_source(scan)
|
make_jit_source(scan)
|
||||||
make_jit_source(sort)
|
make_jit_source(sort)
|
||||||
|
make_jit_source(svd)
|
||||||
make_jit_source(
|
make_jit_source(
|
||||||
reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
|
reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
|
||||||
kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
|
kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
|
||||||
@ -110,6 +111,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||||
|
@ -27,6 +27,7 @@ const char* scan();
|
|||||||
const char* scatter_axis();
|
const char* scatter_axis();
|
||||||
const char* softmax();
|
const char* softmax();
|
||||||
const char* sort();
|
const char* sort();
|
||||||
|
const char* svd();
|
||||||
const char* reduce();
|
const char* reduce();
|
||||||
|
|
||||||
const char* gemm();
|
const char* gemm();
|
||||||
|
@ -823,4 +823,21 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|||||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_svd_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out,
|
||||||
|
bool compute_uv) {
|
||||||
|
auto lib = d.get_library(kernel_name, [&]() {
|
||||||
|
std::string kernel_source = metal::utils();
|
||||||
|
kernel_source += metal::svd();
|
||||||
|
// For now, just add a placeholder template definition
|
||||||
|
// Actual kernel implementations will be added in subsequent PRs
|
||||||
|
kernel_source += get_template_definition(
|
||||||
|
kernel_name, "svd_placeholder", get_type_string(out.dtype()));
|
||||||
|
return kernel_source;
|
||||||
|
});
|
||||||
|
return d.get_kernel(kernel_name, lib);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -241,6 +241,12 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|||||||
int wn,
|
int wn,
|
||||||
bool transpose);
|
bool transpose);
|
||||||
|
|
||||||
|
MTL::ComputePipelineState* get_svd_kernel(
|
||||||
|
metal::Device& d,
|
||||||
|
const std::string& kernel_name,
|
||||||
|
const array& out,
|
||||||
|
bool compute_uv);
|
||||||
|
|
||||||
// Create a GPU kernel template definition for JIT compilation
|
// Create a GPU kernel template definition for JIT compilation
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
std::string
|
std::string
|
||||||
|
@ -112,6 +112,7 @@ if(NOT MLX_METAL_JIT)
|
|||||||
build_kernel(softmax softmax.h)
|
build_kernel(softmax softmax.h)
|
||||||
build_kernel(logsumexp logsumexp.h)
|
build_kernel(logsumexp logsumexp.h)
|
||||||
build_kernel(sort sort.h)
|
build_kernel(sort sort.h)
|
||||||
|
build_kernel(svd svd.h)
|
||||||
build_kernel(ternary ternary.h ternary_ops.h)
|
build_kernel(ternary ternary.h ternary_ops.h)
|
||||||
build_kernel(unary unary.h unary_ops.h)
|
build_kernel(unary unary.h unary_ops.h)
|
||||||
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
|
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
|
||||||
|
39
mlx/backend/metal/kernels/svd.h
Normal file
39
mlx/backend/metal/kernels/svd.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parameters for SVD Metal kernels
|
||||||
|
*/
|
||||||
|
struct SVDParams {
|
||||||
|
const int M; // Matrix rows
|
||||||
|
const int N; // Matrix columns
|
||||||
|
const int K; // min(M, N) - number of singular values
|
||||||
|
const int max_iterations; // Maximum Jacobi iterations
|
||||||
|
const float tolerance; // Convergence threshold
|
||||||
|
const int batch_size; // Number of matrices in batch
|
||||||
|
const int64_t matrix_stride; // Stride between matrices in batch
|
||||||
|
const bool compute_uv; // Whether to compute U and V matrices
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Jacobi rotation parameters for SVD computation
|
||||||
|
*/
|
||||||
|
struct JacobiRotation {
|
||||||
|
float cos_theta; // Cosine of rotation angle
|
||||||
|
float sin_theta; // Sine of rotation angle
|
||||||
|
int p, q; // Column indices for rotation (p < q)
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convergence tracking for iterative SVD algorithms
|
||||||
|
*/
|
||||||
|
struct SVDConvergenceInfo {
|
||||||
|
float off_diagonal_norm; // Norm of off-diagonal elements
|
||||||
|
int iteration_count; // Current iteration number
|
||||||
|
bool converged; // Whether algorithm has converged
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
81
mlx/backend/metal/kernels/svd.metal
Normal file
81
mlx/backend/metal/kernels/svd.metal
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
#include "mlx/backend/metal/kernels/svd.h"
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
// Forward declarations for SVD kernels
|
||||||
|
// These will be implemented in subsequent PRs
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Preprocess matrix for SVD computation
|
||||||
|
* Computes A^T * A for one-sided Jacobi algorithm
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
[[kernel]] void svd_preprocess(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
device T* AtA [[buffer(1)]],
|
||||||
|
const constant SVDParams& params [[buffer(2)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform one iteration of Jacobi rotations
|
||||||
|
* Updates A^T * A matrix and tracks convergence
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
[[kernel]] void svd_jacobi_iteration(
|
||||||
|
device T* AtA [[buffer(0)]],
|
||||||
|
device JacobiRotation* rotations [[buffer(1)]],
|
||||||
|
device SVDConvergenceInfo* convergence [[buffer(2)]],
|
||||||
|
const constant SVDParams& params [[buffer(3)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract singular values from diagonalized matrix
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
[[kernel]] void svd_extract_singular_values(
|
||||||
|
const device T* AtA [[buffer(0)]],
|
||||||
|
device T* S [[buffer(1)]],
|
||||||
|
const constant SVDParams& params [[buffer(2)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute singular vectors U and V
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
[[kernel]] void svd_compute_vectors(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
const device JacobiRotation* rotations [[buffer(1)]],
|
||||||
|
device T* U [[buffer(2)]],
|
||||||
|
device T* V [[buffer(3)]],
|
||||||
|
const constant SVDParams& params [[buffer(4)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
// Placeholder kernel implementation for initial PR
|
||||||
|
// This will be replaced with actual SVD implementation in subsequent PRs
|
||||||
|
template <typename T>
|
||||||
|
[[kernel]] void svd_placeholder(
|
||||||
|
const device T* A [[buffer(0)]],
|
||||||
|
device T* S [[buffer(1)]],
|
||||||
|
const constant SVDParams& params [[buffer(2)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||||
|
// Placeholder implementation - just copy input to output for now
|
||||||
|
// This ensures the kernel compiles and can be called
|
||||||
|
uint index = tid.x;
|
||||||
|
if (index < params.K) {
|
||||||
|
S[index] = T(1.0); // Placeholder singular values
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Template instantiations for compilation
|
||||||
|
template [[host_name("svd_placeholder_float")]] [[kernel]]
|
||||||
|
decltype(svd_placeholder<float>) svd_placeholder<float>;
|
||||||
|
|
||||||
|
template [[host_name("svd_placeholder_double")]] [[kernel]]
|
||||||
|
decltype(svd_placeholder<double>) svd_placeholder<double>;
|
@ -18,6 +18,15 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Forward declaration for SVD implementation
|
||||||
|
template <typename T>
|
||||||
|
void svd_metal_impl(
|
||||||
|
const array& a,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
bool compute_uv,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) {
|
||||||
enc.set_bytes(start, 0);
|
enc.set_bytes(start, 0);
|
||||||
@ -331,7 +340,20 @@ void QRF::eval_gpu(
|
|||||||
void SVD::eval_gpu(
|
void SVD::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI.");
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
switch (inputs[0].dtype()) {
|
||||||
|
case float32:
|
||||||
|
svd_metal_impl<float>(inputs[0], outputs, compute_uv_, d, s);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
svd_metal_impl<double>(inputs[0], outputs, compute_uv_, d, s);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[SVD::eval_gpu] only supports float32 or float64.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Inverse::eval_gpu(const std::vector<array>& inputs, array& output) {
|
void Inverse::eval_gpu(const std::vector<array>& inputs, array& output) {
|
||||||
|
105
mlx/backend/metal/svd.cpp
Normal file
105
mlx/backend/metal/svd.cpp
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/svd.h"
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/kernels.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Select appropriate SVD algorithm based on matrix properties
|
||||||
|
*/
|
||||||
|
enum class SVDAlgorithm {
|
||||||
|
JACOBI_ONE_SIDED, // Default for most cases
|
||||||
|
JACOBI_TWO_SIDED, // Better numerical stability (future)
|
||||||
|
BIDIAGONAL_QR // For very large matrices (future)
|
||||||
|
};
|
||||||
|
|
||||||
|
SVDAlgorithm select_svd_algorithm(int M, int N, Dtype dtype) {
|
||||||
|
// For now, always use one-sided Jacobi
|
||||||
|
// Future PRs will add algorithm selection heuristics
|
||||||
|
return SVDAlgorithm::JACOBI_ONE_SIDED;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate SVD input parameters
|
||||||
|
*/
|
||||||
|
void validate_svd_inputs(const array& a) {
|
||||||
|
if (a.ndim() < 2) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[SVD::eval_gpu] Input must have >= 2 dimensions");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a.dtype() != float32 && a.dtype() != float64) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[SVD::eval_gpu] Only float32 and float64 supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for reasonable matrix size
|
||||||
|
int M = a.shape(-2);
|
||||||
|
int N = a.shape(-1);
|
||||||
|
if (M > 4096 || N > 4096) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[SVD::eval_gpu] Matrix too large for current implementation. "
|
||||||
|
"Maximum supported size is 4096x4096");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (M == 0 || N == 0) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[SVD::eval_gpu] Matrix dimensions must be positive");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Metal implementation of SVD using one-sided Jacobi algorithm
|
||||||
|
* This is a placeholder implementation that will be completed in subsequent PRs
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void svd_metal_impl(
|
||||||
|
const array& a,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
bool compute_uv,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s) {
|
||||||
|
// Validate inputs
|
||||||
|
validate_svd_inputs(a);
|
||||||
|
|
||||||
|
// Extract matrix dimensions
|
||||||
|
const int M = a.shape(-2);
|
||||||
|
const int N = a.shape(-1);
|
||||||
|
const int K = std::min(M, N);
|
||||||
|
const size_t num_matrices = a.size() / (M * N);
|
||||||
|
|
||||||
|
// TODO: Implement actual Metal SVD computation in subsequent PRs
|
||||||
|
// For now, throw an informative error
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[SVD::eval_gpu] Metal SVD implementation in progress. "
|
||||||
|
"Matrix size: " +
|
||||||
|
std::to_string(M) + "x" + std::to_string(N) +
|
||||||
|
", batch size: " + std::to_string(num_matrices) +
|
||||||
|
", compute_uv: " + (compute_uv ? "true" : "false"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Explicit template instantiations
|
||||||
|
template void svd_metal_impl<float>(
|
||||||
|
const array& a,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
bool compute_uv,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
|
template void svd_metal_impl<double>(
|
||||||
|
const array& a,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
bool compute_uv,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
Loading…
Reference in New Issue
Block a user