Fix the accelerate dispatch for the power op (#70)

- The exponent and base were swapped because accelerate is using
  exponent-base instead of base-exponent
- Fix also the test for binary ops as it was testing op(x, x) which
  couldn't catch ordering errors like that
This commit is contained in:
Angelos Katharopoulos 2023-12-08 10:58:03 -08:00 committed by GitHub
parent 4e3bdb560c
commit 209404239b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 16 deletions

View File

@ -494,7 +494,7 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
b.flags().row_contiguous) {
int size = a.size();
out.set_data(allocator::malloc_or_wait(out.nbytes()));
vvpowf(out.data<float>(), a.data<float>(), b.data<float>(), &size);
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
} else {
eval(inputs, out);
}

View File

@ -953,30 +953,32 @@ class TestOps(mlx_tests.MLXTestCase):
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
def test_binary_ops(self):
def test_ops(npop, mlxop, x, y, atol):
r_np = npop(x, x)
r_mlx = mlxop(y, y)
def test_ops(npop, mlxop, x1, x2, y1, y2, atol):
r_np = npop(x1, x2)
r_mlx = mlxop(y1, y2)
mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
r_np = npop(x[:1], x)
r_mlx = mlxop(y[:1], y)
r_np = npop(x1[:1], x2)
r_mlx = mlxop(y1[:1], y2)
mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
r_np = npop(x[:, :1], x)
r_mlx = mlxop(y[:, :1], y)
r_np = npop(x1[:, :1], x2)
r_mlx = mlxop(y1[:, :1], y2)
mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
r_np = npop(x[:, :, :1], x)
r_mlx = mlxop(y[:, :, :1], y)
r_np = npop(x1[:, :, :1], x2)
r_mlx = mlxop(y1[:, :, :1], y2)
mx.eval(r_mlx)
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
x = np.maximum(np.random.rand(18, 28, 38), 0.1)
y = mx.array(x)
mx.eval(y)
x1 = np.maximum(np.random.rand(18, 28, 38), 0.1)
x2 = np.maximum(np.random.rand(18, 28, 38), 0.1)
y1 = mx.array(x1)
y2 = mx.array(x2)
mx.eval(y1, y2)
for op in [
"add",
"subtract",
@ -1008,9 +1010,13 @@ class TestOps(mlx_tests.MLXTestCase):
for dtype in dtypes:
atol = 1e-3 if dtype == "float16" else 1e-6
with self.subTest(dtype=dtype):
x_ = x.astype(getattr(np, dtype))
y_ = mx.array(x_)
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
x1_ = x1.astype(getattr(np, dtype))
x2_ = x2.astype(getattr(np, dtype))
y1_ = mx.array(x1_)
y2_ = mx.array(x2_)
test_ops(
getattr(np, op), getattr(mx, op), x1_, x2_, y1_, y2_, atol
)
def test_irregular_binary_ops(self):
# Check transposed binary ops