mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
Split multi output (#461)
* Multi-output split primitive * Add the multi-output split to the ArrayIterator * Add some grad tests for split
This commit is contained in:

committed by
GitHub

parent
4e290d282f
commit
d8fabaa12b
@@ -339,6 +339,25 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
|
||||
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
|
||||
|
||||
def test_split_against_slice(self):
|
||||
def f_split(x):
|
||||
a, _, b = x.split(3, -1)
|
||||
return (a * b).sum()
|
||||
|
||||
def f_slice(x):
|
||||
step = x.shape[-1] // 3
|
||||
a = x[..., :step]
|
||||
b = x[..., -step:]
|
||||
return (a * b).sum()
|
||||
|
||||
x = mx.random.uniform(shape=(100, 300))
|
||||
mx.eval(x)
|
||||
|
||||
df1 = mx.grad(f_split)
|
||||
df2 = mx.grad(f_slice)
|
||||
|
||||
self.assertTrue(mx.allclose(df1(x), df2(x)))
|
||||
|
||||
def test_vjp_types(self):
|
||||
def fun(x):
|
||||
return x
|
||||
|
Reference in New Issue
Block a user