mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following https://developer.apple.com/documentation/accelerate/ compressing_an_image_using_linear_algebra Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
This commit is contained in:
@@ -200,4 +200,42 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::svd] Input array must have type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array "
|
||||
"with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
const auto m = a.shape(-2);
|
||||
const auto n = a.shape(-1);
|
||||
const auto rank = a.ndim();
|
||||
|
||||
std::vector<int> u_shape = a.shape();
|
||||
u_shape[rank - 2] = m;
|
||||
u_shape[rank - 1] = m;
|
||||
|
||||
std::vector<int> s_shape = a.shape();
|
||||
s_shape.pop_back();
|
||||
s_shape[rank - 2] = std::min(m, n);
|
||||
|
||||
std::vector<int> vt_shape = a.shape();
|
||||
vt_shape[rank - 2] = n;
|
||||
vt_shape[rank - 1] = n;
|
||||
|
||||
return array::make_arrays(
|
||||
{u_shape, s_shape, vt_shape},
|
||||
{a.dtype(), a.dtype(), a.dtype()},
|
||||
std::make_unique<SVD>(to_stream(s)),
|
||||
{a});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
||||
Reference in New Issue
Block a user