mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Fix the accelerate dispatch for the power op (#70)
- The exponent and base were swapped because accelerate is using exponent-base instead of base-exponent - Fix also the test for binary ops as it was testing op(x, x) which couldn't catch ordering errors like that
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							4e3bdb560c
						
					
				
				
					commit
					209404239b
				
			| @@ -494,7 +494,7 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) { | ||||
|       b.flags().row_contiguous) { | ||||
|     int size = a.size(); | ||||
|     out.set_data(allocator::malloc_or_wait(out.nbytes())); | ||||
|     vvpowf(out.data<float>(), a.data<float>(), b.data<float>(), &size); | ||||
|     vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size); | ||||
|   } else { | ||||
|     eval(inputs, out); | ||||
|   } | ||||
|   | ||||
| @@ -953,30 +953,32 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                     test_ops(np_vjp, mx_vjp, x_, y_, atol_) | ||||
|  | ||||
|     def test_binary_ops(self): | ||||
|         def test_ops(npop, mlxop, x, y, atol): | ||||
|             r_np = npop(x, x) | ||||
|             r_mlx = mlxop(y, y) | ||||
|         def test_ops(npop, mlxop, x1, x2, y1, y2, atol): | ||||
|             r_np = npop(x1, x2) | ||||
|             r_mlx = mlxop(y1, y2) | ||||
|             mx.eval(r_mlx) | ||||
|             self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) | ||||
|  | ||||
|             r_np = npop(x[:1], x) | ||||
|             r_mlx = mlxop(y[:1], y) | ||||
|             r_np = npop(x1[:1], x2) | ||||
|             r_mlx = mlxop(y1[:1], y2) | ||||
|             mx.eval(r_mlx) | ||||
|             self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) | ||||
|  | ||||
|             r_np = npop(x[:, :1], x) | ||||
|             r_mlx = mlxop(y[:, :1], y) | ||||
|             r_np = npop(x1[:, :1], x2) | ||||
|             r_mlx = mlxop(y1[:, :1], y2) | ||||
|             mx.eval(r_mlx) | ||||
|             self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) | ||||
|  | ||||
|             r_np = npop(x[:, :, :1], x) | ||||
|             r_mlx = mlxop(y[:, :, :1], y) | ||||
|             r_np = npop(x1[:, :, :1], x2) | ||||
|             r_mlx = mlxop(y1[:, :, :1], y2) | ||||
|             mx.eval(r_mlx) | ||||
|             self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) | ||||
|  | ||||
|         x = np.maximum(np.random.rand(18, 28, 38), 0.1) | ||||
|         y = mx.array(x) | ||||
|         mx.eval(y) | ||||
|         x1 = np.maximum(np.random.rand(18, 28, 38), 0.1) | ||||
|         x2 = np.maximum(np.random.rand(18, 28, 38), 0.1) | ||||
|         y1 = mx.array(x1) | ||||
|         y2 = mx.array(x2) | ||||
|         mx.eval(y1, y2) | ||||
|         for op in [ | ||||
|             "add", | ||||
|             "subtract", | ||||
| @@ -1008,9 +1010,13 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                 for dtype in dtypes: | ||||
|                     atol = 1e-3 if dtype == "float16" else 1e-6 | ||||
|                     with self.subTest(dtype=dtype): | ||||
|                         x_ = x.astype(getattr(np, dtype)) | ||||
|                         y_ = mx.array(x_) | ||||
|                         test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol) | ||||
|                         x1_ = x1.astype(getattr(np, dtype)) | ||||
|                         x2_ = x2.astype(getattr(np, dtype)) | ||||
|                         y1_ = mx.array(x1_) | ||||
|                         y2_ = mx.array(x2_) | ||||
|                         test_ops( | ||||
|                             getattr(np, op), getattr(mx, op), x1_, x2_, y1_, y2_, atol | ||||
|                         ) | ||||
|  | ||||
|     def test_irregular_binary_ops(self): | ||||
|         # Check transposed binary ops | ||||
|   | ||||
		Reference in New Issue
	
	Block a user