From cf5a5a4a01e298b93c95a0ed84d7feb2cea7e76b Mon Sep 17 00:00:00 2001 From: m0saan Date: Sat, 23 Dec 2023 23:08:10 +0100 Subject: [PATCH] updated the batch norm doc string ^^ --- python/mlx/nn/layers/normalization.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 1cff5af56..b3c1ceefb 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -196,6 +196,9 @@ class BatchNorm(Module): [1]: https://arxiv.org/abs/1502.03167 + The input tensor shape is specified as (N, C) or (N, C, L), representing the batch size (N), the number of features or channels (C), and optionally, the sequence length (L). The output tensor maintains the same shape as the input, adhering to (N, C) or (N, C, L). + For three-dimensional tensors, the shape is denoted as (N, C, H, W), where N signifies the batch size, C represents the number of channels, H corresponds to the height, and W denotes the width. + Args: num_features (int): The feature dimension of the input to normalize over. eps (float, optional): A small additive constant for numerical stability. Default is 1e-5.