mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Refactor adaptive pooling for style consistency
This commit is contained in:
@@ -398,21 +398,40 @@ class AvgPool3d(_Pool3d):
|
|||||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveAvgPool2d(Module):
|
class _AdaptivePool(Module):
|
||||||
|
"""Base class for adaptive pooling layers."""
|
||||||
|
|
||||||
|
def __init__(self, output_size):
|
||||||
|
super().__init__()
|
||||||
|
self.output_size = output_size
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return f"output_size={self.output_size}"
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveAvgPool2d(_AdaptivePool):
|
||||||
r"""Applies 2-dimensional adaptive average pooling.
|
r"""Applies 2-dimensional adaptive average pooling.
|
||||||
|
|
||||||
The output size is H x W, for any input size. The number of output
|
Spatially downsamples the input by taking the average over pooling regions
|
||||||
features is equal to the number of input planes.
|
such that the output size is H x W, for any input size.
|
||||||
|
|
||||||
|
The parameters can be:
|
||||||
|
|
||||||
|
* a single ``int`` -- in which case the same value is used for both the
|
||||||
|
height and width dimensions, creating a square output.
|
||||||
|
* a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
|
||||||
|
used for the height dimension, the second ``int`` for the width dimension.
|
||||||
|
* ``None`` can be used for either dimension to keep the input size unchanged.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_size: the target output size of the form H x W.
|
output_size (int or tuple(int, int)): The target output size of the form H x W.
|
||||||
Can be a tuple (H, W) or a single int for a square image.
|
Can be a tuple (H, W) or a single int for a square output.
|
||||||
H and W can be either an ``int``, or ``None`` which means the size
|
H and W can be either an ``int``, or ``None`` which means the size
|
||||||
will be the same as that of the input.
|
will be the same as that of the input.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mlx.core as mx
|
>>> import mlx.core as mx
|
||||||
>>> import mlx.nn as nn
|
>>> import mlx.nn.layers as nn
|
||||||
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
|
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
|
||||||
>>> pool = nn.AdaptiveAvgPool2d((5, 7))
|
>>> pool = nn.AdaptiveAvgPool2d((5, 7))
|
||||||
>>> pool(x)
|
>>> pool(x)
|
||||||
@@ -421,103 +440,74 @@ class AdaptiveAvgPool2d(Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]):
|
def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]):
|
||||||
super().__init__()
|
super().__init__(output_size)
|
||||||
self.output_size = output_size
|
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return adaptive_avg_pool2d(x, self.output_size)
|
output_size = self.output_size
|
||||||
|
if isinstance(output_size, int):
|
||||||
|
output_size = (output_size, output_size)
|
||||||
|
elif len(output_size) == 1:
|
||||||
|
output_size = (output_size[0], output_size[0])
|
||||||
|
|
||||||
|
*batch_dims, H, W, C = x.shape
|
||||||
|
|
||||||
|
output_H = H if output_size[0] is None else output_size[0]
|
||||||
|
output_W = W if output_size[1] is None else output_size[1]
|
||||||
|
|
||||||
|
if H == output_H and W == output_W:
|
||||||
|
return x
|
||||||
|
|
||||||
|
kernel_H = H // output_H
|
||||||
|
kernel_W = W // output_W
|
||||||
|
|
||||||
|
if H % output_H == 0 and W % output_W == 0:
|
||||||
|
new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C]
|
||||||
|
x_reshaped = x.reshape(new_shape)
|
||||||
|
return mx.mean(x_reshaped, axis=[-4, -2])
|
||||||
|
else:
|
||||||
|
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
|
||||||
|
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
|
||||||
|
|
||||||
|
values = []
|
||||||
|
for i in range(output_H):
|
||||||
|
row_values = []
|
||||||
|
for j in range(output_W):
|
||||||
|
h_start = i * stride_H
|
||||||
|
h_end = min(h_start + kernel_H, H)
|
||||||
|
w_start = j * stride_W
|
||||||
|
w_end = min(w_start + kernel_W, W)
|
||||||
|
|
||||||
|
region = x[..., h_start:h_end, w_start:w_end, :]
|
||||||
|
row_values.append(mx.mean(region, axis=(-3, -2)))
|
||||||
|
values.append(mx.stack(row_values, axis=-2))
|
||||||
|
|
||||||
|
return mx.stack(values, axis=-3)
|
||||||
|
|
||||||
|
|
||||||
def adaptive_avg_pool2d(
|
class AdaptiveAvgPool3d(_AdaptivePool):
|
||||||
x: mx.array, output_size: Union[int, Tuple[Optional[int], Optional[int]]]
|
|
||||||
) -> mx.array:
|
|
||||||
r"""Apply 2-dimensional adaptive average pooling.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Input array of shape (N, H, W, C) or (H, W, C).
|
|
||||||
output_size: Target output size (H, W) or single int for square output.
|
|
||||||
Values can be None to keep the corresponding input dimension.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Output array with spatial dimensions matching output_size.
|
|
||||||
"""
|
|
||||||
# Parse output_size
|
|
||||||
if isinstance(output_size, int):
|
|
||||||
output_size = (output_size, output_size)
|
|
||||||
elif len(output_size) == 1:
|
|
||||||
output_size = (output_size[0], output_size[0])
|
|
||||||
|
|
||||||
# Get input dimensions
|
|
||||||
*batch_dims, H, W, C = x.shape
|
|
||||||
|
|
||||||
# Handle None values in output_size
|
|
||||||
output_H = H if output_size[0] is None else output_size[0]
|
|
||||||
output_W = W if output_size[1] is None else output_size[1]
|
|
||||||
|
|
||||||
# If already the right size, return as is
|
|
||||||
if H == output_H and W == output_W:
|
|
||||||
return x
|
|
||||||
|
|
||||||
# Calculate kernel size and stride
|
|
||||||
kernel_H = H // output_H
|
|
||||||
kernel_W = W // output_W
|
|
||||||
stride_H = H // output_H
|
|
||||||
stride_W = W // output_W
|
|
||||||
|
|
||||||
# For exact division, use regular pooling
|
|
||||||
if H % output_H == 0 and W % output_W == 0:
|
|
||||||
# Reshape for pooling: (batch..., H, W, C) -> (batch..., output_H, kernel_H, output_W, kernel_W, C)
|
|
||||||
new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C]
|
|
||||||
x_reshaped = x.reshape(new_shape)
|
|
||||||
|
|
||||||
# Average over kernel dimensions (kernel_H is at -4, kernel_W is at -2)
|
|
||||||
result = mx.mean(
|
|
||||||
x_reshaped, axis=[-4, -2]
|
|
||||||
) # Average over kernel_H and kernel_W
|
|
||||||
return result
|
|
||||||
|
|
||||||
# For non-exact division, use strided approach with overlap
|
|
||||||
else:
|
|
||||||
# Calculate actual stride to fit exactly
|
|
||||||
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
|
|
||||||
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
|
|
||||||
|
|
||||||
# Collect all averaged values
|
|
||||||
values = []
|
|
||||||
for i in range(output_H):
|
|
||||||
row_values = []
|
|
||||||
for j in range(output_W):
|
|
||||||
h_start = i * stride_H
|
|
||||||
h_end = min(h_start + kernel_H, H)
|
|
||||||
w_start = j * stride_W
|
|
||||||
w_end = min(w_start + kernel_W, W)
|
|
||||||
|
|
||||||
# Extract region and average
|
|
||||||
region = x[..., h_start:h_end, w_start:w_end, :]
|
|
||||||
avg_val = mx.mean(region, axis=(-3, -2)) # Average over H and W
|
|
||||||
row_values.append(avg_val)
|
|
||||||
values.append(mx.stack(row_values, axis=-2)) # Stack along W dimension
|
|
||||||
|
|
||||||
# Stack all rows along H dimension
|
|
||||||
result = mx.stack(values, axis=-3)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveAvgPool3d(Module):
|
|
||||||
r"""Applies 3-dimensional adaptive average pooling.
|
r"""Applies 3-dimensional adaptive average pooling.
|
||||||
|
|
||||||
The output size is D x H x W, for any input size. The number of output
|
Spatially downsamples the input by taking the average over pooling regions
|
||||||
features is equal to the number of input planes.
|
such that the output size is D x H x W, for any input size.
|
||||||
|
|
||||||
|
The parameters can be:
|
||||||
|
|
||||||
|
* a single ``int`` -- in which case the same value is used for the depth,
|
||||||
|
height, and width dimensions, creating a cube output.
|
||||||
|
* a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
|
||||||
|
for the depth dimension, the second ``int`` for the height dimension, and
|
||||||
|
the third ``int`` for the width dimension.
|
||||||
|
* ``None`` can be used for any dimension to keep the input size unchanged.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_size: the target output size of the form D x H x W.
|
output_size (int or tuple(int, int, int)): The target output size of the form D x H x W.
|
||||||
Can be a tuple (D, H, W) or a single int for a cube D x D x D.
|
Can be a tuple (D, H, W) or a single int for a cube output.
|
||||||
D, H and W can be either an ``int``, or ``None`` which means the size
|
D, H and W can be either an ``int``, or ``None`` which means the size
|
||||||
will be the same as that of the input.
|
will be the same as that of the input.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mlx.core as mx
|
>>> import mlx.core as mx
|
||||||
>>> import mlx.nn as nn
|
>>> import mlx.nn.layers as nn
|
||||||
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
|
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
|
||||||
>>> pool = nn.AdaptiveAvgPool3d((5, 7, 9))
|
>>> pool = nn.AdaptiveAvgPool3d((5, 7, 9))
|
||||||
>>> pool(x)
|
>>> pool(x)
|
||||||
@@ -529,104 +519,63 @@ class AdaptiveAvgPool3d(Module):
|
|||||||
self,
|
self,
|
||||||
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
|
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(output_size)
|
||||||
self.output_size = output_size
|
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return adaptive_avg_pool3d(x, self.output_size)
|
output_size = self.output_size
|
||||||
|
if isinstance(output_size, int):
|
||||||
|
output_size = (output_size, output_size, output_size)
|
||||||
|
elif len(output_size) == 1:
|
||||||
|
output_size = (output_size[0], output_size[0], output_size[0])
|
||||||
|
elif len(output_size) == 2:
|
||||||
|
output_size = (output_size[0], output_size[1], output_size[1])
|
||||||
|
|
||||||
|
*batch_dims, D, H, W, C = x.shape
|
||||||
|
|
||||||
def adaptive_avg_pool3d(
|
output_D = D if output_size[0] is None else output_size[0]
|
||||||
x: mx.array,
|
output_H = H if output_size[1] is None else output_size[1]
|
||||||
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
|
output_W = W if output_size[2] is None else output_size[2]
|
||||||
) -> mx.array:
|
|
||||||
r"""Apply 3-dimensional adaptive average pooling.
|
|
||||||
|
|
||||||
Args:
|
if D == output_D and H == output_H and W == output_W:
|
||||||
x: Input array of shape (N, D, H, W, C) or (D, H, W, C).
|
return x
|
||||||
output_size: Target output size (D, H, W) or single int for cube output.
|
|
||||||
Values can be None to keep the corresponding input dimension.
|
|
||||||
|
|
||||||
Returns:
|
kernel_D = D // output_D
|
||||||
Output array with spatial dimensions matching output_size.
|
kernel_H = H // output_H
|
||||||
"""
|
kernel_W = W // output_W
|
||||||
# Parse output_size
|
|
||||||
if isinstance(output_size, int):
|
|
||||||
output_size = (output_size, output_size, output_size)
|
|
||||||
elif len(output_size) == 1:
|
|
||||||
output_size = (output_size[0], output_size[0], output_size[0])
|
|
||||||
elif len(output_size) == 2:
|
|
||||||
output_size = (output_size[0], output_size[1], output_size[1])
|
|
||||||
|
|
||||||
# Get input dimensions
|
if D % output_D == 0 and H % output_H == 0 and W % output_W == 0:
|
||||||
*batch_dims, D, H, W, C = x.shape
|
new_shape = batch_dims + [
|
||||||
|
output_D,
|
||||||
|
kernel_D,
|
||||||
|
output_H,
|
||||||
|
kernel_H,
|
||||||
|
output_W,
|
||||||
|
kernel_W,
|
||||||
|
C,
|
||||||
|
]
|
||||||
|
x_reshaped = x.reshape(new_shape)
|
||||||
|
return mx.mean(x_reshaped, axis=[-6, -4, -2])
|
||||||
|
else:
|
||||||
|
stride_D = (D - kernel_D) // (output_D - 1) if output_D > 1 else 1
|
||||||
|
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
|
||||||
|
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
|
||||||
|
|
||||||
# Handle None values in output_size
|
values = []
|
||||||
output_D = D if output_size[0] is None else output_size[0]
|
for i in range(output_D):
|
||||||
output_H = H if output_size[1] is None else output_size[1]
|
depth_values = []
|
||||||
output_W = W if output_size[2] is None else output_size[2]
|
for j in range(output_H):
|
||||||
|
row_values = []
|
||||||
|
for k in range(output_W):
|
||||||
|
d_start = i * stride_D
|
||||||
|
d_end = min(d_start + kernel_D, D)
|
||||||
|
h_start = j * stride_H
|
||||||
|
h_end = min(h_start + kernel_H, H)
|
||||||
|
w_start = k * stride_W
|
||||||
|
w_end = min(w_start + kernel_W, W)
|
||||||
|
|
||||||
# If already the right size, return as is
|
region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :]
|
||||||
if D == output_D and H == output_H and W == output_W:
|
row_values.append(mx.mean(region, axis=(-4, -3, -2)))
|
||||||
return x
|
depth_values.append(mx.stack(row_values, axis=-2))
|
||||||
|
values.append(mx.stack(depth_values, axis=-3))
|
||||||
|
|
||||||
# Calculate kernel size and stride
|
return mx.stack(values, axis=-4)
|
||||||
kernel_D = D // output_D
|
|
||||||
kernel_H = H // output_H
|
|
||||||
kernel_W = W // output_W
|
|
||||||
|
|
||||||
# For exact division, use regular pooling
|
|
||||||
if D % output_D == 0 and H % output_H == 0 and W % output_W == 0:
|
|
||||||
# Reshape for pooling: (batch..., D, H, W, C) -> (batch..., output_D, kernel_D, output_H, kernel_H, output_W, kernel_W, C)
|
|
||||||
new_shape = batch_dims + [
|
|
||||||
output_D,
|
|
||||||
kernel_D,
|
|
||||||
output_H,
|
|
||||||
kernel_H,
|
|
||||||
output_W,
|
|
||||||
kernel_W,
|
|
||||||
C,
|
|
||||||
]
|
|
||||||
x_reshaped = x.reshape(new_shape)
|
|
||||||
|
|
||||||
# Average over kernel dimensions (kernel_D is at -6, kernel_H is at -4, kernel_W is at -2)
|
|
||||||
result = mx.mean(
|
|
||||||
x_reshaped, axis=[-6, -4, -2]
|
|
||||||
) # Average over kernel_D, kernel_H and kernel_W
|
|
||||||
return result
|
|
||||||
|
|
||||||
# For non-exact division, use strided approach with overlap
|
|
||||||
else:
|
|
||||||
# Calculate actual stride to fit exactly
|
|
||||||
stride_D = (D - kernel_D) // (output_D - 1) if output_D > 1 else 1
|
|
||||||
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
|
|
||||||
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
|
|
||||||
|
|
||||||
# Collect all averaged values
|
|
||||||
values = []
|
|
||||||
for i in range(output_D):
|
|
||||||
depth_values = []
|
|
||||||
for j in range(output_H):
|
|
||||||
row_values = []
|
|
||||||
for k in range(output_W):
|
|
||||||
d_start = i * stride_D
|
|
||||||
d_end = min(d_start + kernel_D, D)
|
|
||||||
h_start = j * stride_H
|
|
||||||
h_end = min(h_start + kernel_H, H)
|
|
||||||
w_start = k * stride_W
|
|
||||||
w_end = min(w_start + kernel_W, W)
|
|
||||||
|
|
||||||
# Extract region and average
|
|
||||||
region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :]
|
|
||||||
avg_val = mx.mean(
|
|
||||||
region, axis=(-4, -3, -2)
|
|
||||||
) # Average over D, H and W
|
|
||||||
row_values.append(avg_val)
|
|
||||||
depth_values.append(
|
|
||||||
mx.stack(row_values, axis=-2)
|
|
||||||
) # Stack along W dimension
|
|
||||||
values.append(mx.stack(depth_values, axis=-3)) # Stack along H dimension
|
|
||||||
|
|
||||||
# Stack all depths along D dimension
|
|
||||||
result = mx.stack(values, axis=-4)
|
|
||||||
return result
|
|
||||||
|
|||||||
Reference in New Issue
Block a user