mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add AdaptiveAvgPool2d and AdaptiveAvgPool3d
This commit is contained in:
@@ -20,6 +20,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
- Vincent Amato: Added `AdaptiveAvgPool2d` and `AdaptiveAvgPool3d` layers.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
||||
@@ -77,6 +77,8 @@ from mlx.nn.layers.normalization import (
|
||||
RMSNorm,
|
||||
)
|
||||
from mlx.nn.layers.pooling import (
|
||||
AdaptiveAvgPool2d,
|
||||
AdaptiveAvgPool3d,
|
||||
AvgPool1d,
|
||||
AvgPool2d,
|
||||
AvgPool3d,
|
||||
|
||||
@@ -396,3 +396,237 @@ class AvgPool3d(_Pool3d):
|
||||
padding: Optional[Union[int, Tuple[int, int, int]]] = 0,
|
||||
):
|
||||
super().__init__(mx.mean, 0, kernel_size, stride, padding)
|
||||
|
||||
|
||||
class AdaptiveAvgPool2d(Module):
|
||||
r"""Applies 2-dimensional adaptive average pooling.
|
||||
|
||||
The output size is H x W, for any input size. The number of output
|
||||
features is equal to the number of input planes.
|
||||
|
||||
Args:
|
||||
output_size: the target output size of the form H x W.
|
||||
Can be a tuple (H, W) or a single int for a square image.
|
||||
H and W can be either an ``int``, or ``None`` which means the size
|
||||
will be the same as that of the input.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> x = mx.random.normal(shape=(8, 32, 32, 4))
|
||||
>>> pool = nn.AdaptiveAvgPool2d((5, 7))
|
||||
>>> pool(x)
|
||||
>>> pool = nn.AdaptiveAvgPool2d(7)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(self, output_size: Union[int, Tuple[Optional[int], Optional[int]]]):
|
||||
super().__init__()
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self, x):
|
||||
return adaptive_avg_pool2d(x, self.output_size)
|
||||
|
||||
|
||||
def adaptive_avg_pool2d(
|
||||
x: mx.array, output_size: Union[int, Tuple[Optional[int], Optional[int]]]
|
||||
) -> mx.array:
|
||||
r"""Apply 2-dimensional adaptive average pooling.
|
||||
|
||||
Args:
|
||||
x: Input array of shape (N, H, W, C) or (H, W, C).
|
||||
output_size: Target output size (H, W) or single int for square output.
|
||||
Values can be None to keep the corresponding input dimension.
|
||||
|
||||
Returns:
|
||||
Output array with spatial dimensions matching output_size.
|
||||
"""
|
||||
# Parse output_size
|
||||
if isinstance(output_size, int):
|
||||
output_size = (output_size, output_size)
|
||||
elif len(output_size) == 1:
|
||||
output_size = (output_size[0], output_size[0])
|
||||
|
||||
# Get input dimensions
|
||||
*batch_dims, H, W, C = x.shape
|
||||
|
||||
# Handle None values in output_size
|
||||
output_H = H if output_size[0] is None else output_size[0]
|
||||
output_W = W if output_size[1] is None else output_size[1]
|
||||
|
||||
# If already the right size, return as is
|
||||
if H == output_H and W == output_W:
|
||||
return x
|
||||
|
||||
# Calculate kernel size and stride
|
||||
kernel_H = H // output_H
|
||||
kernel_W = W // output_W
|
||||
stride_H = H // output_H
|
||||
stride_W = W // output_W
|
||||
|
||||
# For exact division, use regular pooling
|
||||
if H % output_H == 0 and W % output_W == 0:
|
||||
# Reshape for pooling: (batch..., H, W, C) -> (batch..., output_H, kernel_H, output_W, kernel_W, C)
|
||||
new_shape = batch_dims + [output_H, kernel_H, output_W, kernel_W, C]
|
||||
x_reshaped = x.reshape(new_shape)
|
||||
|
||||
# Average over kernel dimensions (kernel_H is at -4, kernel_W is at -2)
|
||||
result = mx.mean(
|
||||
x_reshaped, axis=[-4, -2]
|
||||
) # Average over kernel_H and kernel_W
|
||||
return result
|
||||
|
||||
# For non-exact division, use strided approach with overlap
|
||||
else:
|
||||
# Calculate actual stride to fit exactly
|
||||
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
|
||||
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
|
||||
|
||||
# Collect all averaged values
|
||||
values = []
|
||||
for i in range(output_H):
|
||||
row_values = []
|
||||
for j in range(output_W):
|
||||
h_start = i * stride_H
|
||||
h_end = min(h_start + kernel_H, H)
|
||||
w_start = j * stride_W
|
||||
w_end = min(w_start + kernel_W, W)
|
||||
|
||||
# Extract region and average
|
||||
region = x[..., h_start:h_end, w_start:w_end, :]
|
||||
avg_val = mx.mean(region, axis=(-3, -2)) # Average over H and W
|
||||
row_values.append(avg_val)
|
||||
values.append(mx.stack(row_values, axis=-2)) # Stack along W dimension
|
||||
|
||||
# Stack all rows along H dimension
|
||||
result = mx.stack(values, axis=-3)
|
||||
return result
|
||||
|
||||
|
||||
class AdaptiveAvgPool3d(Module):
|
||||
r"""Applies 3-dimensional adaptive average pooling.
|
||||
|
||||
The output size is D x H x W, for any input size. The number of output
|
||||
features is equal to the number of input planes.
|
||||
|
||||
Args:
|
||||
output_size: the target output size of the form D x H x W.
|
||||
Can be a tuple (D, H, W) or a single int for a cube D x D x D.
|
||||
D, H and W can be either an ``int``, or ``None`` which means the size
|
||||
will be the same as that of the input.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
|
||||
>>> pool = nn.AdaptiveAvgPool3d((5, 7, 9))
|
||||
>>> pool(x)
|
||||
>>> pool = nn.AdaptiveAvgPool3d(7)
|
||||
>>> pool(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
|
||||
):
|
||||
super().__init__()
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self, x):
|
||||
return adaptive_avg_pool3d(x, self.output_size)
|
||||
|
||||
|
||||
def adaptive_avg_pool3d(
|
||||
x: mx.array,
|
||||
output_size: Union[int, Tuple[Optional[int], Optional[int], Optional[int]]],
|
||||
) -> mx.array:
|
||||
r"""Apply 3-dimensional adaptive average pooling.
|
||||
|
||||
Args:
|
||||
x: Input array of shape (N, D, H, W, C) or (D, H, W, C).
|
||||
output_size: Target output size (D, H, W) or single int for cube output.
|
||||
Values can be None to keep the corresponding input dimension.
|
||||
|
||||
Returns:
|
||||
Output array with spatial dimensions matching output_size.
|
||||
"""
|
||||
# Parse output_size
|
||||
if isinstance(output_size, int):
|
||||
output_size = (output_size, output_size, output_size)
|
||||
elif len(output_size) == 1:
|
||||
output_size = (output_size[0], output_size[0], output_size[0])
|
||||
elif len(output_size) == 2:
|
||||
output_size = (output_size[0], output_size[1], output_size[1])
|
||||
|
||||
# Get input dimensions
|
||||
*batch_dims, D, H, W, C = x.shape
|
||||
|
||||
# Handle None values in output_size
|
||||
output_D = D if output_size[0] is None else output_size[0]
|
||||
output_H = H if output_size[1] is None else output_size[1]
|
||||
output_W = W if output_size[2] is None else output_size[2]
|
||||
|
||||
# If already the right size, return as is
|
||||
if D == output_D and H == output_H and W == output_W:
|
||||
return x
|
||||
|
||||
# Calculate kernel size and stride
|
||||
kernel_D = D // output_D
|
||||
kernel_H = H // output_H
|
||||
kernel_W = W // output_W
|
||||
|
||||
# For exact division, use regular pooling
|
||||
if D % output_D == 0 and H % output_H == 0 and W % output_W == 0:
|
||||
# Reshape for pooling: (batch..., D, H, W, C) -> (batch..., output_D, kernel_D, output_H, kernel_H, output_W, kernel_W, C)
|
||||
new_shape = batch_dims + [
|
||||
output_D,
|
||||
kernel_D,
|
||||
output_H,
|
||||
kernel_H,
|
||||
output_W,
|
||||
kernel_W,
|
||||
C,
|
||||
]
|
||||
x_reshaped = x.reshape(new_shape)
|
||||
|
||||
# Average over kernel dimensions (kernel_D is at -6, kernel_H is at -4, kernel_W is at -2)
|
||||
result = mx.mean(
|
||||
x_reshaped, axis=[-6, -4, -2]
|
||||
) # Average over kernel_D, kernel_H and kernel_W
|
||||
return result
|
||||
|
||||
# For non-exact division, use strided approach with overlap
|
||||
else:
|
||||
# Calculate actual stride to fit exactly
|
||||
stride_D = (D - kernel_D) // (output_D - 1) if output_D > 1 else 1
|
||||
stride_H = (H - kernel_H) // (output_H - 1) if output_H > 1 else 1
|
||||
stride_W = (W - kernel_W) // (output_W - 1) if output_W > 1 else 1
|
||||
|
||||
# Collect all averaged values
|
||||
values = []
|
||||
for i in range(output_D):
|
||||
depth_values = []
|
||||
for j in range(output_H):
|
||||
row_values = []
|
||||
for k in range(output_W):
|
||||
d_start = i * stride_D
|
||||
d_end = min(d_start + kernel_D, D)
|
||||
h_start = j * stride_H
|
||||
h_end = min(h_start + kernel_H, H)
|
||||
w_start = k * stride_W
|
||||
w_end = min(w_start + kernel_W, W)
|
||||
|
||||
# Extract region and average
|
||||
region = x[..., d_start:d_end, h_start:h_end, w_start:w_end, :]
|
||||
avg_val = mx.mean(
|
||||
region, axis=(-4, -3, -2)
|
||||
) # Average over D, H and W
|
||||
row_values.append(avg_val)
|
||||
depth_values.append(
|
||||
mx.stack(row_values, axis=-2)
|
||||
) # Stack along W dimension
|
||||
values.append(mx.stack(depth_values, axis=-3)) # Stack along H dimension
|
||||
|
||||
# Stack all depths along D dimension
|
||||
result = mx.stack(values, axis=-4)
|
||||
return result
|
||||
|
||||
@@ -1818,6 +1818,79 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
"AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))",
|
||||
)
|
||||
|
||||
# Test AdaptiveAvgPool2d
|
||||
# Test exact division case (8x8 -> 4x4)
|
||||
x = mx.arange(64, dtype=mx.float32).reshape(1, 8, 8, 1)
|
||||
pool = nn.AdaptiveAvgPool2d((4, 4))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 4, 4, 1))
|
||||
# Check first few values manually
|
||||
expected_00 = (0 + 1 + 8 + 9) / 4 # Top-left 2x2 block
|
||||
expected_01 = (2 + 3 + 10 + 11) / 4 # Next 2x2 block
|
||||
self.assertAlmostEqual(output[0, 0, 0, 0].item(), expected_00, places=5)
|
||||
self.assertAlmostEqual(output[0, 0, 1, 0].item(), expected_01, places=5)
|
||||
|
||||
# Test square output format (int input)
|
||||
pool = nn.AdaptiveAvgPool2d(2)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 2, 2, 1))
|
||||
|
||||
# Test None values (keep original dimension)
|
||||
x = mx.random.normal((2, 6, 8, 3))
|
||||
pool = nn.AdaptiveAvgPool2d((3, None))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 8, 3))
|
||||
|
||||
pool = nn.AdaptiveAvgPool2d((None, 4))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 6, 4, 3))
|
||||
|
||||
# Test non-exact division case
|
||||
x = mx.random.normal((1, 7, 5, 2))
|
||||
pool = nn.AdaptiveAvgPool2d((3, 2))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 3, 2, 2))
|
||||
|
||||
# Test 3D input (no batch dimension)
|
||||
x = mx.random.normal((6, 6, 4))
|
||||
pool = nn.AdaptiveAvgPool2d((2, 3))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 4))
|
||||
|
||||
# Test AdaptiveAvgPool3d
|
||||
# Test exact division case (8x8x8 -> 4x4x4)
|
||||
x = mx.arange(512, dtype=mx.float32).reshape(1, 8, 8, 8, 1)
|
||||
pool = nn.AdaptiveAvgPool3d((4, 4, 4))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 4, 4, 4, 1))
|
||||
|
||||
# Test cube output format (int input)
|
||||
pool = nn.AdaptiveAvgPool3d(2)
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 2, 2, 2, 1))
|
||||
|
||||
# Test None values (keep original dimensions)
|
||||
x = mx.random.normal((2, 6, 8, 10, 3))
|
||||
pool = nn.AdaptiveAvgPool3d((3, None, 5))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 8, 5, 3))
|
||||
|
||||
pool = nn.AdaptiveAvgPool3d((None, 4, None))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 6, 4, 10, 3))
|
||||
|
||||
# Test non-exact division case
|
||||
x = mx.random.normal((1, 7, 5, 9, 2))
|
||||
pool = nn.AdaptiveAvgPool3d((3, 2, 4))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (1, 3, 2, 4, 2))
|
||||
|
||||
# Test 4D input (no batch dimension)
|
||||
x = mx.random.normal((6, 6, 6, 4))
|
||||
pool = nn.AdaptiveAvgPool3d((2, 3, 2))
|
||||
output = pool(x)
|
||||
self.assertEqual(output.shape, (2, 3, 2, 4))
|
||||
|
||||
def test_set_dtype(self):
|
||||
def assert_dtype(layer, dtype):
|
||||
for k, v in tree_flatten(layer.parameters()):
|
||||
|
||||
Reference in New Issue
Block a user