Compare commits

...

2 Commits

Author SHA1 Message Date
Cheng
d5f61a93fa Fix typo: refs/head/main => refs/heads/main (#2818)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-22 09:43:35 +09:00
Awni Hannun
4a09264236 Tolerance for some ops tests on cuda (#2815) 2025-11-21 16:06:16 -08:00
2 changed files with 26 additions and 27 deletions

View File

@@ -13,7 +13,7 @@ permissions:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
check_lint:

View File

@@ -1443,23 +1443,22 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(a.tolist(), expected)
def test_unary_ops(self):
def test_ops(npop, mlxop, x, y, atol):
def test_ops(npop, mlxop, x, y, atol, rtol):
r_np = npop(x)
r_mlx = mlxop(y)
mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, rtol=rtol))
x = np.random.rand(18, 28, 38)
for op in ["abs", "exp", "log", "square", "sqrt"]:
with self.subTest(op=op):
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
for dtype, atol in float_dtypes:
for dtype, atol, rtol in float_dtypes:
with self.subTest(dtype=dtype):
x_ = x.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
def test_unary_ops_from_non_array(self):
unary_ops = [
@@ -1511,12 +1510,14 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
def test_trig_ops(self):
def test_ops(npop, mlxop, x, y, atol):
def test_ops(npop, mlxop, x, y, atol, rtol):
r_np = npop(x)
r_mlx = mlxop(y)
mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, equal_nan=True))
self.assertTrue(
np.allclose(r_np, r_mlx, atol=atol, rtol=rtol, equal_nan=True)
)
x = np.random.rand(9, 12, 18)
xi = np.random.rand(9, 12, 18)
@@ -1526,34 +1527,34 @@ class TestOps(mlx_tests.MLXTestCase):
for op in all_fwd_ops:
with self.subTest(op=op):
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
for dtype, atol in float_dtypes:
for dtype, atol, rtol in float_dtypes:
with self.subTest(dtype=dtype):
x_ = x.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
with self.subTest(op=op):
float_dtypes = [("complex64", 1e-5)]
for dtype, atol in float_dtypes:
with self.subTest(dtype=dtype):
x_ = x + 1.0j * xi
x_ = x_.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
dtype = "complex64"
with self.subTest(dtype=dtype):
x_ = x + 1.0j * xi
x_ = x_.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, 1e-5, 1e-5)
with self.subTest(op="arc" + op):
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
op_inv = "arc" + op
for dtype, atol in float_dtypes:
for dtype, atol, rtol in float_dtypes:
with self.subTest(dtype=dtype):
np_op_fwd = getattr(np, op)
x_ = np_op_fwd(x).astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol)
test_ops(
getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol, rtol
)
# Test grads
np_vjp_funcs = {
@@ -1579,11 +1580,10 @@ class TestOps(mlx_tests.MLXTestCase):
x_ = x.astype(np.float32)
y_ = mx.array(x_)
op_ = op
atol_ = 1e-5
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
with self.subTest(op="arc" + op):
np_op_fwd = getattr(np, op)
@@ -1599,11 +1599,10 @@ class TestOps(mlx_tests.MLXTestCase):
x_ = x.astype(np.float32)
y_ = mx.array(x_)
op_ = "arc" + op
atol_ = 1e-5
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
def test_binary_ops(self):
def test_ops(npop, mlxop, x1, x2, y1, y2, atol):