Merge branch 'ml-explore:main' into adding-Muon-optimizer

This commit is contained in:
Gökdeniz Gülmez
2025-03-12 16:52:21 +01:00
committed by GitHub
23 changed files with 2251 additions and 2432 deletions

View File

@@ -183,9 +183,11 @@ class TestDistributed(mlx_tests.MLXTestCase):
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
all_sum_only = mx.metal.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
all_sum_with_binary = mx.metal.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)

View File

@@ -171,7 +171,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
rtol = 1e-2
self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol))
q = mx.random.normal(shape=(1, 32, 1, Dk))
k = mx.random.normal(shape=(1, 32, 32, Dk))
v = mx.random.normal(shape=(1, 32, 128, Dk))
@@ -201,6 +200,38 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
def test_fast_sdpa_vector_kv_transposed_head_seq(self):
D = 64
Nq = 4
Nkv = 1
scale = 1.0
mx.random.seed(0)
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
lengths = [43, 4096]
for L in lengths:
k = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D))
v = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D))
k = k.swapaxes(1, 2)
v = v.swapaxes(1, 2)
masks = [
mx.array(True),
mx.array([True] * (L - 10) + [False] * 10),
mx.random.uniform(shape=(Nq, 1, L)) > 0.2,
mx.random.uniform(shape=(L, 1, Nq)).T > 0.2,
]
for m in masks:
ref = mlx_primitives_sdpa(q, k, v, scale, mask=m)
out = mx.fast.scaled_dot_product_attention(
q,
k,
v,
scale=scale,
mask=m,
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_fast_sdpa_vector(self):
D = 64
L = 43
@@ -292,7 +323,6 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
return
L = 4096
scale = 1.0
mx.random.seed(0)

View File

@@ -659,6 +659,16 @@ class TestVmap(mlx_tests.MLXTestCase):
self.assertEqual(mem_pre, mem_post)
def test_vmap_flatten(self):
def fun(x):
return mx.flatten(x, 0, 1)
x = mx.zeros((2, 3, 4))
self.assertEqual(mx.vmap(fun)(x).shape, (2, 12))
self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
if __name__ == "__main__":
unittest.main()