feat: Add Dropout3d layer to nn.layers (#313)

* feat: Add Dropout3d layer to nn.layers

* acknowledgement

* Add dropout tests to test_nn.py

* run pre-commit

* Add activation functions and dropout3d ops

* Add dropout tests for bfloat16 and float16
This commit is contained in:
Nripesh Niketan
2024-01-01 02:01:21 +04:00
committed by GitHub
parent 99c20f523e
commit e09bf35b28
5 changed files with 100 additions and 2 deletions

View File

@@ -43,7 +43,7 @@ from mlx.nn.layers.activations import (
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout, Dropout2d
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm

View File

@@ -86,3 +86,52 @@ class Dropout2d(Module):
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x
class Dropout3d(Module):
r"""Apply 3D channel-wise dropout during training.
Randomly zero out entire channels independently with probability :math:`p`.
This layer expects the channels to be last, i.e., the input shape should be
`NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
`H` is the input image height, `W` is the input image width, and `C` is
the number of input channels.
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
maintain the expected value of each element. Unlike traditional dropout,
which zeros individual entries, this layer zeros entire channels. This is
often beneficial for convolutional layers processing 3D data, like in
medical imaging or video processing.
Args:
p (float): Probability of zeroing a channel during training.
"""
def __init__(self, p: float = 0.5):
super().__init__()
if p < 0 or p >= 1:
raise ValueError(f"The dropout probability {p} is not in [0, 1)")
self._p_1 = 1 - p
def _extra_repr(self):
return f"p={1-self._p_1}"
def __call__(self, x):
if x.ndim not in (4, 5):
raise ValueError(
f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions."
)
if self._p_1 == 1 or not self.training:
return x
# Dropout is applied on the whole channel
# 4D input: (1, 1, 1, C)
# 5D input: (B, 1, 1, 1, C)
mask_shape = list(x.shape)
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x