Add SVD primitive (#809)

Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
This commit is contained in:
nicolov
2024-03-12 20:30:11 +01:00
committed by GitHub
parent ffb19df3c0
commit d0c544a868
13 changed files with 318 additions and 1 deletions

View File

@@ -5,6 +5,7 @@
#include <cmath>
#include "mlx/mlx.h"
#include "mlx/ops.h"
using namespace mlx::core;
using namespace mlx::core::linalg;
@@ -267,3 +268,35 @@ TEST_CASE("test QR factorization") {
CHECK_EQ(Q.dtype(), float32);
CHECK_EQ(R.dtype(), float32);
}
TEST_CASE("test SVD factorization") {
// 0D and 1D throw
CHECK_THROWS(linalg::svd(array(0.0)));
CHECK_THROWS(linalg::svd(array({0.0, 1.0})));
// Unsupported types throw
CHECK_THROWS(linalg::svd(array({0, 1}, {1, 2})));
const auto prng_key = random::key(42);
const auto A = mlx::core::random::normal({5, 4}, prng_key);
const auto outs = linalg::svd(A, Device::cpu);
CHECK_EQ(outs.size(), 3);
const auto& U = outs[0];
const auto& S = outs[1];
const auto& Vt = outs[2];
CHECK_EQ(U.shape(), std::vector<int>{5, 5});
CHECK_EQ(S.shape(), std::vector<int>{4});
CHECK_EQ(Vt.shape(), std::vector<int>{4, 4});
const auto U_slice = slice(U, {0, 0}, {U.shape(0), S.shape(0)});
const auto A_again = matmul(matmul(U_slice, diag(S)), Vt);
CHECK(
allclose(A_again, A, /* rtol = */ 1e-4, /* atol = */ 1e-4).item<bool>());
CHECK_EQ(U.dtype(), float32);
CHECK_EQ(S.dtype(), float32);
CHECK_EQ(Vt.dtype(), float32);
}