From 865e53fcab06b0f8ab975acc43dbe1c1e6747dd3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 25 Dec 2023 07:07:24 -0800 Subject: [PATCH] doc nits --- python/mlx/nn/layers/dropout.py | 12 +++----- python/mlx/nn/layers/normalization.py | 44 ++++++++++++++++----------- python/mlx/nn/losses.py | 6 ++-- python/tests/test_nn.py | 15 +++++---- 4 files changed, 41 insertions(+), 36 deletions(-) diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index 14c5cb15e..caa7a6452 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -5,7 +5,7 @@ from mlx.nn.layers.base import Module class Dropout(Module): - """Randomly zero a portion of the elements during training. + r"""Randomly zero a portion of the elements during training. The remaining elements are multiplied with :math:`\frac{1}{1-p}` where :math:`p` is the probability of zeroing an element. This is done so the @@ -36,15 +36,13 @@ class Dropout(Module): class Dropout2d(Module): - """Apply 2D channel-wise dropout during training. + r"""Apply 2D channel-wise dropout during training. Randomly zero out entire channels independently with probability :math:`p`. This layer expects the channels to be last, i.e. the input shape should be - ``NWHC`` or ``WHC`` where: - - ``N`` is the batch dimension - - ``H`` is the input image height - - ``W`` is the input image width - - ``C`` is the number of input channels + ``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input + image height,``W`` is the input image width, and``C`` is the number of + input channels The remaining channels are scaled by :math:`\frac{1}{1-p}` to maintain the expected value of each element. Unlike traditional dropout, diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index ac8a05f55..9cd578fb2 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -194,24 +194,31 @@ class BatchNorm(Module): where :math:`\gamma` and :math:`\beta` are learned per feature dimension parameters initialized at 1 and 0 respectively. - [1]: https://arxiv.org/abs/1502.03167 + The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the + batch, ``C`` is the number of features or channels, and ``L`` is the + sequence length. The output has the same shape as the input. For + four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are + the height and width respecitvely. - The input tensor shape is specified as (N, C) or (N, L, C), representing the batch size (N), the number of features or channels (C), and optionally, the sequence length (L). The output tensor maintains the same shape as the input, adhering to (N, C) or (N, L, C). - For three-dimensional tensors, the shape is denoted as (N, H, W, C), where N signifies the batch size, C represents the number of channels, H corresponds to the height, and W denotes the width. + For more information on Batch Normalization, see the original paper `Batch + Normalization: Accelerating Deep Network Training by Reducing Internal + Covariate Shift `_. Args: - num_features (int): The feature dimension of the input to normalize over. - eps (float, optional): A small additive constant for numerical stability. Default is 1e-5. - momentum (float, optional): The momentum for updating the running mean and variance. Default is 0.1. - affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True. - track_running_stats (bool, optional): If True, track the running mean and variance. Default is True. + num_features (int): The feature dimension to normalize over. + eps (float, optional): A small additive constant for numerical + stability. Default: ``1e-5``. + momentum (float, optional): The momentum for updating the running + mean and variance. Default: ``0.1``. + affine (bool, optional): If ``True``, apply a learned affine + transformation after the normalization. Default: ``True``. + track_running_stats (bool, optional): If ``True``, track the + running mean and variance. Default: ``True``. Examples: >>> import mlx.core as mx >>> import mlx.nn as nn - >>> mx.random.seed(42) - >>> input = mx.random.normal((5, 4), dtype=mx.float32) - >>> # Batch norm + >>> x = mx.random.normal((5, 4)) >>> bn = nn.BatchNorm(num_features=4, affine=True) >>> output = bn(x) """ @@ -229,10 +236,9 @@ class BatchNorm(Module): self.num_features = num_features self.eps = eps self.momentum = momentum - self.affine = affine self.track_running_stats = track_running_stats - if self.affine: + if affine: self.weight = mx.ones((num_features,)) self.bias = mx.zeros((num_features,)) @@ -241,7 +247,11 @@ class BatchNorm(Module): self._running_var = mx.ones((num_features,)) def _extra_repr(self): - return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}" + return ( + f"{self.num_features}, eps={self.eps}, " + f"momentum={self.momentum}, affine={'weight' in self}, " + f"track_running_stats={self.track_running_stats}" + ) def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: """ @@ -253,9 +263,7 @@ class BatchNorm(Module): Returns: tuple: Tuple containing mean and variance. """ - reduction_axes = ( - (0,) if len(x.shape) == 2 else (0, 1) if len(x.shape) == 3 else (0, 1, 2) - ) + reduction_axes = tuple(range(0, x.ndim - 1)) means = mx.mean(x, axis=reduction_axes, keepdims=True) var = mx.var(x, axis=reduction_axes, keepdims=True) @@ -279,7 +287,7 @@ class BatchNorm(Module): mx.array: Output tensor. """ - if x.ndim not in [2, 3, 4]: + if x.ndim < 2 or x.ndim > 4: raise ValueError( f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}" ) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index cfb6ffa15..91316fd04 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -286,7 +286,7 @@ def _reduce(loss: mx.array, reduction: str = "none"): def hinge_loss( inputs: mx.array, targets: mx.array, reduction: str = "none" ) -> mx.array: - """ + r""" Computes the hinge loss between inputs and targets. .. math:: @@ -311,7 +311,7 @@ def hinge_loss( def huber_loss( inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none" ) -> mx.array: - """ + r""" Computes the Huber loss between inputs and targets. .. math:: @@ -345,7 +345,7 @@ def huber_loss( def log_cosh_loss( inputs: mx.array, targets: mx.array, reduction: str = "none" ) -> mx.array: - """ + r""" Computes the log cosh loss between inputs and targets. Logcosh acts like L2 loss for small errors, ensuring stable gradients, diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 9f0eb9ff8..d2c83851e 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -3,7 +3,6 @@ import os import tempfile import unittest -from unittest.mock import Mock, patch import mlx.core as mx import mlx.nn as nn @@ -342,9 +341,9 @@ class TestNN(mlx_tests.MLXTestCase): expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778]) expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258]) self.assertTrue(x.shape == y.shape) - self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) - self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5)) - self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5)) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) # test eval mode bn.eval() @@ -360,7 +359,7 @@ class TestNN(mlx_tests.MLXTestCase): ) self.assertTrue(x.shape == y.shape) - self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) # test_no_affine bn = nn.BatchNorm(num_features=4, affine=False) @@ -406,13 +405,13 @@ class TestNN(mlx_tests.MLXTestCase): ], ] ) - self.assertTrue(np.allclose(y, expected_y, atol=1e-5)) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) expected_mean = mx.array( [[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]] ) expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]]) - self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5)) - self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) x = mx.random.normal((N, L, C, L, C), dtype=mx.float32) with self.assertRaises(ValueError):