diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 4cde5bbaf..d4051fbd3 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -494,7 +494,7 @@ void Power::eval_cpu(const std::vector& inputs, array& out) { b.flags().row_contiguous) { int size = a.size(); out.set_data(allocator::malloc_or_wait(out.nbytes())); - vvpowf(out.data(), a.data(), b.data(), &size); + vvpowf(out.data(), b.data(), a.data(), &size); } else { eval(inputs, out); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 4fdd48ac3..217a52589 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -953,30 +953,32 @@ class TestOps(mlx_tests.MLXTestCase): test_ops(np_vjp, mx_vjp, x_, y_, atol_) def test_binary_ops(self): - def test_ops(npop, mlxop, x, y, atol): - r_np = npop(x, x) - r_mlx = mlxop(y, y) + def test_ops(npop, mlxop, x1, x2, y1, y2, atol): + r_np = npop(x1, x2) + r_mlx = mlxop(y1, y2) mx.eval(r_mlx) self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) - r_np = npop(x[:1], x) - r_mlx = mlxop(y[:1], y) + r_np = npop(x1[:1], x2) + r_mlx = mlxop(y1[:1], y2) mx.eval(r_mlx) self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) - r_np = npop(x[:, :1], x) - r_mlx = mlxop(y[:, :1], y) + r_np = npop(x1[:, :1], x2) + r_mlx = mlxop(y1[:, :1], y2) mx.eval(r_mlx) self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) - r_np = npop(x[:, :, :1], x) - r_mlx = mlxop(y[:, :, :1], y) + r_np = npop(x1[:, :, :1], x2) + r_mlx = mlxop(y1[:, :, :1], y2) mx.eval(r_mlx) self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) - x = np.maximum(np.random.rand(18, 28, 38), 0.1) - y = mx.array(x) - mx.eval(y) + x1 = np.maximum(np.random.rand(18, 28, 38), 0.1) + x2 = np.maximum(np.random.rand(18, 28, 38), 0.1) + y1 = mx.array(x1) + y2 = mx.array(x2) + mx.eval(y1, y2) for op in [ "add", "subtract", @@ -1008,9 +1010,13 @@ class TestOps(mlx_tests.MLXTestCase): for dtype in dtypes: atol = 1e-3 if dtype == "float16" else 1e-6 with self.subTest(dtype=dtype): - x_ = x.astype(getattr(np, dtype)) - y_ = mx.array(x_) - test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol) + x1_ = x1.astype(getattr(np, dtype)) + x2_ = x2.astype(getattr(np, dtype)) + y1_ = mx.array(x1_) + y2_ = mx.array(x2_) + test_ops( + getattr(np, op), getattr(mx, op), x1_, x2_, y1_, y2_, atol + ) def test_irregular_binary_ops(self): # Check transposed binary ops