mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 20:11:19 +08:00
Fix the accelerate dispatch for the power op (#70)
- The exponent and base were swapped because accelerate is using exponent-base instead of base-exponent - Fix also the test for binary ops as it was testing op(x, x) which couldn't catch ordering errors like that
This commit is contained in:
parent
4e3bdb560c
commit
209404239b
@ -494,7 +494,7 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
b.flags().row_contiguous) {
|
b.flags().row_contiguous) {
|
||||||
int size = a.size();
|
int size = a.size();
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
vvpowf(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
@ -953,30 +953,32 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
|
test_ops(np_vjp, mx_vjp, x_, y_, atol_)
|
||||||
|
|
||||||
def test_binary_ops(self):
|
def test_binary_ops(self):
|
||||||
def test_ops(npop, mlxop, x, y, atol):
|
def test_ops(npop, mlxop, x1, x2, y1, y2, atol):
|
||||||
r_np = npop(x, x)
|
r_np = npop(x1, x2)
|
||||||
r_mlx = mlxop(y, y)
|
r_mlx = mlxop(y1, y2)
|
||||||
mx.eval(r_mlx)
|
mx.eval(r_mlx)
|
||||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
||||||
|
|
||||||
r_np = npop(x[:1], x)
|
r_np = npop(x1[:1], x2)
|
||||||
r_mlx = mlxop(y[:1], y)
|
r_mlx = mlxop(y1[:1], y2)
|
||||||
mx.eval(r_mlx)
|
mx.eval(r_mlx)
|
||||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
||||||
|
|
||||||
r_np = npop(x[:, :1], x)
|
r_np = npop(x1[:, :1], x2)
|
||||||
r_mlx = mlxop(y[:, :1], y)
|
r_mlx = mlxop(y1[:, :1], y2)
|
||||||
mx.eval(r_mlx)
|
mx.eval(r_mlx)
|
||||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
||||||
|
|
||||||
r_np = npop(x[:, :, :1], x)
|
r_np = npop(x1[:, :, :1], x2)
|
||||||
r_mlx = mlxop(y[:, :, :1], y)
|
r_mlx = mlxop(y1[:, :, :1], y2)
|
||||||
mx.eval(r_mlx)
|
mx.eval(r_mlx)
|
||||||
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
self.assertTrue(np.allclose(r_np, r_mlx, atol=atol))
|
||||||
|
|
||||||
x = np.maximum(np.random.rand(18, 28, 38), 0.1)
|
x1 = np.maximum(np.random.rand(18, 28, 38), 0.1)
|
||||||
y = mx.array(x)
|
x2 = np.maximum(np.random.rand(18, 28, 38), 0.1)
|
||||||
mx.eval(y)
|
y1 = mx.array(x1)
|
||||||
|
y2 = mx.array(x2)
|
||||||
|
mx.eval(y1, y2)
|
||||||
for op in [
|
for op in [
|
||||||
"add",
|
"add",
|
||||||
"subtract",
|
"subtract",
|
||||||
@ -1008,9 +1010,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
atol = 1e-3 if dtype == "float16" else 1e-6
|
atol = 1e-3 if dtype == "float16" else 1e-6
|
||||||
with self.subTest(dtype=dtype):
|
with self.subTest(dtype=dtype):
|
||||||
x_ = x.astype(getattr(np, dtype))
|
x1_ = x1.astype(getattr(np, dtype))
|
||||||
y_ = mx.array(x_)
|
x2_ = x2.astype(getattr(np, dtype))
|
||||||
test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol)
|
y1_ = mx.array(x1_)
|
||||||
|
y2_ = mx.array(x2_)
|
||||||
|
test_ops(
|
||||||
|
getattr(np, op), getattr(mx, op), x1_, x2_, y1_, y2_, atol
|
||||||
|
)
|
||||||
|
|
||||||
def test_irregular_binary_ops(self):
|
def test_irregular_binary_ops(self):
|
||||||
# Check transposed binary ops
|
# Check transposed binary ops
|
||||||
|
Loading…
Reference in New Issue
Block a user