Fix reduce edge case (#1389)

This commit is contained in:
Angelos Katharopoulos
2024-09-01 21:37:51 -07:00
committed by GitHub
parent 9592766939
commit 969337345f
2 changed files with 16 additions and 16 deletions

View File

@@ -10,21 +10,21 @@ import numpy as np
class TestReduce(mlx_tests.MLXTestCase):
def test_axis_permutation_sums(self):
x_npy = np.random.randn(5, 5, 5, 5, 5).astype(np.float32)
x_mlx = mx.array(x_npy)
for t in permutations(range(5)):
with self.subTest(t=t):
y_npy = np.transpose(x_npy, t)
y_mlx = mx.transpose(x_mlx, t)
for n in range(1, 6):
for a in combinations(range(5), n):
with self.subTest(a=a):
z_npy = np.sum(y_npy, axis=a)
z_mlx = mx.sum(y_mlx, axis=a)
mx.eval(z_mlx)
self.assertTrue(
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
)
for shape in [(5, 5, 1, 5, 5), (65, 65, 1, 65)]:
with self.subTest(shape=shape):
x_npy = (np.random.randn(*shape) * 128).astype(np.int32)
x_mlx = mx.array(x_npy)
for t in permutations(range(len(shape))):
with self.subTest(t=t):
y_npy = np.transpose(x_npy, t)
y_mlx = mx.transpose(x_mlx, t)
for n in range(1, len(shape) + 1):
for a in combinations(range(len(shape)), n):
with self.subTest(a=a):
z_npy = np.sum(y_npy, axis=a)
z_mlx = mx.sum(y_mlx, axis=a)
mx.eval(z_mlx)
self.assertTrue(np.all(z_npy == z_mlx))
def test_expand_sums(self):
x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32)