Add dropout2d (#250)

This commit is contained in:
Justin Deschenaux 2023-12-22 17:02:29 +01:00 committed by GitHub
parent 8385f93cea
commit e8deca84e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 3 deletions

View File

@ -8,7 +8,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
# Third-Party Software

View File

@ -27,3 +27,6 @@ Layers
MultiHeadAttention
Sequential
QuantizedLinear
Dropout
Dropout2d

View File

@ -33,7 +33,7 @@ from mlx.nn.layers.activations import (
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout
from mlx.nn.layers.dropout import Dropout, Dropout2d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm

View File

@ -32,4 +32,61 @@ class Dropout(Module):
mask = mx.random.bernoulli(self._p_1, x.shape)
return (1 / self._p_1) * mask.astype(x.dtype) * x
return (1 / self._p_1) * mask * x
class Dropout2d(Module):
"""Apply 2D channel-wise dropout during training.
Randomly zero out entire channels independently with probability :math:`p`.
This layer expects the channels to be last, i.e. the input shape should be
``NWHC`` or ``WHC`` where:
- ``N`` is the batch dimension
- ``H`` is the input image height
- ``W`` is the input image width
- ``C`` is the number of input channels
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
maintain the expected value of each element. Unlike traditional dropout,
which zeros individual entries, this layer zeros entire channels. This is
beneficial for early convolution layers where adjacent pixels are
correlated. In such case, traditional dropout may not effectively
regularize activations. For more details, see [1].
[1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.
Efficient Object Localization Using Convolutional Networks. CVPR 2015.
Args:
p (float): Probability of zeroing a channel during training.
"""
def __init__(self, p: float = 0.5):
super().__init__()
if p < 0 or p >= 1:
raise ValueError("The dropout probability should be in [0, 1)")
self._p_1 = 1 - p
def _extra_repr(self):
return f"p={1-self._p_1}"
def __call__(self, x):
if x.ndim not in (3, 4):
raise ValueError(
f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions."
)
if self._p_1 == 1 or not self.training:
return x
# Dropout is applied on the whole channel
# 3D input: (1, 1, C)
# 4D input: (B, 1, 1, C)
mask_shape = x.shape
mask_shape[-2] = mask_shape[-3] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x