Allow unary ops to accept array like (#1093)

This commit is contained in:
Awni Hannun
2024-05-09 09:36:02 -07:00
committed by GitHub
parent cc05a281c4
commit b21242faf1
2 changed files with 162 additions and 42 deletions

View File

@@ -1217,6 +1217,54 @@ class TestOps(mlx_tests.MLXTestCase):
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
def test_unary_ops_from_non_array(self):
unary_ops = [
"abs",
"exp",
"log",
"square",
"sqrt",
"sin",
"cos",
"tan",
"sinh",
"cosh",
"tanh",
"sign",
"negative",
"expm1",
"arcsin",
"arccos",
"arctan",
"arcsinh",
"arctanh",
"degrees",
"radians",
"log2",
"log10",
"log1p",
"floor",
"ceil",
]
x = 0.5
x_np = np.random.rand(10).astype(np.float32)
for op in unary_ops:
with self.subTest(op=op):
# Test from scalar
expected = getattr(np, op)(x)
out = getattr(mx, op)(x)
# Check close
self.assertTrue(np.allclose(expected, out, equal_nan=True))
# Test from NumPy
expected = getattr(np, op)(x_np)
out = getattr(mx, op)(x_np)
# Check close
self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True))
def test_trig_ops(self):
def test_ops(npop, mlxop, x, y, atol):
r_np = npop(x)