From 5866b3857bb46b2f41368b69a2df2b2f4c874231 Mon Sep 17 00:00:00 2001 From: Emmanuel Ferdman Date: Sat, 7 Jun 2025 16:12:08 +0300 Subject: [PATCH] Refactor the lu test (#2250) Signed-off-by: Emmanuel Ferdman --- python/tests/test_linalg.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index f5eeda837..764d11f6e 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -359,36 +359,6 @@ class TestLinalg(mlx_tests.MLXTestCase): mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) ) # Non-square matrix - def test_lu(self): - with self.assertRaises(ValueError): - mx.linalg.lu(mx.array(0.0), stream=mx.cpu) - - with self.assertRaises(ValueError): - mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu) - - with self.assertRaises(ValueError): - mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu) - - # Test 3x3 matrix - a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]) - P, L, U = mx.linalg.lu(a, stream=mx.cpu) - self.assertTrue(mx.allclose(L[P, :] @ U, a)) - - # Test batch dimension - a = mx.broadcast_to(a, (5, 5, 3, 3)) - P, L, U = mx.linalg.lu(a, stream=mx.cpu) - L = mx.take_along_axis(L, P[..., None], axis=-2) - self.assertTrue(mx.allclose(L @ U, a)) - - # Test non-square matrix - a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0]]) - P, L, U = mx.linalg.lu(a, stream=mx.cpu) - self.assertTrue(mx.allclose(L[P, :] @ U, a)) - - a = mx.array([[3.0, 1.0], [1.0, 8.0], [9.0, 2.0]]) - P, L, U = mx.linalg.lu(a, stream=mx.cpu) - self.assertTrue(mx.allclose(L[P, :] @ U, a)) - def test_eigh(self): tols = {"atol": 1e-5, "rtol": 1e-5}