Add a multi optimizer (#1916)

This commit is contained in:
Angelos Katharopoulos
2025-03-04 13:16:35 -08:00
committed by GitHub
parent a0737273d3
commit 9680f72cca
5 changed files with 168 additions and 1 deletions

View File

@@ -345,6 +345,7 @@ class TestVmap(mlx_tests.MLXTestCase):
)
def test_vmap_inverse(self):
mx.random.seed(42)
a = mx.random.uniform(shape=(3, 4, 4))
cpu_inv = lambda x: mx.linalg.inv(x, stream=mx.cpu)