From b444a6a693354c59bd0d0aaa14c58b53057db565 Mon Sep 17 00:00:00 2001 From: __mo_san__ <50895527+m0saan@users.noreply.github.com> Date: Fri, 22 Dec 2023 20:57:32 +0100 Subject: [PATCH] Update normalization.py --- python/mlx/nn/layers/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 8ab41de13..17e6ab669 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -208,7 +208,7 @@ class BatchNorm(Module): >>> mx.random.seed(42) >>> input = mx.random.normal((5, 4), dtype=mx.float32) >>> # Batch norm - >>> bn = nn.BatchNorm1d(num_features=4, affine=True) + >>> bn = nn.BatchNorm(num_features=4, affine=True) >>> output = bn(x) """