mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00

committed by
GitHub

parent
4c1dfa58b7
commit
71de73a668
@@ -341,7 +341,7 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
atol, rtol = 1e-1, 1e-3
|
||||
else:
|
||||
atol, rtol = 1e-5, 1e-6
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol, rtol=rtol))
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32", "bfloat16"):
|
||||
for N, C, O in (
|
||||
@@ -1042,6 +1042,14 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(expected[0], grads[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], grads[1]))
|
||||
|
||||
def test_repeated_conv(self):
|
||||
x = mx.random.normal((1, 3, 3, 320))
|
||||
w = mx.random.normal((320, 3, 3, 320))
|
||||
for i in range(8):
|
||||
y1 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1)
|
||||
y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1)
|
||||
self.assertTrue(mx.allclose(y1, y2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user