Fix deep recursion with siblings (#1462)

* fix recursion with siblings

* fix

* add test

* increase tol
This commit is contained in:
Awni Hannun
2024-10-07 06:15:33 -07:00
committed by GitHub
parent 95d04805b3
commit 0070e1db40
3 changed files with 48 additions and 11 deletions

View File

@@ -902,7 +902,7 @@ class TestConv(mlx_tests.MLXTestCase):
dw10 = (cotan[1::s, :-1:s] * x).sum()
dw11 = (cotan[1::s, 1::s] * x).sum()
expected = mx.array([[dw00, dw01], [dw10, dw11]])
self.assertTrue(mx.allclose(dw, expected))
self.assertTrue(mx.allclose(dw, expected, rtol=1e-5, atol=1e-5))
def test_conv_groups_grad(self):
def fn(x, w):