mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Use reshape and transpose for non-overlapping pooling windows (#867)
This commit is contained in:
parent
f5a1582fe8
commit
53e6a9367c
@ -20,6 +20,22 @@ def _value_or_list(x, n, msg):
|
||||
return [x] * n
|
||||
|
||||
|
||||
def _non_overlapping_sliding_windows(x, shape, window_shape):
|
||||
# Compute the intermediate shape
|
||||
new_shape = [shape[0]]
|
||||
for s, w in zip(shape[1:], window_shape):
|
||||
new_shape.append(s // w)
|
||||
new_shape.append(w)
|
||||
new_shape.append(shape[-1])
|
||||
|
||||
last_axis = len(new_shape) - 1
|
||||
axis_order = [0, *range(1, last_axis, 2), *range(2, last_axis, 2), last_axis]
|
||||
|
||||
x = x.reshape(new_shape)
|
||||
x = x.transpose(axis_order)
|
||||
return x
|
||||
|
||||
|
||||
def _sliding_windows(x, window_shape, window_strides):
|
||||
if x.ndim < 3:
|
||||
raise ValueError(
|
||||
@ -37,6 +53,12 @@ def _sliding_windows(x, window_shape, window_strides):
|
||||
)
|
||||
|
||||
shape = x.shape
|
||||
if all(
|
||||
window == stride and size % window == 0
|
||||
for size, window, stride in zip(spatial_dims, window_shape, window_strides)
|
||||
):
|
||||
return _non_overlapping_sliding_windows(x, shape, window_shape)
|
||||
|
||||
strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:]
|
||||
|
||||
# Compute the output shape
|
||||
|
Loading…
Reference in New Issue
Block a user