mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following https://developer.apple.com/documentation/accelerate/ compressing_an_image_using_linear_algebra Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
This commit is contained in:
parent
ffb19df3c0
commit
d0c544a868
@ -71,6 +71,7 @@ DEFAULT(Slice)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
@ -53,6 +53,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
|
||||
|
@ -100,6 +100,7 @@ DEFAULT(Square)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Subtract)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
|
23
mlx/backend/common/lapack_helper.h
Normal file
23
mlx/backend/common/lapack_helper.h
Normal file
@ -0,0 +1,23 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)
|
||||
|
||||
// This is to work around a change in the function signatures of lapack >= 3.9.1
|
||||
// where functions taking char* also include a strlen argument, see a similar
|
||||
// change in OpenCV:
|
||||
// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57
|
||||
#define MLX_LAPACK_FUNC(f) LAPACK_##f
|
||||
|
||||
#else
|
||||
|
||||
#define MLX_LAPACK_FUNC(f) f##_
|
||||
|
||||
#endif
|
148
mlx/backend/common/svd.cpp
Normal file
148
mlx/backend/common/svd.cpp
Normal file
@ -0,0 +1,148 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack_helper.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void svd_impl(const array& a, array& u, array& s, array& vt) {
|
||||
// Lapack uses the column-major convention. To avoid having to transpose
|
||||
// the input and then transpose the outputs, we swap the indices/sizes of the
|
||||
// matrices and take advantage of the following identity (see
|
||||
// https://math.stackexchange.com/a/30077)
|
||||
// A = UΣVᵀ
|
||||
// Aᵀ = VΣUᵀ
|
||||
// As a result some of the indices/sizes are swapped as noted above.
|
||||
|
||||
// Rows and cols of the original matrix in row-major order.
|
||||
const int M = a.shape(-2);
|
||||
const int N = a.shape(-1);
|
||||
const int K = std::min(M, N);
|
||||
|
||||
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
|
||||
const int lda = N;
|
||||
// U of shape M x M. (N x N in lapack).
|
||||
const int ldu = N;
|
||||
// Vᵀ of shape N x N. (M x M in lapack).
|
||||
const int ldvt = M;
|
||||
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
array in(a.shape(), float32, nullptr, {});
|
||||
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
// Allocate outputs.
|
||||
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
||||
|
||||
static constexpr auto job_u = "V";
|
||||
static constexpr auto job_vt = "V";
|
||||
static constexpr auto range = "A";
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
int ns = 0;
|
||||
float workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not used
|
||||
// here but required by lapack).
|
||||
std::vector<int> iwork;
|
||||
iwork.resize(12 * K);
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
static const int ignored_int = 0;
|
||||
static const float ignored_float = 0;
|
||||
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
MLX_LAPACK_FUNC(sgesvdx)
|
||||
(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* iwork = */ iwork.data(),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
MLX_LAPACK_FUNC(sgesvdx)
|
||||
(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in.data<float>() + M * N * i,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ s.data<float>() + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt.data<float>() + N * N * i,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ u.data<float>() + M * M * i,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ iwork.data(),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
if (ns != K) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: expected " << K << " singular values, but " << ns
|
||||
<< " were computed.";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
if (!(inputs[0].dtype() == float32)) {
|
||||
throw std::runtime_error("[SVD::eval] only supports float32.");
|
||||
}
|
||||
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -890,4 +890,10 @@ void QRF::eval_gpu(
|
||||
throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI.");
|
||||
}
|
||||
|
||||
void SVD::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -93,6 +93,7 @@ NO_GPU(Square)
|
||||
NO_GPU(Sqrt)
|
||||
NO_GPU(StopGradient)
|
||||
NO_GPU(Subtract)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Transpose)
|
||||
|
@ -200,4 +200,42 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::svd] Input array must have type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array "
|
||||
"with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
const auto m = a.shape(-2);
|
||||
const auto n = a.shape(-1);
|
||||
const auto rank = a.ndim();
|
||||
|
||||
std::vector<int> u_shape = a.shape();
|
||||
u_shape[rank - 2] = m;
|
||||
u_shape[rank - 1] = m;
|
||||
|
||||
std::vector<int> s_shape = a.shape();
|
||||
s_shape.pop_back();
|
||||
s_shape[rank - 2] = std::min(m, n);
|
||||
|
||||
std::vector<int> vt_shape = a.shape();
|
||||
vt_shape[rank - 2] = n;
|
||||
vt_shape[rank - 1] = n;
|
||||
|
||||
return array::make_arrays(
|
||||
{u_shape, s_shape, vt_shape},
|
||||
{a.dtype(), a.dtype(), a.dtype()},
|
||||
std::make_unique<SVD>(to_stream(s)),
|
||||
{a});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
@ -62,4 +62,6 @@ norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
|
||||
|
||||
std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
|
||||
|
||||
std::vector<array> svd(const array& a, StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
@ -1840,4 +1840,20 @@ class QRF : public Primitive {
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
/* SVD primitive. */
|
||||
class SVD : public Primitive {
|
||||
public:
|
||||
explicit SVD(Stream stream) : Primitive(stream){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(SVD)
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -16,6 +16,13 @@ using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::linalg;
|
||||
|
||||
namespace {
|
||||
py::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||
const auto result = svd(a, s);
|
||||
return py::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void init_linalg(py::module_& parent_module) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
@ -186,7 +193,7 @@ void init_linalg(py::module_& parent_module) {
|
||||
R"pbdoc(
|
||||
qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)
|
||||
|
||||
The QR factorizatoin of the input matrix.
|
||||
The QR factorization of the input matrix.
|
||||
|
||||
This function supports arrays with at least 2 dimensions. The matrices
|
||||
which are factorized are assumed to be in the last two dimensions of
|
||||
@ -210,4 +217,28 @@ void init_linalg(py::module_& parent_module) {
|
||||
array([[-2.23607, -3.57771],
|
||||
[0, 0.447214]], dtype=float32)
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"svd",
|
||||
&svd_helper,
|
||||
"a"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)
|
||||
|
||||
The Singular Value Decomposition (SVD) of the input matrix.
|
||||
|
||||
This function supports arrays with at least 2 dimensions. When the input
|
||||
has more than two dimensions, the function iterates over all indices of the first
|
||||
a.ndim - 2 dimensions and for each combination SVD is applied to the last two indices.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that
|
||||
``A = U @ diag(S) @ Vt``
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -120,6 +120,22 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(out, mx.eye(2), rtol=1e-5, atol=1e-7))
|
||||
self.assertTrue(mx.allclose(mx.tril(r, -1), mx.zeros_like(r)))
|
||||
|
||||
def test_svd_decomposition(self):
|
||||
A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32)
|
||||
U, S, Vt = mx.linalg.svd(A, stream=mx.cpu)
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
|
||||
# Multiple matrices
|
||||
B = A + 10.0
|
||||
AB = mx.stack([A, B])
|
||||
Us, Ss, Vts = mx.linalg.svd(AB, stream=mx.cpu)
|
||||
for M, U, S, Vt in zip([A, B], Us, Ss, Vts):
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::linalg;
|
||||
@ -267,3 +268,35 @@ TEST_CASE("test QR factorization") {
|
||||
CHECK_EQ(Q.dtype(), float32);
|
||||
CHECK_EQ(R.dtype(), float32);
|
||||
}
|
||||
|
||||
TEST_CASE("test SVD factorization") {
|
||||
// 0D and 1D throw
|
||||
CHECK_THROWS(linalg::svd(array(0.0)));
|
||||
CHECK_THROWS(linalg::svd(array({0.0, 1.0})));
|
||||
|
||||
// Unsupported types throw
|
||||
CHECK_THROWS(linalg::svd(array({0, 1}, {1, 2})));
|
||||
|
||||
const auto prng_key = random::key(42);
|
||||
const auto A = mlx::core::random::normal({5, 4}, prng_key);
|
||||
const auto outs = linalg::svd(A, Device::cpu);
|
||||
CHECK_EQ(outs.size(), 3);
|
||||
|
||||
const auto& U = outs[0];
|
||||
const auto& S = outs[1];
|
||||
const auto& Vt = outs[2];
|
||||
|
||||
CHECK_EQ(U.shape(), std::vector<int>{5, 5});
|
||||
CHECK_EQ(S.shape(), std::vector<int>{4});
|
||||
CHECK_EQ(Vt.shape(), std::vector<int>{4, 4});
|
||||
|
||||
const auto U_slice = slice(U, {0, 0}, {U.shape(0), S.shape(0)});
|
||||
|
||||
const auto A_again = matmul(matmul(U_slice, diag(S)), Vt);
|
||||
|
||||
CHECK(
|
||||
allclose(A_again, A, /* rtol = */ 1e-4, /* atol = */ 1e-4).item<bool>());
|
||||
CHECK_EQ(U.dtype(), float32);
|
||||
CHECK_EQ(S.dtype(), float32);
|
||||
CHECK_EQ(Vt.dtype(), float32);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user