diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index c9969f8d6..9e3b27532 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -8,7 +8,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: - Juarez Bochi: Fixed bug in cross attention. -- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example. +- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. # Third-Party Software diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index fab3ff785..4c7d4aa79 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -27,3 +27,6 @@ Layers MultiHeadAttention Sequential QuantizedLinear + Dropout + Dropout2d + diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 3f03064bf..d54e45f6d 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -33,7 +33,7 @@ from mlx.nn.layers.activations import ( from mlx.nn.layers.base import Module from mlx.nn.layers.containers import Sequential from mlx.nn.layers.convolution import Conv1d, Conv2d -from mlx.nn.layers.dropout import Dropout +from mlx.nn.layers.dropout import Dropout, Dropout2d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Linear from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index 3193cdbd7..e2cc981e2 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -32,4 +32,61 @@ class Dropout(Module): mask = mx.random.bernoulli(self._p_1, x.shape) - return (1 / self._p_1) * mask.astype(x.dtype) * x + return (1 / self._p_1) * mask * x + + +class Dropout2d(Module): + """Apply 2D channel-wise dropout during training. + + Randomly zero out entire channels independently with probability :math:`p`. + This layer expects the channels to be last, i.e. the input shape should be + ``NWHC`` or ``WHC`` where: + - ``N`` is the batch dimension + - ``H`` is the input image height + - ``W`` is the input image width + - ``C`` is the number of input channels + + The remaining channels are scaled by :math:`\frac{1}{1-p}` to + maintain the expected value of each element. Unlike traditional dropout, + which zeros individual entries, this layer zeros entire channels. This is + beneficial for early convolution layers where adjacent pixels are + correlated. In such case, traditional dropout may not effectively + regularize activations. For more details, see [1]. + + [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015. + Efficient Object Localization Using Convolutional Networks. CVPR 2015. + + Args: + p (float): Probability of zeroing a channel during training. + """ + + def __init__(self, p: float = 0.5): + super().__init__() + + if p < 0 or p >= 1: + raise ValueError("The dropout probability should be in [0, 1)") + + self._p_1 = 1 - p + + def _extra_repr(self): + return f"p={1-self._p_1}" + + def __call__(self, x): + if x.ndim not in (3, 4): + raise ValueError( + f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions." + ) + + if self._p_1 == 1 or not self.training: + return x + + # Dropout is applied on the whole channel + # 3D input: (1, 1, C) + # 4D input: (B, 1, 1, C) + mask_shape = x.shape + mask_shape[-2] = mask_shape[-3] = 1 + + mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) + return (1 / self._p_1) * mask * x + +