This commit is contained in:
Awni Hannun 2023-12-25 07:07:24 -08:00
parent 15577cb727
commit 865e53fcab
4 changed files with 41 additions and 36 deletions

View File

@ -5,7 +5,7 @@ from mlx.nn.layers.base import Module
class Dropout(Module):
"""Randomly zero a portion of the elements during training.
r"""Randomly zero a portion of the elements during training.
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
:math:`p` is the probability of zeroing an element. This is done so the
@ -36,15 +36,13 @@ class Dropout(Module):
class Dropout2d(Module):
"""Apply 2D channel-wise dropout during training.
r"""Apply 2D channel-wise dropout during training.
Randomly zero out entire channels independently with probability :math:`p`.
This layer expects the channels to be last, i.e. the input shape should be
``NWHC`` or ``WHC`` where:
- ``N`` is the batch dimension
- ``H`` is the input image height
- ``W`` is the input image width
- ``C`` is the number of input channels
``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
image height,``W`` is the input image width, and``C`` is the number of
input channels
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
maintain the expected value of each element. Unlike traditional dropout,

View File

@ -194,24 +194,31 @@ class BatchNorm(Module):
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
parameters initialized at 1 and 0 respectively.
[1]: https://arxiv.org/abs/1502.03167
The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the
batch, ``C`` is the number of features or channels, and ``L`` is the
sequence length. The output has the same shape as the input. For
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
the height and width respecitvely.
The input tensor shape is specified as (N, C) or (N, L, C), representing the batch size (N), the number of features or channels (C), and optionally, the sequence length (L). The output tensor maintains the same shape as the input, adhering to (N, C) or (N, L, C).
For three-dimensional tensors, the shape is denoted as (N, H, W, C), where N signifies the batch size, C represents the number of channels, H corresponds to the height, and W denotes the width.
For more information on Batch Normalization, see the original paper `Batch
Normalization: Accelerating Deep Network Training by Reducing Internal
Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
Args:
num_features (int): The feature dimension of the input to normalize over.
eps (float, optional): A small additive constant for numerical stability. Default is 1e-5.
momentum (float, optional): The momentum for updating the running mean and variance. Default is 0.1.
affine (bool, optional): If True, learn an affine transform to apply after the normalization. Default is True.
track_running_stats (bool, optional): If True, track the running mean and variance. Default is True.
num_features (int): The feature dimension to normalize over.
eps (float, optional): A small additive constant for numerical
stability. Default: ``1e-5``.
momentum (float, optional): The momentum for updating the running
mean and variance. Default: ``0.1``.
affine (bool, optional): If ``True``, apply a learned affine
transformation after the normalization. Default: ``True``.
track_running_stats (bool, optional): If ``True``, track the
running mean and variance. Default: ``True``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> mx.random.seed(42)
>>> input = mx.random.normal((5, 4), dtype=mx.float32)
>>> # Batch norm
>>> x = mx.random.normal((5, 4))
>>> bn = nn.BatchNorm(num_features=4, affine=True)
>>> output = bn(x)
"""
@ -229,10 +236,9 @@ class BatchNorm(Module):
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
if affine:
self.weight = mx.ones((num_features,))
self.bias = mx.zeros((num_features,))
@ -241,7 +247,11 @@ class BatchNorm(Module):
self._running_var = mx.ones((num_features,))
def _extra_repr(self):
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"
return (
f"{self.num_features}, eps={self.eps}, "
f"momentum={self.momentum}, affine={'weight' in self}, "
f"track_running_stats={self.track_running_stats}"
)
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
"""
@ -253,9 +263,7 @@ class BatchNorm(Module):
Returns:
tuple: Tuple containing mean and variance.
"""
reduction_axes = (
(0,) if len(x.shape) == 2 else (0, 1) if len(x.shape) == 3 else (0, 1, 2)
)
reduction_axes = tuple(range(0, x.ndim - 1))
means = mx.mean(x, axis=reduction_axes, keepdims=True)
var = mx.var(x, axis=reduction_axes, keepdims=True)
@ -279,7 +287,7 @@ class BatchNorm(Module):
mx.array: Output tensor.
"""
if x.ndim not in [2, 3, 4]:
if x.ndim < 2 or x.ndim > 4:
raise ValueError(
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
)

View File

@ -286,7 +286,7 @@ def _reduce(loss: mx.array, reduction: str = "none"):
def hinge_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
"""
r"""
Computes the hinge loss between inputs and targets.
.. math::
@ -311,7 +311,7 @@ def hinge_loss(
def huber_loss(
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
) -> mx.array:
"""
r"""
Computes the Huber loss between inputs and targets.
.. math::
@ -345,7 +345,7 @@ def huber_loss(
def log_cosh_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
"""
r"""
Computes the log cosh loss between inputs and targets.
Logcosh acts like L2 loss for small errors, ensuring stable gradients,

View File

@ -3,7 +3,6 @@
import os
import tempfile
import unittest
from unittest.mock import Mock, patch
import mlx.core as mx
import mlx.nn as nn
@ -342,9 +341,9 @@ class TestNN(mlx_tests.MLXTestCase):
expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778])
expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258])
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5))
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
# test eval mode
bn.eval()
@ -360,7 +359,7 @@ class TestNN(mlx_tests.MLXTestCase):
)
self.assertTrue(x.shape == y.shape)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
# test_no_affine
bn = nn.BatchNorm(num_features=4, affine=False)
@ -406,13 +405,13 @@ class TestNN(mlx_tests.MLXTestCase):
],
]
)
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
self.assertTrue(mx.allclose(y, expected_y, atol=1e-5))
expected_mean = mx.array(
[[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]]
)
expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]])
self.assertTrue(np.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(np.allclose(bn._running_var, expected_var, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5))
self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5))
x = mx.random.normal((N, L, C, L, C), dtype=mx.float32)
with self.assertRaises(ValueError):