mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
Allow unary ops to accept array like (#1093)
This commit is contained in:
@@ -1217,6 +1217,54 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
||||
|
||||
def test_unary_ops_from_non_array(self):
|
||||
unary_ops = [
|
||||
"abs",
|
||||
"exp",
|
||||
"log",
|
||||
"square",
|
||||
"sqrt",
|
||||
"sin",
|
||||
"cos",
|
||||
"tan",
|
||||
"sinh",
|
||||
"cosh",
|
||||
"tanh",
|
||||
"sign",
|
||||
"negative",
|
||||
"expm1",
|
||||
"arcsin",
|
||||
"arccos",
|
||||
"arctan",
|
||||
"arcsinh",
|
||||
"arctanh",
|
||||
"degrees",
|
||||
"radians",
|
||||
"log2",
|
||||
"log10",
|
||||
"log1p",
|
||||
"floor",
|
||||
"ceil",
|
||||
]
|
||||
|
||||
x = 0.5
|
||||
x_np = np.random.rand(10).astype(np.float32)
|
||||
for op in unary_ops:
|
||||
with self.subTest(op=op):
|
||||
# Test from scalar
|
||||
expected = getattr(np, op)(x)
|
||||
out = getattr(mx, op)(x)
|
||||
|
||||
# Check close
|
||||
self.assertTrue(np.allclose(expected, out, equal_nan=True))
|
||||
|
||||
# Test from NumPy
|
||||
expected = getattr(np, op)(x_np)
|
||||
out = getattr(mx, op)(x_np)
|
||||
|
||||
# Check close
|
||||
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
|
||||
|
||||
def test_trig_ops(self):
|
||||
def test_ops(npop, mlxop, x, y, atol):
|
||||
r_np = npop(x)
|
||||
|
Reference in New Issue
Block a user