diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 1bcd2a2dc..3733fd777 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -317,7 +317,7 @@ class AvgPool2d(_Pool2d): >>> import mlx.core as mx >>> import mlx.nn.layers as nn >>> x = mx.random.normal(shape=(8, 32, 32, 4)) - >>> pool = nn.MaxPool2d(kernel_size=2, stride=2) + >>> pool = nn.AvgPool2d(kernel_size=2, stride=2) >>> pool(x) """