mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Use reshape and transpose for non-overlapping pooling windows (#867)
This commit is contained in:
parent
f5a1582fe8
commit
53e6a9367c
@ -20,6 +20,22 @@ def _value_or_list(x, n, msg):
|
|||||||
return [x] * n
|
return [x] * n
|
||||||
|
|
||||||
|
|
||||||
|
def _non_overlapping_sliding_windows(x, shape, window_shape):
|
||||||
|
# Compute the intermediate shape
|
||||||
|
new_shape = [shape[0]]
|
||||||
|
for s, w in zip(shape[1:], window_shape):
|
||||||
|
new_shape.append(s // w)
|
||||||
|
new_shape.append(w)
|
||||||
|
new_shape.append(shape[-1])
|
||||||
|
|
||||||
|
last_axis = len(new_shape) - 1
|
||||||
|
axis_order = [0, *range(1, last_axis, 2), *range(2, last_axis, 2), last_axis]
|
||||||
|
|
||||||
|
x = x.reshape(new_shape)
|
||||||
|
x = x.transpose(axis_order)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _sliding_windows(x, window_shape, window_strides):
|
def _sliding_windows(x, window_shape, window_strides):
|
||||||
if x.ndim < 3:
|
if x.ndim < 3:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -37,6 +53,12 @@ def _sliding_windows(x, window_shape, window_strides):
|
|||||||
)
|
)
|
||||||
|
|
||||||
shape = x.shape
|
shape = x.shape
|
||||||
|
if all(
|
||||||
|
window == stride and size % window == 0
|
||||||
|
for size, window, stride in zip(spatial_dims, window_shape, window_strides)
|
||||||
|
):
|
||||||
|
return _non_overlapping_sliding_windows(x, shape, window_shape)
|
||||||
|
|
||||||
strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:]
|
strides = list(reversed(list(accumulate(reversed(shape + (1,)), operator.mul))))[1:]
|
||||||
|
|
||||||
# Compute the output shape
|
# Compute the output shape
|
||||||
|
Loading…
Reference in New Issue
Block a user