diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 33a30d843..88b127bce 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -8,7 +8,7 @@ namespace mlx::core { template -void svd_impl(const array& a, array& u, array& s, array& vt) { +void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) { // Lapack uses the column-major convention. To avoid having to transpose // the input and then transpose the outputs, we swap the indices/sizes of the // matrices and take advantage of the following identity (see @@ -35,13 +35,8 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { array in(a.shape(), a.dtype(), nullptr, {}); copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); - // Allocate outputs. - u.set_data(allocator::malloc_or_wait(u.nbytes())); - s.set_data(allocator::malloc_or_wait(s.nbytes())); - vt.set_data(allocator::malloc_or_wait(vt.nbytes())); - - static constexpr auto job_u = "V"; - static constexpr auto job_vt = "V"; + auto job_u = (u_data && vt_data) ? "V" : "N"; + auto job_vt = (u_data && vt_data) ? "V" : "N"; static constexpr auto range = "A"; // Will contain the number of singular values after the call has returned. @@ -56,6 +51,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { static const int ignored_int = 0; static const T ignored_float = 0; + static T ignored_output = 0; int info; @@ -109,12 +105,12 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { /* il = */ &ignored_int, /* iu = */ &ignored_int, /* ns = */ &ns, - /* s = */ s.data() + K * i, + /* s = */ s_data + K * i, // According to the identity above, lapack will write Vᵀᵀ as U. - /* u = */ vt.data() + N * N * i, + /* u = */ vt_data ? vt_data + N * N * i : &ignored_output, /* ldu = */ &ldu, // According to the identity above, lapack will write Uᵀ as Vᵀ. - /* vt = */ u.data() + M * M * i, + /* vt = */ u_data ? u_data + M * M * i : &ignored_output, /* ldvt = */ &ldvt, /* work = */ static_cast(scratch.buffer.raw_ptr()), /* lwork = */ &lwork, @@ -136,15 +132,36 @@ void svd_impl(const array& a, array& u, array& s, array& vt) { } } +template +void compute_svd(const array& a, bool compute_uv, std::vector& outputs) { + if (compute_uv) { + array& u = outputs[0]; + array& s = outputs[1]; + array& vt = outputs[2]; + + u.set_data(allocator::malloc_or_wait(u.nbytes())); + s.set_data(allocator::malloc_or_wait(s.nbytes())); + vt.set_data(allocator::malloc_or_wait(vt.nbytes())); + + svd_impl(a, u.data(), s.data(), vt.data()); + } else { + array& s = outputs[0]; + + s.set_data(allocator::malloc_or_wait(s.nbytes())); + + svd_impl(a, nullptr, s.data(), nullptr); + } +} + void SVD::eval_cpu( const std::vector& inputs, std::vector& outputs) { switch (inputs[0].dtype()) { case float32: - svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); + compute_svd(inputs[0], compute_uv_, outputs); break; case float64: - svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); + compute_svd(inputs[0], compute_uv_, outputs); break; default: throw std::runtime_error( diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 356d39626..5b9b51ad3 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -102,8 +102,21 @@ inline array matrix_norm( dtype, s); } else if (ord == 2.0 || ord == -2.0) { - throw std::runtime_error( - "[linalg::norm] Singular value norms are not implemented."); + row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0]; + col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1]; + auto a_matrix = (row_axis > col_axis) + ? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s) + : moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s); + a_matrix = svd(a_matrix, false, s).at(0); + a_matrix = (ord == 2.0) ? max(a_matrix, -1, false, s) + : min(a_matrix, -1, false, s); + if (keepdims) { + std::vector sorted_axes = (row_axis < col_axis) + ? std::vector{row_axis, col_axis} + : std::vector{col_axis, row_axis}; + a_matrix = expand_dims(a_matrix, sorted_axes, s); + } + return astype(a_matrix, dtype, s); } else { std::ostringstream msg; msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm."; @@ -120,8 +133,19 @@ inline array matrix_norm( if (ord == "f" || ord == "fro") { return l2_norm(a, axis, keepdims, s); } else if (ord == "nuc") { - throw std::runtime_error( - "[linalg::norm] Nuclear norm not yet implemented."); + int row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0]; + int col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1]; + auto a_matrix = (row_axis > col_axis) + ? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s) + : moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s); + a_matrix = sum(svd(a_matrix, false, s).at(0), -1, false, s); + if (keepdims) { + std::vector sorted_axes = (row_axis < col_axis) + ? std::vector{row_axis, col_axis} + : std::vector{col_axis, row_axis}; + a_matrix = expand_dims(a_matrix, sorted_axes, s); + } + return a_matrix; } else { std::ostringstream msg; msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm."; @@ -214,7 +238,8 @@ std::pair qr(const array& a, StreamOrDevice s /* = {} */) { return std::make_pair(out[0], out[1]); } -std::vector svd(const array& a, StreamOrDevice s /* = {} */) { +std::vector +svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) { check_cpu_stream(s, "[linalg::svd]"); check_float(a.dtype(), "[linalg::svd]"); @@ -230,14 +255,22 @@ std::vector svd(const array& a, StreamOrDevice s /* = {} */) { const auto n = a.shape(-1); const auto rank = a.ndim(); - auto u_shape = a.shape(); - u_shape[rank - 2] = m; - u_shape[rank - 1] = m; - auto s_shape = a.shape(); s_shape.pop_back(); s_shape[rank - 2] = std::min(m, n); + if (!compute_uv) { + return {array( + std::move(s_shape), + std::move(a.dtype()), + std::make_shared(to_stream(s), compute_uv), + {a})}; + } + + auto u_shape = a.shape(); + u_shape[rank - 2] = m; + u_shape[rank - 1] = m; + auto vt_shape = a.shape(); vt_shape[rank - 2] = n; vt_shape[rank - 1] = n; @@ -245,7 +278,7 @@ std::vector svd(const array& a, StreamOrDevice s /* = {} */) { return array::make_arrays( {u_shape, s_shape, vt_shape}, {a.dtype(), a.dtype(), a.dtype()}, - std::make_shared(to_stream(s)), + std::make_shared(to_stream(s), compute_uv), {a}); } @@ -323,7 +356,7 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) { int m = a.shape(-2); int n = a.shape(-1); int k = std::min(m, n); - auto outs = linalg::svd(a, s); + auto outs = linalg::svd(a, true, s); array U = outs[0]; array S = outs[1]; array V = outs[2]; diff --git a/mlx/linalg.h b/mlx/linalg.h index 9fe4dbf60..8c3a2070a 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -62,7 +62,11 @@ norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { std::pair qr(const array& a, StreamOrDevice s = {}); -std::vector svd(const array& a, StreamOrDevice s = {}); +std::vector +svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */); +inline std::vector svd(const array& a, StreamOrDevice s = {}) { + return svd(a, true, s); +} array inv(const array& a, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index ac4c17938..60c13b2c9 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4940,7 +4940,8 @@ std::pair, std::vector> SVD::vmap( const std::vector& axes) { auto ax = axes[0] >= 0 ? 0 : -1; auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; - return {{linalg::svd(a, stream())}, {ax, ax, ax}}; + std::vector new_axes(compute_uv_ ? 3 : 1, ax); + return {linalg::svd(a, compute_uv_, stream()), std::move(new_axes)}; } std::pair, std::vector> Inverse::vmap( diff --git a/mlx/primitives.h b/mlx/primitives.h index a73ddef96..c2c0576aa 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2287,7 +2287,8 @@ class QRF : public Primitive { /* SVD primitive. */ class SVD : public Primitive { public: - explicit SVD(Stream stream) : Primitive(stream) {} + explicit SVD(Stream stream, bool compute_uv) + : Primitive(stream), compute_uv_(compute_uv) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; @@ -2296,6 +2297,12 @@ class SVD : public Primitive { DEFINE_VMAP() DEFINE_PRINT(SVD) + auto state() const { + return compute_uv_; + } + + private: + bool compute_uv_; }; /* Matrix inversion primitive. */ diff --git a/mlx/random.cpp b/mlx/random.cpp index a4755605c..d6ce5bb0e 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -244,7 +244,7 @@ array multivariate_normal( // Compute the square-root of the covariance matrix, using the SVD auto covariance = astype(cov, float32, stream); - auto SVD = linalg::svd(covariance, stream); + auto SVD = linalg::svd(covariance, true, stream); auto std = astype( matmul( multiply( diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index a43cebbe7..3bc0e5b1b 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -92,6 +92,7 @@ void init_linalg(nb::module_& parent_module) { ===== ============================ ========================== None Frobenius norm 2-norm 'fro' Frobenius norm -- + 'nuc' nuclear norm -- inf max(sum(abs(x), axis=1)) max(abs(x)) -inf min(sum(abs(x), axis=1)) min(abs(x)) 0 -- sum(x != 0) @@ -102,9 +103,6 @@ void init_linalg(nb::module_& parent_module) { other -- sum(abs(x)**ord)**(1./ord) ===== ============================ ========================== - .. warning:: - Nuclear norm and norms based on singular values are not yet implemented. - The Frobenius norm is given by [1]_: :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` @@ -206,15 +204,22 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "svd", - [](const mx::array& a, mx::StreamOrDevice s /* = {} */) { - const auto result = mx::linalg::svd(a, s); - return nb::make_tuple(result.at(0), result.at(1), result.at(2)); + [](const mx::array& a, + bool compute_uv /* = true */, + mx::StreamOrDevice s /* = {} */) -> nb::object { + const auto result = mx::linalg::svd(a, compute_uv, s); + if (result.size() == 1) { + return nb::cast(result.at(0)); + } else { + return nb::make_tuple(result.at(0), result.at(1), result.at(2)); + } }, "a"_a, + "compute_uv"_a = true, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), + "def svd(a: array, compute_uv: bool = True, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), R"pbdoc( The Singular Value Decomposition (SVD) of the input matrix. @@ -224,12 +229,15 @@ void init_linalg(nb::module_& parent_module) { Args: a (array): Input array. + compute_uv (bool, optional): If ``True``, return the ``U``, ``S``, and ``Vt`` components. + If ``False``, return only the ``S`` array. Default: ``True``. stream (Stream, optional): Stream or device. Defaults to ``None`` in which case the default stream of the default device is used. Returns: - tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that - ``A = U @ diag(S) @ Vt`` + Union[tuple(array, ...), array]: + If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that + ``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``. )pbdoc"); m.def( "inv", diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index adc365c62..ffa355c10 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -12,11 +12,11 @@ import numpy as np class TestLinalg(mlx_tests.MLXTestCase): def test_norm(self): vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")] - matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] + matrix_ords = [None, "fro", "nuc", -1, 1, -2, 2, float("inf"), -float("inf")] for shape in [(3,), (2, 3), (2, 3, 3)]: - x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) - x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) + x_mx = mx.arange(1, math.prod(shape) + 1, dtype=mx.float32).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1, dtype=np.float32).reshape(shape) # Test when at least one axis is provided for num_axes in range(1, len(shape)): if num_axes == 1: @@ -26,11 +26,14 @@ class TestLinalg(mlx_tests.MLXTestCase): for axis in itertools.combinations(range(len(shape)), num_axes): for keepdims in [True, False]: for o in ords: + stream = ( + mx.cpu if o in ["nuc", -2, 2] else mx.default_device() + ) out_np = np.linalg.norm( x_np, ord=o, axis=axis, keepdims=keepdims ) out_mx = mx.linalg.norm( - x_mx, ord=o, axis=axis, keepdims=keepdims + x_mx, ord=o, axis=axis, keepdims=keepdims, stream=stream ) with self.subTest( shape=shape, ord=o, axis=axis, keepdims=keepdims @@ -133,20 +136,38 @@ class TestLinalg(mlx_tests.MLXTestCase): def test_svd_decomposition(self): A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32) - U, S, Vt = mx.linalg.svd(A, stream=mx.cpu) + U, S, Vt = mx.linalg.svd(A, compute_uv=True, stream=mx.cpu) self.assertTrue( mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7) ) + S = mx.linalg.svd(A, compute_uv=False, stream=mx.cpu) + self.assertTrue( + mx.allclose( + mx.linalg.norm(S), mx.linalg.norm(A, ord="fro"), rtol=1e-5, atol=1e-7 + ) + ) + # Multiple matrices B = A + 10.0 AB = mx.stack([A, B]) - Us, Ss, Vts = mx.linalg.svd(AB, stream=mx.cpu) + Us, Ss, Vts = mx.linalg.svd(AB, compute_uv=True, stream=mx.cpu) for M, U, S, Vt in zip([A, B], Us, Ss, Vts): self.assertTrue( mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) ) + Ss = mx.linalg.svd(AB, compute_uv=False, stream=mx.cpu) + for M, S in zip([A, B], Ss): + self.assertTrue( + mx.allclose( + mx.linalg.norm(S), + mx.linalg.norm(M, ord="fro"), + rtol=1e-5, + atol=1e-7, + ) + ) + def test_inverse(self): A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32) A_inv = mx.linalg.inv(A, stream=mx.cpu) diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 81b74d98c..2eee33b5c 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -316,33 +316,56 @@ class TestVmap(mlx_tests.MLXTestCase): def test_vmap_svd(self): a = mx.random.uniform(shape=(3, 4, 2)) - cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu) + cpu_svd_full = lambda x: mx.linalg.svd(x, compute_uv=True, stream=mx.cpu) + cpu_svd_singular = lambda x: mx.linalg.svd(x, compute_uv=False, stream=mx.cpu) # Vmap over the first axis (this is already supported natively by the primitive). - Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a) + Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(0,))(a) self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1])) self.assertEqual(Ss.shape, (a.shape[0], a.shape[2])) self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2])) + Sv = mx.vmap(cpu_svd_singular, in_axes=(0,))(a) + self.assertEqual(Sv.shape, (a.shape[0], a.shape[2])) + for i in range(a.shape[0]): M = a[i] U, S, Vt = Us[i], Ss[i], Vts[i] self.assertTrue( mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) ) + self.assertTrue( + mx.allclose( + mx.linalg.norm(Sv[i]), + mx.linalg.norm(M, ord="fro"), + rtol=1e-5, + atol=1e-7, + ) + ) # Vmap over the second axis. - Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a) + Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(1,))(a) self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0])) self.assertEqual(Ss.shape, (a.shape[1], a.shape[2])) self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2])) + Sv = mx.vmap(cpu_svd_singular, in_axes=(1,))(a) + self.assertEqual(Sv.shape, (a.shape[1], a.shape[2])) + for i in range(a.shape[1]): M = a[:, i, :] U, S, Vt = Us[i], Ss[i], Vts[i] self.assertTrue( mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) ) + self.assertTrue( + mx.allclose( + mx.linalg.norm(Sv[i]), + mx.linalg.norm(M, ord="fro"), + rtol=1e-5, + atol=1e-7, + ) + ) def test_vmap_inverse(self): mx.random.seed(42) diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index b2465c29a..0660a69fe 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -100,7 +100,7 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { norm(x, -std::numeric_limits::infinity()).item(), doctest::Approx(expected)); - x = reshape(arange(9), {3, 3}); + x = reshape(arange(9, float32), {3, 3}); CHECK(allclose( norm(x, 2.0, 0, false), @@ -129,10 +129,34 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { CHECK_EQ( norm(x, -1.0, std::vector{1, 0}).item(), doctest::Approx(3.0)); + CHECK_EQ( + norm(x, 2.0, std::vector{0, 1}, false, Device::cpu).item(), + doctest::Approx(14.226707)); + CHECK_EQ( + norm(x, 2.0, std::vector{1, 0}, false, Device::cpu).item(), + doctest::Approx(14.226707)); + CHECK_EQ( + norm(x, -2.0, std::vector{0, 1}, false, Device::cpu).item(), + doctest::Approx(0.0)); + CHECK_EQ( + norm(x, -2.0, std::vector{1, 0}, false, Device::cpu).item(), + doctest::Approx(0.0)); CHECK_EQ(norm(x, 1.0, std::vector{0, 1}, true).shape(), Shape{1, 1}); CHECK_EQ(norm(x, 1.0, std::vector{1, 0}, true).shape(), Shape{1, 1}); CHECK_EQ(norm(x, -1.0, std::vector{0, 1}, true).shape(), Shape{1, 1}); CHECK_EQ(norm(x, -1.0, std::vector{1, 0}, true).shape(), Shape{1, 1}); + CHECK_EQ( + norm(x, 2.0, std::vector{0, 1}, true, Device::cpu).shape(), + Shape{1, 1}); + CHECK_EQ( + norm(x, 2.0, std::vector{1, 0}, true, Device::cpu).shape(), + Shape{1, 1}); + CHECK_EQ( + norm(x, -2.0, std::vector{0, 1}, true, Device::cpu).shape(), + Shape{1, 1}); + CHECK_EQ( + norm(x, -2.0, std::vector{1, 0}, true, Device::cpu).shape(), + Shape{1, 1}); CHECK_EQ( norm(x, -1.0, std::vector{-2, -1}, false).item(), @@ -140,8 +164,14 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { CHECK_EQ( norm(x, 1.0, std::vector{-2, -1}, false).item(), doctest::Approx(15.0)); + CHECK_EQ( + norm(x, -2.0, std::vector{-2, -1}, false, Device::cpu).item(), + doctest::Approx(0.0)); + CHECK_EQ( + norm(x, 2.0, std::vector{-2, -1}, false, Device::cpu).item(), + doctest::Approx(14.226707)); - x = reshape(arange(18), {2, 3, 3}); + x = reshape(arange(18, float32), {2, 3, 3}); CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2})); CHECK(allclose( norm(x, 3.0, 0), @@ -199,13 +229,31 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { .item()); CHECK(allclose(norm(x, -1.0, std::vector{1, 2}), array({9, 36})) .item()); + CHECK(allclose( + norm(x, 2.0, std::vector{0, 1}, false, Device::cpu), + array({22.045408, 24.155825, 26.318918})) + .item()); + CHECK(allclose( + norm(x, 2.0, std::vector{1, 2}, false, Device::cpu), + array({14.226707, 39.759212})) + .item()); + CHECK(allclose( + norm(x, -2.0, std::vector{0, 1}, false, Device::cpu), + array({3, 2.7378995, 2.5128777})) + .item()); + CHECK(allclose( + norm(x, -2.0, std::vector{1, 2}, false, Device::cpu), + array({4.979028e-16, 7.009628e-16}), + /* rtol = */ 1e-5, + /* atol = */ 1e-6) + .item()); } TEST_CASE("[mlx.core.linalg.norm] string ord") { array x({1, 2, 3}); CHECK_THROWS(norm(x, "fro")); - x = reshape(arange(9), {3, 3}); + x = reshape(arange(9, float32), {3, 3}); CHECK_THROWS(norm(x, "bad ord")); CHECK_EQ( @@ -214,8 +262,11 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") { CHECK_EQ( norm(x, "fro", std::vector{0, 1}).item(), doctest::Approx(14.2828568570857)); + CHECK_EQ( + norm(x, "nuc", std::vector{0, 1}, false, Device::cpu).item(), + doctest::Approx(15.491934)); - x = reshape(arange(18), {2, 3, 3}); + x = reshape(arange(18, float32), {2, 3, 3}); CHECK(allclose( norm(x, "fro", std::vector{0, 1}), array({22.24859546, 24.31049156, 26.43860813})) @@ -240,6 +291,18 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") { norm(x, "f", std::vector{2, 1}), array({14.28285686, 39.7617907})) .item()); + CHECK(allclose( + norm(x, "nuc", std::vector{0, 1}, false, Device::cpu), + array({25.045408, 26.893724, 28.831797})) + .item()); + CHECK(allclose( + norm(x, "nuc", std::vector{1, 2}, false, Device::cpu), + array({15.491934, 40.211937})) + .item()); + CHECK(allclose( + norm(x, "nuc", std::vector{-2, -1}, false, Device::cpu), + array({15.491934, 40.211937})) + .item()); } TEST_CASE("test QR factorization") { @@ -271,7 +334,7 @@ TEST_CASE("test SVD factorization") { const auto prng_key = random::key(42); const auto A = mlx::core::random::normal({5, 4}, prng_key); - const auto outs = linalg::svd(A, Device::cpu); + const auto outs = linalg::svd(A, true, Device::cpu); CHECK_EQ(outs.size(), 3); const auto& U = outs[0]; @@ -291,6 +354,15 @@ TEST_CASE("test SVD factorization") { CHECK_EQ(U.dtype(), float32); CHECK_EQ(S.dtype(), float32); CHECK_EQ(Vt.dtype(), float32); + + // Test singular values + const auto& outs_sv = linalg::svd(A, false, Device::cpu); + const auto SV = outs_sv[0]; + + CHECK_EQ(SV.shape(), Shape{4}); + CHECK_EQ(SV.dtype(), float32); + + CHECK(allclose(norm(SV), norm(A, "fro")).item()); } TEST_CASE("test matrix inversion") { diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 38011b942..2a2a28571 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -466,15 +466,19 @@ TEST_CASE("test vmap scatter") { } TEST_CASE("test vmap SVD") { - auto fun = [](std::vector inputs) { - return linalg::svd(inputs.at(0), Device::cpu); + auto svd_full = [](std::vector inputs) { + return linalg::svd(inputs.at(0), true, Device::cpu); + }; + + auto svd_singular = [](std::vector inputs) { + return linalg::svd(inputs.at(0), false, Device::cpu); }; auto a = astype(reshape(arange(24), {3, 4, 2}), float32); // vmap over the second axis. { - auto out = vmap(fun, /* in_axes = */ {1})({a}); + auto out = vmap(svd_full, /* in_axes = */ {1})({a}); const auto& U = out.at(0); const auto& S = out.at(1); const auto& Vt = out.at(2); @@ -486,7 +490,7 @@ TEST_CASE("test vmap SVD") { // vmap over the third axis. { - auto out = vmap(fun, /* in_axes = */ {2})({a}); + auto out = vmap(svd_full, /* in_axes = */ {2})({a}); const auto& U = out.at(0); const auto& S = out.at(1); const auto& Vt = out.at(2); @@ -495,6 +499,21 @@ TEST_CASE("test vmap SVD") { CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)}); CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)}); } + + // test singular values + { + auto out = vmap(svd_singular, /* in_axes = */ {1})({a}); + const auto& S = out.at(0); + + CHECK_EQ(S.shape(), Shape{a.shape(1), a.shape(2)}); + } + + { + auto out = vmap(svd_singular, /* in_axes = */ {2})({a}); + const auto& S = out.at(0); + + CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)}); + } } TEST_CASE("test vmap dynamic slices") {