Adds nuclear norm support (#1894)

* adjust norm unit test tolerance
This commit is contained in:
Abe Leininger
2025-03-04 15:26:02 -06:00
committed by GitHub
parent 9680f72cca
commit 3835a428c5
11 changed files with 260 additions and 55 deletions

View File

@@ -92,6 +92,7 @@ void init_linalg(nb::module_& parent_module) {
===== ============================ ==========================
None Frobenius norm 2-norm
'fro' Frobenius norm --
'nuc' nuclear norm --
inf max(sum(abs(x), axis=1)) max(abs(x))
-inf min(sum(abs(x), axis=1)) min(abs(x))
0 -- sum(x != 0)
@@ -102,9 +103,6 @@ void init_linalg(nb::module_& parent_module) {
other -- sum(abs(x)**ord)**(1./ord)
===== ============================ ==========================
.. warning::
Nuclear norm and norms based on singular values are not yet implemented.
The Frobenius norm is given by [1]_:
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
@@ -206,15 +204,22 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"svd",
[](const mx::array& a, mx::StreamOrDevice s /* = {} */) {
const auto result = mx::linalg::svd(a, s);
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
[](const mx::array& a,
bool compute_uv /* = true */,
mx::StreamOrDevice s /* = {} */) -> nb::object {
const auto result = mx::linalg::svd(a, compute_uv, s);
if (result.size() == 1) {
return nb::cast(result.at(0));
} else {
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
}
},
"a"_a,
"compute_uv"_a = true,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
"def svd(a: array, compute_uv: bool = True, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
R"pbdoc(
The Singular Value Decomposition (SVD) of the input matrix.
@@ -224,12 +229,15 @@ void init_linalg(nb::module_& parent_module) {
Args:
a (array): Input array.
compute_uv (bool, optional): If ``True``, return the ``U``, ``S``, and ``Vt`` components.
If ``False``, return only the ``S`` array. Default: ``True``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that
``A = U @ diag(S) @ Vt``
Union[tuple(array, ...), array]:
If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that
``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``.
)pbdoc");
m.def(
"inv",

View File

@@ -12,11 +12,11 @@ import numpy as np
class TestLinalg(mlx_tests.MLXTestCase):
def test_norm(self):
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
matrix_ords = [None, "fro", "nuc", -1, 1, -2, 2, float("inf"), -float("inf")]
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
x_mx = mx.arange(1, math.prod(shape) + 1, dtype=mx.float32).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1, dtype=np.float32).reshape(shape)
# Test when at least one axis is provided
for num_axes in range(1, len(shape)):
if num_axes == 1:
@@ -26,11 +26,14 @@ class TestLinalg(mlx_tests.MLXTestCase):
for axis in itertools.combinations(range(len(shape)), num_axes):
for keepdims in [True, False]:
for o in ords:
stream = (
mx.cpu if o in ["nuc", -2, 2] else mx.default_device()
)
out_np = np.linalg.norm(
x_np, ord=o, axis=axis, keepdims=keepdims
)
out_mx = mx.linalg.norm(
x_mx, ord=o, axis=axis, keepdims=keepdims
x_mx, ord=o, axis=axis, keepdims=keepdims, stream=stream
)
with self.subTest(
shape=shape, ord=o, axis=axis, keepdims=keepdims
@@ -133,20 +136,38 @@ class TestLinalg(mlx_tests.MLXTestCase):
def test_svd_decomposition(self):
A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32)
U, S, Vt = mx.linalg.svd(A, stream=mx.cpu)
U, S, Vt = mx.linalg.svd(A, compute_uv=True, stream=mx.cpu)
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7)
)
S = mx.linalg.svd(A, compute_uv=False, stream=mx.cpu)
self.assertTrue(
mx.allclose(
mx.linalg.norm(S), mx.linalg.norm(A, ord="fro"), rtol=1e-5, atol=1e-7
)
)
# Multiple matrices
B = A + 10.0
AB = mx.stack([A, B])
Us, Ss, Vts = mx.linalg.svd(AB, stream=mx.cpu)
Us, Ss, Vts = mx.linalg.svd(AB, compute_uv=True, stream=mx.cpu)
for M, U, S, Vt in zip([A, B], Us, Ss, Vts):
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
Ss = mx.linalg.svd(AB, compute_uv=False, stream=mx.cpu)
for M, S in zip([A, B], Ss):
self.assertTrue(
mx.allclose(
mx.linalg.norm(S),
mx.linalg.norm(M, ord="fro"),
rtol=1e-5,
atol=1e-7,
)
)
def test_inverse(self):
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
A_inv = mx.linalg.inv(A, stream=mx.cpu)

View File

@@ -316,33 +316,56 @@ class TestVmap(mlx_tests.MLXTestCase):
def test_vmap_svd(self):
a = mx.random.uniform(shape=(3, 4, 2))
cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu)
cpu_svd_full = lambda x: mx.linalg.svd(x, compute_uv=True, stream=mx.cpu)
cpu_svd_singular = lambda x: mx.linalg.svd(x, compute_uv=False, stream=mx.cpu)
# Vmap over the first axis (this is already supported natively by the primitive).
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a)
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(0,))(a)
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
Sv = mx.vmap(cpu_svd_singular, in_axes=(0,))(a)
self.assertEqual(Sv.shape, (a.shape[0], a.shape[2]))
for i in range(a.shape[0]):
M = a[i]
U, S, Vt = Us[i], Ss[i], Vts[i]
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
self.assertTrue(
mx.allclose(
mx.linalg.norm(Sv[i]),
mx.linalg.norm(M, ord="fro"),
rtol=1e-5,
atol=1e-7,
)
)
# Vmap over the second axis.
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a)
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(1,))(a)
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
Sv = mx.vmap(cpu_svd_singular, in_axes=(1,))(a)
self.assertEqual(Sv.shape, (a.shape[1], a.shape[2]))
for i in range(a.shape[1]):
M = a[:, i, :]
U, S, Vt = Us[i], Ss[i], Vts[i]
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
self.assertTrue(
mx.allclose(
mx.linalg.norm(Sv[i]),
mx.linalg.norm(M, ord="fro"),
rtol=1e-5,
atol=1e-7,
)
)
def test_vmap_inverse(self):
mx.random.seed(42)