Distributed layers (#1270)

This commit is contained in:
Angelos Katharopoulos
2025-03-21 13:52:17 -07:00
committed by GitHub
parent 69e4dd506b
commit 4eef8102c9
10 changed files with 895 additions and 80 deletions

View File

@@ -3,10 +3,10 @@
import unittest
import mlx.core as mx
import mlx_tests
import mlx_distributed_tests
class TestRingDistributed(mlx_tests.MLXTestCase):
class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
@classmethod
def setUpClass(cls):
world = mx.distributed.init(strict=True, backend="ring")