# Copyright © 2023 Apple Inc. import itertools import math import unittest import mlx.core as mx import mlx_tests import numpy as np class TestLinalg(mlx_tests.MLXTestCase): def test_norm(self): vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")] matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] for shape in [(3,), (2, 3), (2, 3, 3)]: x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) # Test when at least one axis is provided for num_axes in range(1, len(shape)): if num_axes == 1: ords = vector_ords else: ords = matrix_ords for axis in itertools.combinations(range(len(shape)), num_axes): for keepdims in [True, False]: for o in ords: out_np = np.linalg.norm( x_np, ord=o, axis=axis, keepdims=keepdims ) out_mx = mx.linalg.norm( x_mx, ord=o, axis=axis, keepdims=keepdims ) with self.subTest( shape=shape, ord=o, axis=axis, keepdims=keepdims ): self.assertTrue( np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) ) # Test only ord provided for shape in [(3,), (2, 3)]: x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) for o in [None, 1, -1, float("inf"), -float("inf")]: for keepdims in [True, False]: out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims) out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) with self.subTest(shape=shape, ord=o, keepdims=keepdims): self.assertTrue( np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) ) # Test no ord and no axis provided for shape in [(3,), (2, 3), (2, 3, 3)]: x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) for keepdims in [True, False]: out_np = np.linalg.norm(x_np, keepdims=keepdims) out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) with self.subTest(shape=shape, keepdims=keepdims): self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) def test_complex_norm(self): for shape in [(3,), (2, 3), (2, 3, 3)]: x_np = np.random.uniform(size=shape).astype( np.float32 ) + 1j * np.random.uniform(size=shape).astype(np.float32) x_mx = mx.array(x_np) out_np = np.linalg.norm(x_np) out_mx = mx.linalg.norm(x_mx) with self.subTest(shape=shape): self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) for num_axes in range(1, len(shape)): for axis in itertools.combinations(range(len(shape)), num_axes): out_np = np.linalg.norm(x_np, axis=axis) out_mx = mx.linalg.norm(x_mx, axis=axis) with self.subTest(shape=shape, axis=axis): self.assertTrue( np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) ) x_np = np.random.uniform(size=(4, 4)).astype( np.float32 ) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32) x_mx = mx.array(x_np) out_np = np.linalg.norm(x_np, ord="fro") out_mx = mx.linalg.norm(x_mx, ord="fro") self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) def test_qr_factorization(self): with self.assertRaises(ValueError): mx.linalg.qr(mx.array(0.0)) with self.assertRaises(ValueError): mx.linalg.qr(mx.array([0.0, 1.0])) with self.assertRaises(ValueError): mx.linalg.qr(mx.array([[0, 1], [1, 0]])) A = mx.array([[2.0, 3.0], [1.0, 2.0]]) Q, R = mx.linalg.qr(A, stream=mx.cpu) out = Q @ R self.assertTrue(mx.allclose(out, A)) out = Q @ Q self.assertTrue(mx.allclose(out, mx.eye(2), rtol=1e-5, atol=1e-7)) self.assertTrue(mx.allclose(mx.tril(R, -1), mx.zeros_like(R))) self.assertEqual(Q.dtype, mx.float32) self.assertEqual(R.dtype, mx.float32) # Multiple matrices B = mx.array([[-1.0, 2.0], [-4.0, 1.0]]) A = mx.stack([A, B]) Q, R = mx.linalg.qr(A, stream=mx.cpu) for a, q, r in zip(A, Q, R): out = q @ r self.assertTrue(mx.allclose(out, a)) out = q @ q self.assertTrue(mx.allclose(out, mx.eye(2), rtol=1e-5, atol=1e-7)) self.assertTrue(mx.allclose(mx.tril(r, -1), mx.zeros_like(r))) def test_svd_decomposition(self): A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32) U, S, Vt = mx.linalg.svd(A, stream=mx.cpu) self.assertTrue( mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7) ) # Multiple matrices B = A + 10.0 AB = mx.stack([A, B]) Us, Ss, Vts = mx.linalg.svd(AB, stream=mx.cpu) for M, U, S, Vt in zip([A, B], Us, Ss, Vts): self.assertTrue( mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7) ) def test_inverse(self): A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32) A_inv = mx.linalg.inv(A, stream=mx.cpu) self.assertTrue(mx.allclose(A @ A_inv, mx.eye(A.shape[0]), rtol=0, atol=1e-6)) # Multiple matrices B = A - 100 AB = mx.stack([A, B]) invs = mx.linalg.inv(AB, stream=mx.cpu) for M, M_inv in zip(AB, invs): self.assertTrue( mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) ) def test_tri_inverse(self): for upper in (False, True): A = mx.array([[1, 0, 0], [6, -5, 0], [-9, 8, 7]], dtype=mx.float32) B = mx.array([[7, 0, 0], [3, -2, 0], [1, 8, 3]], dtype=mx.float32) if upper: A = A.T B = B.T AB = mx.stack([A, B]) invs = mx.linalg.tri_inv(AB, upper=upper, stream=mx.cpu) for M, M_inv in zip(AB, invs): self.assertTrue( mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) ) def test_cholesky(self): sqrtA = mx.array( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float32 ) A = sqrtA.T @ sqrtA / 81 L = mx.linalg.cholesky(A, stream=mx.cpu) U = mx.linalg.cholesky(A, upper=True, stream=mx.cpu) self.assertTrue(mx.allclose(L @ L.T, A, rtol=1e-5, atol=1e-7)) self.assertTrue(mx.allclose(U.T @ U, A, rtol=1e-5, atol=1e-7)) # Multiple matrices B = A + 1 / 9 AB = mx.stack([A, B]) Ls = mx.linalg.cholesky(AB, stream=mx.cpu) for M, L in zip(AB, Ls): self.assertTrue(mx.allclose(L @ L.T, M, rtol=1e-5, atol=1e-7)) def test_pseudo_inverse(self): A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32) A_plus = mx.linalg.pinv(A, stream=mx.cpu) self.assertTrue(mx.allclose(A @ A_plus @ A, A, rtol=0, atol=1e-5)) # Multiple matrices B = A - 100 AB = mx.stack([A, B]) pinvs = mx.linalg.pinv(AB, stream=mx.cpu) for M, M_plus in zip(AB, pinvs): self.assertTrue(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3)) def test_cholesky_inv(self): mx.random.seed(7) sqrtA = mx.array( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float32 ) A = sqrtA.T @ sqrtA / 81 N = 3 A = mx.random.uniform(shape=(N, N)) A = A @ A.T for upper in (False, True): L = mx.linalg.cholesky(A, upper=upper, stream=mx.cpu) A_inv = mx.linalg.cholesky_inv(L, upper=upper, stream=mx.cpu) self.assertTrue(mx.allclose(A @ A_inv, mx.eye(N), atol=1e-4)) # Multiple matrices B = A + 1 / 9 AB = mx.stack([A, B]) Ls = mx.linalg.cholesky(AB, stream=mx.cpu) for upper in (False, True): Ls = mx.linalg.cholesky(AB, upper=upper, stream=mx.cpu) AB_inv = mx.linalg.cholesky_inv(Ls, upper=upper, stream=mx.cpu) for M, M_inv in zip(AB, AB_inv): self.assertTrue(mx.allclose(M @ M_inv, mx.eye(N), atol=1e-4)) if __name__ == "__main__": unittest.main()