mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
0dbc7e5bee
...
d5f61a93fa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d5f61a93fa | ||
|
|
4a09264236 |
2
.github/workflows/pull_request.yml
vendored
2
.github/workflows/pull_request.yml
vendored
@@ -13,7 +13,7 @@ permissions:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
jobs:
|
||||
check_lint:
|
||||
|
||||
@@ -1443,23 +1443,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(a.tolist(), expected)
|
||||
|
||||
def test_unary_ops(self):
|
||||
def test_ops(npop, mlxop, x, y, atol):
|
||||
def test_ops(npop, mlxop, x, y, atol, rtol):
|
||||
r_np = npop(x)
|
||||
r_mlx = mlxop(y)
|
||||
mx.eval(r_mlx)
|
||||
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, rtol=rtol))
|
||||
|
||||
x = np.random.rand(18, 28, 38)
|
||||
for op in ["abs", "exp", "log", "square", "sqrt"]:
|
||||
with self.subTest(op=op):
|
||||
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
|
||||
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
|
||||
|
||||
for dtype, atol in float_dtypes:
|
||||
for dtype, atol, rtol in float_dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
x_ = x.astype(getattr(np, dtype))
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
|
||||
|
||||
def test_unary_ops_from_non_array(self):
|
||||
unary_ops = [
|
||||
@@ -1511,12 +1510,14 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
|
||||
|
||||
def test_trig_ops(self):
|
||||
def test_ops(npop, mlxop, x, y, atol):
|
||||
def test_ops(npop, mlxop, x, y, atol, rtol):
|
||||
r_np = npop(x)
|
||||
r_mlx = mlxop(y)
|
||||
mx.eval(r_mlx)
|
||||
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol, equal_nan=True))
|
||||
self.assertTrue(
|
||||
np.allclose(r_np, r_mlx, atol=atol, rtol=rtol, equal_nan=True)
|
||||
)
|
||||
|
||||
x = np.random.rand(9, 12, 18)
|
||||
xi = np.random.rand(9, 12, 18)
|
||||
@@ -1526,34 +1527,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
for op in all_fwd_ops:
|
||||
with self.subTest(op=op):
|
||||
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
|
||||
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
|
||||
|
||||
for dtype, atol in float_dtypes:
|
||||
for dtype, atol, rtol in float_dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
x_ = x.astype(getattr(np, dtype))
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol, rtol)
|
||||
|
||||
with self.subTest(op=op):
|
||||
float_dtypes = [("complex64", 1e-5)]
|
||||
|
||||
for dtype, atol in float_dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
x_ = x + 1.0j * xi
|
||||
x_ = x_.astype(getattr(np, dtype))
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
||||
dtype = "complex64"
|
||||
with self.subTest(dtype=dtype):
|
||||
x_ = x + 1.0j * xi
|
||||
x_ = x_.astype(getattr(np, dtype))
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, 1e-5, 1e-5)
|
||||
|
||||
with self.subTest(op="arc" + op):
|
||||
float_dtypes = [("float16", 1e-3), ("float32", 1e-6)]
|
||||
float_dtypes = [("float16", 1e-3, 1e-3), ("float32", 1e-6, 1e-5)]
|
||||
op_inv = "arc" + op
|
||||
|
||||
for dtype, atol in float_dtypes:
|
||||
for dtype, atol, rtol in float_dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_op_fwd = getattr(np, op)
|
||||
x_ = np_op_fwd(x).astype(getattr(np, dtype))
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol)
|
||||
test_ops(
|
||||
getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol, rtol
|
||||
)
|
||||
|
||||
# Test grads
|
||||
np_vjp_funcs = {
|
||||
@@ -1579,11 +1580,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
x_ = x.astype(np.float32)
|
||||
y_ = mx.array(x_)
|
||||
op_ = op
|
||||
atol_ = 1e-5
|
||||
|
||||
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
|
||||
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
|
||||
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
|
||||
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
|
||||
|
||||
with self.subTest(op="arc" + op):
|
||||
np_op_fwd = getattr(np, op)
|
||||
@@ -1599,11 +1599,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
x_ = x.astype(np.float32)
|
||||
y_ = mx.array(x_)
|
||||
op_ = "arc" + op
|
||||
atol_ = 1e-5
|
||||
|
||||
np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x)
|
||||
mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0]
|
||||
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
|
||||
test_ops(np_vjp, mx_vjp, x_, y_, 1e-5, 1e-5)
|
||||
|
||||
def test_binary_ops(self):
|
||||
def test_ops(npop, mlxop, x1, x2, y1, y2, atol):
|
||||
|
||||
Reference in New Issue
Block a user