mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Fix deep recursion with siblings (#1462)
* fix recursion with siblings * fix * add test * increase tol
This commit is contained in:
@@ -902,7 +902,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
dw10 = (cotan[1::s, :-1:s] * x).sum()
|
||||
dw11 = (cotan[1::s, 1::s] * x).sum()
|
||||
expected = mx.array([[dw00, dw01], [dw10, dw11]])
|
||||
self.assertTrue(mx.allclose(dw, expected))
|
||||
self.assertTrue(mx.allclose(dw, expected, rtol=1e-5, atol=1e-5))
|
||||
|
||||
def test_conv_groups_grad(self):
|
||||
def fn(x, w):
|
||||
|
Reference in New Issue
Block a user