diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 8ab41de13..17e6ab669 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -208,7 +208,7 @@ class BatchNorm(Module): >>> mx.random.seed(42) >>> input = mx.random.normal((5, 4), dtype=mx.float32) >>> # Batch norm - >>> bn = nn.BatchNorm1d(num_features=4, affine=True) + >>> bn = nn.BatchNorm(num_features=4, affine=True) >>> output = bn(x) """