Fix test tolerance and patch bump (#1315)

This commit is contained in:
Angelos Katharopoulos
2024-08-08 14:51:09 -07:00
committed by GitHub
parent eb8819e91e
commit 780c197f95
3 changed files with 7 additions and 7 deletions

View File

@@ -357,9 +357,9 @@ class TestFast(mlx_tests.MLXTestCase):
gx1, gw1, gb1 = mx.grad(f1, argnums=(0, 1, 2))(x, w, b, y)
gx2, gw2, gb2 = mx.grad(f2, argnums=(0, 1, 2))(x, w, b, y)
self.assertLess(mx.abs(gx1 - gx2).max(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 1e-5)
self.assertLess(mx.abs(gx1 - gx2).max(), 5e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 5e-5)
self.assertLess(mx.abs(gb1 - gb2).max() / mx.abs(gb1).mean(), 5e-5)
def gf(f):
def inner(x, w, b, y):
@@ -370,8 +370,8 @@ class TestFast(mlx_tests.MLXTestCase):
gx1, gw1, gb1 = mx.grad(gf(f1), argnums=(0, 1, 2))(x, w, b, y)
gx2, gw2, gb2 = mx.grad(gf(f2), argnums=(0, 1, 2))(x, w, b, y)
self.assertLess(mx.abs(gx1 - gx2).max() / mx.abs(gx1).mean(), 1e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 1e-5)
self.assertLess(mx.abs(gx1 - gx2).max() / mx.abs(gx1).mean(), 5e-5)
self.assertLess(mx.abs(gw1 - gw2).max() / mx.abs(gw1).mean(), 5e-5)
self.assertLess(mx.abs(gb1).max(), 1e-9)
self.assertLess(mx.abs(gb2).max(), 1e-9)