mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -8,7 +8,7 @@
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
|
||||
// Lapack uses the column-major convention. To avoid having to transpose
|
||||
// the input and then transpose the outputs, we swap the indices/sizes of the
|
||||
// matrices and take advantage of the following identity (see
|
||||
@@ -35,13 +35,8 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
// Allocate outputs.
|
||||
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
||||
|
||||
static constexpr auto job_u = "V";
|
||||
static constexpr auto job_vt = "V";
|
||||
auto job_u = (u_data && vt_data) ? "V" : "N";
|
||||
auto job_vt = (u_data && vt_data) ? "V" : "N";
|
||||
static constexpr auto range = "A";
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
@@ -56,6 +51,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
|
||||
static const int ignored_int = 0;
|
||||
static const T ignored_float = 0;
|
||||
static T ignored_output = 0;
|
||||
|
||||
int info;
|
||||
|
||||
@@ -109,12 +105,12 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ s.data<T>() + K * i,
|
||||
/* s = */ s_data + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt.data<T>() + N * N * i,
|
||||
/* u = */ vt_data ? vt_data + N * N * i : &ignored_output,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ u.data<T>() + M * M * i,
|
||||
/* vt = */ u_data ? u_data + M * M * i : &ignored_output,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
@@ -136,15 +132,36 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void compute_svd(const array& a, bool compute_uv, std::vector<array>& outputs) {
|
||||
if (compute_uv) {
|
||||
array& u = outputs[0];
|
||||
array& s = outputs[1];
|
||||
array& vt = outputs[2];
|
||||
|
||||
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
||||
|
||||
svd_impl<T>(a, u.data<T>(), s.data<T>(), vt.data<T>());
|
||||
} else {
|
||||
array& s = outputs[0];
|
||||
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
|
||||
svd_impl<T>(a, nullptr, s.data<T>(), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
void SVD::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
switch (inputs[0].dtype()) {
|
||||
case float32:
|
||||
svd_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
compute_svd<float>(inputs[0], compute_uv_, outputs);
|
||||
break;
|
||||
case float64:
|
||||
svd_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
compute_svd<double>(inputs[0], compute_uv_, outputs);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
Reference in New Issue
Block a user