feat: Add Dropout3d layer to nn.layers (#313)

* feat: Add Dropout3d layer to nn.layers

* acknowledgement

* Add dropout tests to test_nn.py

* run pre-commit

* Add activation functions and dropout3d ops

* Add dropout tests for bfloat16 and float16
This commit is contained in:
Nripesh Niketan 2024-01-01 02:01:21 +04:00 committed by GitHub
parent 99c20f523e
commit e09bf35b28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 100 additions and 2 deletions

View File

@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added tri, tril, triu and safetensor support

View File

@ -28,6 +28,7 @@ Layers
GroupNorm
Dropout
Dropout2d
Dropout3d
Transformer
MultiHeadAttention
ALiBi

View File

@ -43,7 +43,7 @@ from mlx.nn.layers.activations import (
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout, Dropout2d
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm

View File

@ -86,3 +86,52 @@ class Dropout2d(Module):
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x
class Dropout3d(Module):
r"""Apply 3D channel-wise dropout during training.
Randomly zero out entire channels independently with probability :math:`p`.
This layer expects the channels to be last, i.e., the input shape should be
`NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
`H` is the input image height, `W` is the input image width, and `C` is
the number of input channels.
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
maintain the expected value of each element. Unlike traditional dropout,
which zeros individual entries, this layer zeros entire channels. This is
often beneficial for convolutional layers processing 3D data, like in
medical imaging or video processing.
Args:
p (float): Probability of zeroing a channel during training.
"""
def __init__(self, p: float = 0.5):
super().__init__()
if p < 0 or p >= 1:
raise ValueError(f"The dropout probability {p} is not in [0, 1)")
self._p_1 = 1 - p
def _extra_repr(self):
return f"p={1-self._p_1}"
def __call__(self, x):
if x.ndim not in (4, 5):
raise ValueError(
f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions."
)
if self._p_1 == 1 or not self.training:
return x
# Dropout is applied on the whole channel
# 4D input: (1, 1, 1, C)
# 5D input: (B, 1, 1, 1, C)
mask_shape = list(x.shape)
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x

View File

@ -792,6 +792,54 @@ class TestNN(mlx_tests.MLXTestCase):
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
def test_dropout(self):
x = mx.ones((2, 4))
y = nn.Dropout(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float32)
x = mx.ones((2, 4), dtype=mx.bfloat16)
y = nn.Dropout(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.bfloat16)
x = mx.ones((2, 4), dtype=mx.float16)
y = nn.Dropout(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
def test_dropout2d(self):
x = mx.ones((2, 4, 4, 4))
y = nn.Dropout2d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float32)
x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16)
y = nn.Dropout2d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.bfloat16)
x = mx.ones((2, 4, 4, 4), dtype=mx.float16)
y = nn.Dropout2d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
def test_dropout3d(self):
x = mx.ones((2, 4, 4, 4, 4))
y = nn.Dropout3d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float32)
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16)
y = nn.Dropout3d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.bfloat16)
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16)
y = nn.Dropout3d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
if __name__ == "__main__":
unittest.main()