mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
CPU LU factorization and linear solvers (#1451)
* linalg solve backend * nits * more nits + fix * luf primitive and lu, solve, and solve_triangular backends * changes / nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -330,6 +330,123 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
|
||||
) # Non-square matrix
|
||||
|
||||
def test_lu(self):
|
||||
with self.assertRaises(ValueError):
|
||||
mx.linalg.lu(mx.array(0.0), stream=mx.cpu)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu)
|
||||
|
||||
# Test 3x3 matrix
|
||||
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])
|
||||
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
|
||||
self.assertTrue(mx.allclose(L[P, :] @ U, a))
|
||||
|
||||
# Test batch dimension
|
||||
a = mx.broadcast_to(a, (5, 5, 3, 3))
|
||||
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
|
||||
L = mx.take_along_axis(L, P[..., None], axis=-2)
|
||||
self.assertTrue(mx.allclose(L @ U, a))
|
||||
|
||||
def test_lu_factor(self):
|
||||
mx.random.seed(7)
|
||||
|
||||
# Test 3x3 matrix
|
||||
a = mx.random.uniform(shape=(5, 5))
|
||||
LU, pivots = mx.linalg.lu_factor(a, stream=mx.cpu)
|
||||
n = a.shape[-1]
|
||||
|
||||
pivots = pivots.tolist()
|
||||
perm = list(range(n))
|
||||
for i in range(len(pivots)):
|
||||
perm[i], perm[pivots[i]] = perm[pivots[i]], perm[i]
|
||||
|
||||
L = mx.add(mx.tril(LU, k=-1), mx.eye(n))
|
||||
U = mx.triu(LU)
|
||||
self.assertTrue(mx.allclose(L @ U, a[perm, :]))
|
||||
|
||||
def test_solve(self):
|
||||
mx.random.seed(7)
|
||||
|
||||
# Test 3x3 matrix with 1D rhs
|
||||
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])
|
||||
b = mx.array([11.0, 35.0, 28.0])
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test symmetric positive-definite matrix
|
||||
N = 5
|
||||
a = mx.random.uniform(shape=(N, N))
|
||||
a = mx.matmul(a, a.T) + N * mx.eye(N)
|
||||
b = mx.random.uniform(shape=(N, 1))
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test batch dimension
|
||||
a = mx.random.uniform(shape=(5, 5, 4, 4))
|
||||
b = mx.random.uniform(shape=(5, 5, 4, 1))
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected, atol=1e-5))
|
||||
|
||||
# Test large matrix
|
||||
N = 1000
|
||||
a = mx.random.uniform(shape=(N, N))
|
||||
b = mx.random.uniform(shape=(N, 1))
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected, atol=1e-3))
|
||||
|
||||
# Test multi-column rhs
|
||||
a = mx.random.uniform(shape=(5, 5))
|
||||
b = mx.random.uniform(shape=(5, 8))
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test batched multi-column rhs
|
||||
a = mx.broadcast_to(a, (3, 2, 5, 5))
|
||||
b = mx.broadcast_to(b, (3, 1, 5, 8))
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected, rtol=1e-5, atol=1e-5))
|
||||
|
||||
def test_solve_triangular(self):
|
||||
# Test lower triangular matrix
|
||||
a = mx.array([[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]])
|
||||
b = mx.array([8.0, 14.0, 3.0])
|
||||
|
||||
result = mx.linalg.solve_triangular(a, b, upper=False, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test upper triangular matrix
|
||||
a = mx.array([[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]])
|
||||
b = mx.array([13.0, 33.0, 18.0])
|
||||
|
||||
result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test batch multi-column rhs
|
||||
a = mx.broadcast_to(a, (3, 4, 3, 3))
|
||||
b = mx.broadcast_to(mx.expand_dims(b, -1), (3, 4, 3, 8))
|
||||
|
||||
result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user