mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00

Add GPU-accelerated SVD implementation for Apple Silicon using Metal compute kernels. FEATURES: ✅ Complete one-sided Jacobi SVD algorithm in Metal ✅ Full GPU acceleration with proper Metal integration ✅ Mathematical correctness verified against CPU reference ✅ Support for both singular values only and full SVD (U, S, Vt) ✅ Comprehensive input validation and error handling ✅ Production-ready implementation with extensive testing IMPLEMENTATION: - Metal compute kernels implementing Jacobi algorithm - Proper MLX primitive integration with eval_gpu support - Optimized for matrices up to 64x64 (shared memory limitation) - Float32 precision (Metal hardware limitation) - Batched operations support TESTING: - Comprehensive test suite with 10 test cases - Mathematical correctness validation - Shape and type verification - Edge case handling - Performance characteristics testing This transforms MLX from 'Metal GPU SVD not yet implemented' to a complete, working GPU-accelerated SVD solution.
273 lines
6.2 KiB
C++
273 lines
6.2 KiB
C++
// Copyright © 2024 Apple Inc.
|
|
|
|
#pragma once
|
|
|
|
#include <fmt/format.h>
|
|
|
|
#include "mlx/array.h"
|
|
#include "mlx/backend/metal/device.h"
|
|
|
|
namespace mlx::core {
|
|
|
|
MTL::ComputePipelineState* get_arange_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& out);
|
|
|
|
MTL::ComputePipelineState* get_unary_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
Dtype in_type,
|
|
Dtype out_type,
|
|
const std::string op);
|
|
|
|
MTL::ComputePipelineState* get_binary_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
Dtype in_type,
|
|
Dtype out_type,
|
|
const std::string op);
|
|
|
|
MTL::ComputePipelineState* get_binary_two_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
Dtype in_type,
|
|
Dtype out_type,
|
|
const std::string op);
|
|
|
|
MTL::ComputePipelineState* get_ternary_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
Dtype type,
|
|
const std::string op);
|
|
|
|
MTL::ComputePipelineState* get_copy_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& in,
|
|
const array& out);
|
|
|
|
MTL::ComputePipelineState* get_dynamic_copy_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& in,
|
|
const array& out);
|
|
|
|
MTL::ComputePipelineState* get_softmax_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
bool precise,
|
|
const array& out);
|
|
|
|
MTL::ComputePipelineState* get_logsumexp_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& out);
|
|
|
|
MTL::ComputePipelineState* get_scan_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
bool reverse,
|
|
bool inclusive,
|
|
const std::string& reduce_type,
|
|
const array& in,
|
|
const array& out);
|
|
|
|
MTL::ComputePipelineState* get_sort_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& in,
|
|
const array& out,
|
|
int bn,
|
|
int tn);
|
|
|
|
MTL::ComputePipelineState* get_mb_sort_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& in,
|
|
const array& idx,
|
|
int bn,
|
|
int tn);
|
|
|
|
MTL::ComputePipelineState* get_reduce_init_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const std::string& func_name,
|
|
const std::string& op_name,
|
|
const Dtype& out_type);
|
|
|
|
MTL::ComputePipelineState* get_reduce_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const std::string& func_name,
|
|
const std::string& op_name,
|
|
const Dtype& in_type,
|
|
const Dtype& out_type,
|
|
const std::string& idx_t,
|
|
int ndim = -1,
|
|
int bm = -1,
|
|
int bn = -1);
|
|
|
|
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const std::string& hash_name,
|
|
const metal::MTLFCList& func_consts,
|
|
const array& out,
|
|
bool transpose_a,
|
|
bool transpose_b,
|
|
int bm,
|
|
int bn,
|
|
int bk,
|
|
int wm,
|
|
int wn);
|
|
|
|
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& in,
|
|
const array& out,
|
|
bool transpose_a,
|
|
bool transpose_b,
|
|
int bm,
|
|
int bn,
|
|
int bk,
|
|
int wm,
|
|
int wn,
|
|
bool mn_aligned,
|
|
bool k_aligned);
|
|
|
|
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& in,
|
|
const array& out,
|
|
bool axbpy);
|
|
|
|
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& out,
|
|
const std::optional<array>& mask_out,
|
|
const std::optional<array>& mask_op,
|
|
bool transpose_a,
|
|
bool transpose_b,
|
|
int bm,
|
|
int bn,
|
|
int bk,
|
|
int wm,
|
|
int wn,
|
|
bool mn_aligned,
|
|
bool k_aligned);
|
|
|
|
MTL::ComputePipelineState* get_steel_gemm_gather_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const std::string& hash_name,
|
|
const metal::MTLFCList& func_consts,
|
|
const array& out,
|
|
bool transpose_a,
|
|
bool transpose_b,
|
|
int bm,
|
|
int bn,
|
|
int bk,
|
|
int wm,
|
|
int wn,
|
|
bool rhs);
|
|
|
|
MTL::ComputePipelineState* get_steel_conv_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& out,
|
|
int bm,
|
|
int bn,
|
|
int bk,
|
|
int wm,
|
|
int wn,
|
|
int n_channel_specialization,
|
|
bool small_filter);
|
|
|
|
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& out,
|
|
const std::optional<array>& mask_out,
|
|
const std::optional<array>& mask_op,
|
|
bool transpose_mat,
|
|
int bm,
|
|
int bn,
|
|
int sm,
|
|
int sn,
|
|
int tm,
|
|
int tn,
|
|
bool contiguous);
|
|
|
|
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const std::string& hash_name,
|
|
const metal::MTLFCList& func_consts,
|
|
const array& out,
|
|
int bm,
|
|
int bn,
|
|
int bk,
|
|
int wm,
|
|
int wn);
|
|
|
|
MTL::ComputePipelineState* get_fft_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const std::string& hash_name,
|
|
const metal::MTLFCList& func_consts,
|
|
const std::string& template_def);
|
|
|
|
MTL::ComputePipelineState* get_quantized_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const std::string& template_def);
|
|
|
|
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const std::string& hash_name,
|
|
const metal::MTLFCList& func_consts,
|
|
const array& x,
|
|
int group_size,
|
|
int bits,
|
|
int bm,
|
|
int bn,
|
|
int bk,
|
|
int wm,
|
|
int wn,
|
|
bool transpose);
|
|
|
|
MTL::ComputePipelineState* get_svd_kernel(
|
|
metal::Device& d,
|
|
const std::string& kernel_name,
|
|
const array& out,
|
|
bool compute_uv);
|
|
|
|
// Create a GPU kernel template definition for JIT compilation
|
|
template <typename... Args>
|
|
std::string
|
|
get_template_definition(std::string name, std::string func, Args... args) {
|
|
std::ostringstream s;
|
|
s << func << "<";
|
|
bool first = true;
|
|
auto add_arg = [&s, &first](const auto& arg) {
|
|
if (!first) {
|
|
s << ", ";
|
|
}
|
|
first = false;
|
|
s << arg;
|
|
};
|
|
(add_arg(args), ...);
|
|
s << ">";
|
|
return fmt::format(
|
|
"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
|
|
name,
|
|
s.str());
|
|
}
|
|
|
|
} // namespace mlx::core
|