Split multi output (#461)

* Multi-output split primitive
* Add the multi-output split to the ArrayIterator
* Add some grad tests for split
This commit is contained in:
Angelos Katharopoulos
2024-01-16 13:33:55 -08:00
committed by GitHub
parent 4e290d282f
commit d8fabaa12b
12 changed files with 202 additions and 5 deletions

View File

@@ -339,6 +339,25 @@ class TestAutograd(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
def test_split_against_slice(self):
def f_split(x):
a, _, b = x.split(3, -1)
return (a * b).sum()
def f_slice(x):
step = x.shape[-1] // 3
a = x[..., :step]
b = x[..., -step:]
return (a * b).sum()
x = mx.random.uniform(shape=(100, 300))
mx.eval(x)
df1 = mx.grad(f_split)
df2 = mx.grad(f_slice)
self.assertTrue(mx.allclose(df1(x), df2(x)))
def test_vjp_types(self):
def fun(x):
return x