mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
Fix reduce edge case (#1389)
This commit is contained in:
parent
9592766939
commit
969337345f
@ -32,7 +32,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
|||||||
std::vector<int> shape = {x.shape(axes[0])};
|
std::vector<int> shape = {x.shape(axes[0])};
|
||||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||||
for (int i = 1; i < axes.size(); i++) {
|
for (int i = 1; i < axes.size(); i++) {
|
||||||
if (axes[i] - 1 == axes[i - 1]) {
|
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||||
shape.back() *= x.shape(axes[i]);
|
shape.back() *= x.shape(axes[i]);
|
||||||
strides.back() = x.strides()[axes[i]];
|
strides.back() = x.strides()[axes[i]];
|
||||||
} else {
|
} else {
|
||||||
|
@ -10,21 +10,21 @@ import numpy as np
|
|||||||
|
|
||||||
class TestReduce(mlx_tests.MLXTestCase):
|
class TestReduce(mlx_tests.MLXTestCase):
|
||||||
def test_axis_permutation_sums(self):
|
def test_axis_permutation_sums(self):
|
||||||
x_npy = np.random.randn(5, 5, 5, 5, 5).astype(np.float32)
|
for shape in [(5, 5, 1, 5, 5), (65, 65, 1, 65)]:
|
||||||
x_mlx = mx.array(x_npy)
|
with self.subTest(shape=shape):
|
||||||
for t in permutations(range(5)):
|
x_npy = (np.random.randn(*shape) * 128).astype(np.int32)
|
||||||
with self.subTest(t=t):
|
x_mlx = mx.array(x_npy)
|
||||||
y_npy = np.transpose(x_npy, t)
|
for t in permutations(range(len(shape))):
|
||||||
y_mlx = mx.transpose(x_mlx, t)
|
with self.subTest(t=t):
|
||||||
for n in range(1, 6):
|
y_npy = np.transpose(x_npy, t)
|
||||||
for a in combinations(range(5), n):
|
y_mlx = mx.transpose(x_mlx, t)
|
||||||
with self.subTest(a=a):
|
for n in range(1, len(shape) + 1):
|
||||||
z_npy = np.sum(y_npy, axis=a)
|
for a in combinations(range(len(shape)), n):
|
||||||
z_mlx = mx.sum(y_mlx, axis=a)
|
with self.subTest(a=a):
|
||||||
mx.eval(z_mlx)
|
z_npy = np.sum(y_npy, axis=a)
|
||||||
self.assertTrue(
|
z_mlx = mx.sum(y_mlx, axis=a)
|
||||||
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
|
mx.eval(z_mlx)
|
||||||
)
|
self.assertTrue(np.all(z_npy == z_mlx))
|
||||||
|
|
||||||
def test_expand_sums(self):
|
def test_expand_sums(self):
|
||||||
x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32)
|
x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user