From 53e6a9367ca1218e9e97e5b174bf41070a71386d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 21 Mar 2024 10:21:03 -0700 Subject: [PATCH] Use reshape and transpose for non-overlapping pooling windows (#867) --- python/mlx/nn/layers/pooling.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index ffa05f5d2..9d6267bf9 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -20,6 +20,22 @@ def _value_or_list(x, n, msg): return [x] * n +def _non_overlapping_sliding_windows(x, shape, window_shape): + # Compute the intermediate shape + new_shape = [shape[0]] + for s, w in zip(shape[1:], window_shape): + new_shape.append(s // w) + new_shape.append(w) + new_shape.append(shape[-1]) + + last_axis = len(new_shape) - 1 + axis_order = [0, *range(1, last_axis, 2), *range(2, last_axis, 2), last_axis] + + x = x.reshape(new_shape) + x = x.transpose(axis_order) + return x + + def _sliding_windows(x, window_shape, window_strides): if x.ndim < 3: raise ValueError( @@ -37,6 +53,12 @@ def _sliding_windows(x, window_shape, window_strides): ) shape = x.shape + if all( + window == stride and size % window == 0 + for size, window, stride in zip(spatial_dims, window_shape, window_strides) + ): + return _non_overlapping_sliding_windows(x, shape, window_shape) + strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:] # Compute the output shape