mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
@@ -12,11 +12,11 @@ import numpy as np
|
||||
class TestLinalg(mlx_tests.MLXTestCase):
|
||||
def test_norm(self):
|
||||
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
|
||||
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
|
||||
matrix_ords = [None, "fro", "nuc", -1, 1, -2, 2, float("inf"), -float("inf")]
|
||||
|
||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
x_mx = mx.arange(1, math.prod(shape) + 1, dtype=mx.float32).reshape(shape)
|
||||
x_np = np.arange(1, math.prod(shape) + 1, dtype=np.float32).reshape(shape)
|
||||
# Test when at least one axis is provided
|
||||
for num_axes in range(1, len(shape)):
|
||||
if num_axes == 1:
|
||||
@@ -26,11 +26,14 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||
for keepdims in [True, False]:
|
||||
for o in ords:
|
||||
stream = (
|
||||
mx.cpu if o in ["nuc", -2, 2] else mx.default_device()
|
||||
)
|
||||
out_np = np.linalg.norm(
|
||||
x_np, ord=o, axis=axis, keepdims=keepdims
|
||||
)
|
||||
out_mx = mx.linalg.norm(
|
||||
x_mx, ord=o, axis=axis, keepdims=keepdims
|
||||
x_mx, ord=o, axis=axis, keepdims=keepdims, stream=stream
|
||||
)
|
||||
with self.subTest(
|
||||
shape=shape, ord=o, axis=axis, keepdims=keepdims
|
||||
@@ -133,20 +136,38 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_svd_decomposition(self):
|
||||
A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32)
|
||||
U, S, Vt = mx.linalg.svd(A, stream=mx.cpu)
|
||||
U, S, Vt = mx.linalg.svd(A, compute_uv=True, stream=mx.cpu)
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
|
||||
S = mx.linalg.svd(A, compute_uv=False, stream=mx.cpu)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx.linalg.norm(S), mx.linalg.norm(A, ord="fro"), rtol=1e-5, atol=1e-7
|
||||
)
|
||||
)
|
||||
|
||||
# Multiple matrices
|
||||
B = A + 10.0
|
||||
AB = mx.stack([A, B])
|
||||
Us, Ss, Vts = mx.linalg.svd(AB, stream=mx.cpu)
|
||||
Us, Ss, Vts = mx.linalg.svd(AB, compute_uv=True, stream=mx.cpu)
|
||||
for M, U, S, Vt in zip([A, B], Us, Ss, Vts):
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
|
||||
Ss = mx.linalg.svd(AB, compute_uv=False, stream=mx.cpu)
|
||||
for M, S in zip([A, B], Ss):
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx.linalg.norm(S),
|
||||
mx.linalg.norm(M, ord="fro"),
|
||||
rtol=1e-5,
|
||||
atol=1e-7,
|
||||
)
|
||||
)
|
||||
|
||||
def test_inverse(self):
|
||||
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
|
||||
A_inv = mx.linalg.inv(A, stream=mx.cpu)
|
||||
|
@@ -316,33 +316,56 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
def test_vmap_svd(self):
|
||||
a = mx.random.uniform(shape=(3, 4, 2))
|
||||
|
||||
cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu)
|
||||
cpu_svd_full = lambda x: mx.linalg.svd(x, compute_uv=True, stream=mx.cpu)
|
||||
cpu_svd_singular = lambda x: mx.linalg.svd(x, compute_uv=False, stream=mx.cpu)
|
||||
|
||||
# Vmap over the first axis (this is already supported natively by the primitive).
|
||||
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a)
|
||||
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(0,))(a)
|
||||
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
|
||||
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
|
||||
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
|
||||
|
||||
Sv = mx.vmap(cpu_svd_singular, in_axes=(0,))(a)
|
||||
self.assertEqual(Sv.shape, (a.shape[0], a.shape[2]))
|
||||
|
||||
for i in range(a.shape[0]):
|
||||
M = a[i]
|
||||
U, S, Vt = Us[i], Ss[i], Vts[i]
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx.linalg.norm(Sv[i]),
|
||||
mx.linalg.norm(M, ord="fro"),
|
||||
rtol=1e-5,
|
||||
atol=1e-7,
|
||||
)
|
||||
)
|
||||
|
||||
# Vmap over the second axis.
|
||||
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a)
|
||||
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(1,))(a)
|
||||
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
|
||||
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
|
||||
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
|
||||
|
||||
Sv = mx.vmap(cpu_svd_singular, in_axes=(1,))(a)
|
||||
self.assertEqual(Sv.shape, (a.shape[1], a.shape[2]))
|
||||
|
||||
for i in range(a.shape[1]):
|
||||
M = a[:, i, :]
|
||||
U, S, Vt = Us[i], Ss[i], Vts[i]
|
||||
self.assertTrue(
|
||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
mx.linalg.norm(Sv[i]),
|
||||
mx.linalg.norm(M, ord="fro"),
|
||||
rtol=1e-5,
|
||||
atol=1e-7,
|
||||
)
|
||||
)
|
||||
|
||||
def test_vmap_inverse(self):
|
||||
mx.random.seed(42)
|
||||
|
Reference in New Issue
Block a user