mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Allow unary ops to accept array like (#1093)
This commit is contained in:
		| @@ -1217,6 +1217,54 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                         y_ = mx.array(x_) | ||||
|                         test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol) | ||||
|  | ||||
|     def test_unary_ops_from_non_array(self): | ||||
|         unary_ops = [ | ||||
|             "abs", | ||||
|             "exp", | ||||
|             "log", | ||||
|             "square", | ||||
|             "sqrt", | ||||
|             "sin", | ||||
|             "cos", | ||||
|             "tan", | ||||
|             "sinh", | ||||
|             "cosh", | ||||
|             "tanh", | ||||
|             "sign", | ||||
|             "negative", | ||||
|             "expm1", | ||||
|             "arcsin", | ||||
|             "arccos", | ||||
|             "arctan", | ||||
|             "arcsinh", | ||||
|             "arctanh", | ||||
|             "degrees", | ||||
|             "radians", | ||||
|             "log2", | ||||
|             "log10", | ||||
|             "log1p", | ||||
|             "floor", | ||||
|             "ceil", | ||||
|         ] | ||||
|  | ||||
|         x = 0.5 | ||||
|         x_np = np.random.rand(10).astype(np.float32) | ||||
|         for op in unary_ops: | ||||
|             with self.subTest(op=op): | ||||
|                 # Test from scalar | ||||
|                 expected = getattr(np, op)(x) | ||||
|                 out = getattr(mx, op)(x) | ||||
|  | ||||
|                 # Check close | ||||
|                 self.assertTrue(np.allclose(expected, out, equal_nan=True)) | ||||
|  | ||||
|                 # Test from NumPy | ||||
|                 expected = getattr(np, op)(x_np) | ||||
|                 out = getattr(mx, op)(x_np) | ||||
|  | ||||
|                 # Check close | ||||
|                 self.assertTrue(np.allclose(expected, np.array(out), equal_nan=True)) | ||||
|  | ||||
|     def test_trig_ops(self): | ||||
|         def test_ops(npop, mlxop, x, y, atol): | ||||
|             r_np = npop(x) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun