Support transposed head/seq for kv (#1950)

* support transposed head/seq for kv

* fix flaky test

* nit
This commit is contained in:
Awni Hannun
2025-03-10 10:53:45 -07:00
committed by GitHub
parent cffceda6ee
commit 3c3e558c60
4 changed files with 84 additions and 45 deletions

View File

@@ -183,9 +183,11 @@ class TestDistributed(mlx_tests.MLXTestCase):
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
all_sum_only = mx.metal.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
all_sum_with_binary = mx.metal.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)