expose depends (#2606)

This commit is contained in:
Awni Hannun
2025-09-18 10:06:15 -07:00
committed by GitHub
parent 3f730e77aa
commit 50cc09887f
2 changed files with 53 additions and 0 deletions

View File

@@ -3081,6 +3081,19 @@ class TestOps(mlx_tests.MLXTestCase):
# Doesn't hang
x = mx.power(2, -1)
def test_depends(self):
a = mx.array([1.0, 2.0, 3.0])
b = mx.exp(a)
c = mx.log(a)
out = mx.depends([b], [c])[0]
self.assertTrue(mx.array_equal(out, b))
a = mx.array([1.0, 2.0, 3.0])
b = mx.exp(a)
c = mx.log(a)
out = mx.depends(b, c)
self.assertTrue(mx.array_equal(out, b))
class TestBroadcast(mlx_tests.MLXTestCase):
def test_broadcast_shapes(self):