Add SVD primitive (#809)

Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
This commit is contained in:
nicolov
2024-03-12 20:30:11 +01:00
committed by GitHub
parent ffb19df3c0
commit d0c544a868
13 changed files with 318 additions and 1 deletions

View File

@@ -200,4 +200,42 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
return std::make_pair(out[0], out[1]);
}
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::svd] Input array must have type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
const auto m = a.shape(-2);
const auto n = a.shape(-1);
const auto rank = a.ndim();
std::vector<int> u_shape = a.shape();
u_shape[rank - 2] = m;
u_shape[rank - 1] = m;
std::vector<int> s_shape = a.shape();
s_shape.pop_back();
s_shape[rank - 2] = std::min(m, n);
std::vector<int> vt_shape = a.shape();
vt_shape[rank - 2] = n;
vt_shape[rank - 1] = n;
return array::make_arrays(
{u_shape, s_shape, vt_shape},
{a.dtype(), a.dtype(), a.dtype()},
std::make_unique<SVD>(to_stream(s)),
{a});
}
} // namespace mlx::core::linalg