mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Fix the accelerate dispatch for the power op (#70)
- The exponent and base were swapped because accelerate is using exponent-base instead of base-exponent - Fix also the test for binary ops as it was testing op(x, x) which couldn't catch ordering errors like that
This commit is contained in:
parent
4e3bdb560c
commit
209404239b
@ -494,7 +494,7 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
b.flags().row_contiguous) {
|
||||
int size = a.size();
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
vvpowf(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
@ -953,30 +953,32 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
|
||||
|
||||
def test_binary_ops(self):
|
||||
def test_ops(npop, mlxop, x, y, atol):
|
||||
r_np = npop(x, x)
|
||||
r_mlx = mlxop(y, y)
|
||||
def test_ops(npop, mlxop, x1, x2, y1, y2, atol):
|
||||
r_np = npop(x1, x2)
|
||||
r_mlx = mlxop(y1, y2)
|
||||
mx.eval(r_mlx)
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
||||
|
||||
r_np = npop(x[:1], x)
|
||||
r_mlx = mlxop(y[:1], y)
|
||||
r_np = npop(x1[:1], x2)
|
||||
r_mlx = mlxop(y1[:1], y2)
|
||||
mx.eval(r_mlx)
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
||||
|
||||
r_np = npop(x[:, :1], x)
|
||||
r_mlx = mlxop(y[:, :1], y)
|
||||
r_np = npop(x1[:, :1], x2)
|
||||
r_mlx = mlxop(y1[:, :1], y2)
|
||||
mx.eval(r_mlx)
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
||||
|
||||
r_np = npop(x[:, :, :1], x)
|
||||
r_mlx = mlxop(y[:, :, :1], y)
|
||||
r_np = npop(x1[:, :, :1], x2)
|
||||
r_mlx = mlxop(y1[:, :, :1], y2)
|
||||
mx.eval(r_mlx)
|
||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
||||
|
||||
x = np.maximum(np.random.rand(18, 28, 38), 0.1)
|
||||
y = mx.array(x)
|
||||
mx.eval(y)
|
||||
x1 = np.maximum(np.random.rand(18, 28, 38), 0.1)
|
||||
x2 = np.maximum(np.random.rand(18, 28, 38), 0.1)
|
||||
y1 = mx.array(x1)
|
||||
y2 = mx.array(x2)
|
||||
mx.eval(y1, y2)
|
||||
for op in [
|
||||
"add",
|
||||
"subtract",
|
||||
@ -1008,9 +1010,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
for dtype in dtypes:
|
||||
atol = 1e-3 if dtype == "float16" else 1e-6
|
||||
with self.subTest(dtype=dtype):
|
||||
x_ = x.astype(getattr(np, dtype))
|
||||
y_ = mx.array(x_)
|
||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
||||
x1_ = x1.astype(getattr(np, dtype))
|
||||
x2_ = x2.astype(getattr(np, dtype))
|
||||
y1_ = mx.array(x1_)
|
||||
y2_ = mx.array(x2_)
|
||||
test_ops(
|
||||
getattr(np, op), getattr(mx, op), x1_, x2_, y1_, y2_, atol
|
||||
)
|
||||
|
||||
def test_irregular_binary_ops(self):
|
||||
# Check transposed binary ops
|
||||
|
Loading…
Reference in New Issue
Block a user