diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index ffa05f5d2..9d6267bf9 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -20,6 +20,22 @@ def _value_or_list(x, n, msg): return [x] * n +def _non_overlapping_sliding_windows(x, shape, window_shape): + # Compute the intermediate shape + new_shape = [shape[0]] + for s, w in zip(shape[1:], window_shape): + new_shape.append(s // w) + new_shape.append(w) + new_shape.append(shape[-1]) + + last_axis = len(new_shape) - 1 + axis_order = [0, *range(1, last_axis, 2), *range(2, last_axis, 2), last_axis] + + x = x.reshape(new_shape) + x = x.transpose(axis_order) + return x + + def _sliding_windows(x, window_shape, window_strides): if x.ndim < 3: raise ValueError( @@ -37,6 +53,12 @@ def _sliding_windows(x, window_shape, window_strides): ) shape = x.shape + if all( + window == stride and size % window == 0 + for size, window, stride in zip(spatial_dims, window_shape, window_strides) + ): + return _non_overlapping_sliding_windows(x, shape, window_shape) + strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:] # Compute the output shape