mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following https://developer.apple.com/documentation/accelerate/ compressing_an_image_using_linear_algebra Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::linalg;
|
||||
@@ -267,3 +268,35 @@ TEST_CASE("test QR factorization") {
|
||||
CHECK_EQ(Q.dtype(), float32);
|
||||
CHECK_EQ(R.dtype(), float32);
|
||||
}
|
||||
|
||||
TEST_CASE("test SVD factorization") {
|
||||
// 0D and 1D throw
|
||||
CHECK_THROWS(linalg::svd(array(0.0)));
|
||||
CHECK_THROWS(linalg::svd(array({0.0, 1.0})));
|
||||
|
||||
// Unsupported types throw
|
||||
CHECK_THROWS(linalg::svd(array({0, 1}, {1, 2})));
|
||||
|
||||
const auto prng_key = random::key(42);
|
||||
const auto A = mlx::core::random::normal({5, 4}, prng_key);
|
||||
const auto outs = linalg::svd(A, Device::cpu);
|
||||
CHECK_EQ(outs.size(), 3);
|
||||
|
||||
const auto& U = outs[0];
|
||||
const auto& S = outs[1];
|
||||
const auto& Vt = outs[2];
|
||||
|
||||
CHECK_EQ(U.shape(), std::vector<int>{5, 5});
|
||||
CHECK_EQ(S.shape(), std::vector<int>{4});
|
||||
CHECK_EQ(Vt.shape(), std::vector<int>{4, 4});
|
||||
|
||||
const auto U_slice = slice(U, {0, 0}, {U.shape(0), S.shape(0)});
|
||||
|
||||
const auto A_again = matmul(matmul(U_slice, diag(S)), Vt);
|
||||
|
||||
CHECK(
|
||||
allclose(A_again, A, /* rtol = */ 1e-4, /* atol = */ 1e-4).item<bool>());
|
||||
CHECK_EQ(U.dtype(), float32);
|
||||
CHECK_EQ(S.dtype(), float32);
|
||||
CHECK_EQ(Vt.dtype(), float32);
|
||||
}
|
||||
|
Reference in New Issue
Block a user