Compare commits

...

2 Commits

Author SHA1 Message Date
Cheng
d5f61a93fa Fix typo: refs/head/main => refs/heads/main (#2818)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-22 09:43:35 +09:00
Awni Hannun
4a09264236 Tolerance for some ops tests on cuda (#2815) 2025-11-21 16:06:16 -08:00
2 changed files with 26 additions and 27 deletions

View File

@@ -13,7 +13,7 @@ permissions:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/head/main' }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs: jobs:
check_lint: check_lint:

View File

@@ -1443,23 +1443,22 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertListEqual(a.tolist(), expected) self.assertListEqual(a.tolist(), expected)
def test_unary_ops(self): def test_unary_ops(self):
def test_ops(npop, mlxop, x, y, atol): def test_ops(npop, mlxop, x, y, atol, rtol):
r_np = npop(x) r_np = npop(x)
r_mlx = mlxop(y) r_mlx = mlxop(y)
mx.eval(r_mlx) mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, rtol=rtol))
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
x = np.random.rand(18, 28, 38) x = np.random.rand(18, 28, 38)
for op in ["abs", "exp", "log", "square", "sqrt"]: for op in ["abs", "exp", "log", "square", "sqrt"]:
with self.subTest(op=op): with self.subTest(op=op):
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)] float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
for dtype, atol in float_dtypes: for dtype, atol, rtol in float_dtypes:
with self.subTest(dtype=dtype): with self.subTest(dtype=dtype):
x_ = x.astype(getattr(np, dtype)) x_ = x.astype(getattr(np, dtype))
y_ = mx.array(x_) y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol) test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
def test_unary_ops_from_non_array(self): def test_unary_ops_from_non_array(self):
unary_ops = [ unary_ops = [
@@ -1511,12 +1510,14 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True)) self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
def test_trig_ops(self): def test_trig_ops(self):
def test_ops(npop, mlxop, x, y, atol): def test_ops(npop, mlxop, x, y, atol, rtol):
r_np = npop(x) r_np = npop(x)
r_mlx = mlxop(y) r_mlx = mlxop(y)
mx.eval(r_mlx) mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, equal_nan=True)) self.assertTrue(
np.allclose(r_np, r_mlx, atol=atol, rtol=rtol, equal_nan=True)
)
x = np.random.rand(9, 12, 18) x = np.random.rand(9, 12, 18)
xi = np.random.rand(9, 12, 18) xi = np.random.rand(9, 12, 18)
@@ -1526,34 +1527,34 @@ class TestOps(mlx_tests.MLXTestCase):
for op in all_fwd_ops: for op in all_fwd_ops:
with self.subTest(op=op): with self.subTest(op=op):
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)] float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
for dtype, atol in float_dtypes: for dtype, atol, rtol in float_dtypes:
with self.subTest(dtype=dtype): with self.subTest(dtype=dtype):
x_ = x.astype(getattr(np, dtype)) x_ = x.astype(getattr(np, dtype))
y_ = mx.array(x_) y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol) test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
with self.subTest(op=op): with self.subTest(op=op):
float_dtypes = [("complex64", 1e-5)] dtype = "complex64"
with self.subTest(dtype=dtype):
for dtype, atol in float_dtypes: x_ = x + 1.0j * xi
with self.subTest(dtype=dtype): x_ = x_.astype(getattr(np, dtype))
x_ = x + 1.0j * xi y_ = mx.array(x_)
x_ = x_.astype(getattr(np, dtype)) test_ops(getattr(np, op), getattr(mx, op), x_, y_, 1e-5, 1e-5)
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
with self.subTest(op="arc" + op): with self.subTest(op="arc" + op):
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)] float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
op_inv = "arc" + op op_inv = "arc" + op
for dtype, atol in float_dtypes: for dtype, atol, rtol in float_dtypes:
with self.subTest(dtype=dtype): with self.subTest(dtype=dtype):
np_op_fwd = getattr(np, op) np_op_fwd = getattr(np, op)
x_ = np_op_fwd(x).astype(getattr(np, dtype)) x_ = np_op_fwd(x).astype(getattr(np, dtype))
y_ = mx.array(x_) y_ = mx.array(x_)
test_ops(getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol) test_ops(
getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol, rtol
)
# Test grads # Test grads
np_vjp_funcs = { np_vjp_funcs = {
@@ -1579,11 +1580,10 @@ class TestOps(mlx_tests.MLXTestCase):
x_ = x.astype(np.float32) x_ = x.astype(np.float32)
y_ = mx.array(x_) y_ = mx.array(x_)
op_ = op op_ = op
atol_ = 1e-5
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x) np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0] mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
test_ops(np_vjp, mx_vjp, x_, y_, atol_) test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
with self.subTest(op="arc" + op): with self.subTest(op="arc" + op):
np_op_fwd = getattr(np, op) np_op_fwd = getattr(np, op)
@@ -1599,11 +1599,10 @@ class TestOps(mlx_tests.MLXTestCase):
x_ = x.astype(np.float32) x_ = x.astype(np.float32)
y_ = mx.array(x_) y_ = mx.array(x_)
op_ = "arc" + op op_ = "arc" + op
atol_ = 1e-5
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x) np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0] mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
test_ops(np_vjp, mx_vjp, x_, y_, atol_) test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
def test_binary_ops(self): def test_binary_ops(self):
def test_ops(npop, mlxop, x1, x2, y1, y2, atol): def test_ops(npop, mlxop, x1, x2, y1, y2, atol):