Adds nuclear norm support (#1894)

* adjust norm unit test tolerance
This commit is contained in:
Abe Leininger
2025-03-04 15:26:02 -06:00
committed by GitHub
parent 9680f72cca
commit 3835a428c5
11 changed files with 260 additions and 55 deletions

View File

@@ -316,33 +316,56 @@ class TestVmap(mlx_tests.MLXTestCase):
def test_vmap_svd(self):
a = mx.random.uniform(shape=(3, 4, 2))
cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu)
cpu_svd_full = lambda x: mx.linalg.svd(x, compute_uv=True, stream=mx.cpu)
cpu_svd_singular = lambda x: mx.linalg.svd(x, compute_uv=False, stream=mx.cpu)
# Vmap over the first axis (this is already supported natively by the primitive).
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a)
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(0,))(a)
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
Sv = mx.vmap(cpu_svd_singular, in_axes=(0,))(a)
self.assertEqual(Sv.shape, (a.shape[0], a.shape[2]))
for i in range(a.shape[0]):
M = a[i]
U, S, Vt = Us[i], Ss[i], Vts[i]
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
self.assertTrue(
mx.allclose(
mx.linalg.norm(Sv[i]),
mx.linalg.norm(M, ord="fro"),
rtol=1e-5,
atol=1e-7,
)
)
# Vmap over the second axis.
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a)
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(1,))(a)
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
Sv = mx.vmap(cpu_svd_singular, in_axes=(1,))(a)
self.assertEqual(Sv.shape, (a.shape[1], a.shape[2]))
for i in range(a.shape[1]):
M = a[:, i, :]
U, S, Vt = Us[i], Ss[i], Vts[i]
self.assertTrue(
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
)
self.assertTrue(
mx.allclose(
mx.linalg.norm(Sv[i]),
mx.linalg.norm(M, ord="fro"),
rtol=1e-5,
atol=1e-7,
)
)
def test_vmap_inverse(self):
mx.random.seed(42)