Use reshape and transpose for non-overlapping pooling windows (#867)

This commit is contained in:
Angelos Katharopoulos 2024-03-21 10:21:03 -07:00 committed by GitHub
parent f5a1582fe8
commit 53e6a9367c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,6 +20,22 @@ def _value_or_list(x, n, msg):
return [x] * n
def _non_overlapping_sliding_windows(x, shape, window_shape):
# Compute the intermediate shape
new_shape = [shape[0]]
for s, w in zip(shape[1:], window_shape):
new_shape.append(s // w)
new_shape.append(w)
new_shape.append(shape[-1])
last_axis = len(new_shape) - 1
axis_order = [0, *range(1, last_axis, 2), *range(2, last_axis, 2), last_axis]
x = x.reshape(new_shape)
x = x.transpose(axis_order)
return x
def _sliding_windows(x, window_shape, window_strides):
if x.ndim < 3:
raise ValueError(
@ -37,6 +53,12 @@ def _sliding_windows(x, window_shape, window_strides):
)
shape = x.shape
if all(
window == stride and size % window == 0
for size, window, stride in zip(spatial_dims, window_shape, window_strides)
):
return _non_overlapping_sliding_windows(x, shape, window_shape)
strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:]
# Compute the output shape