From e09bf35b28c9cdb06a057742c7027b9c41c4b3f5 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Mon, 1 Jan 2024 02:01:21 +0400 Subject: [PATCH] feat: Add Dropout3d layer to nn.layers (#313) * feat: Add Dropout3d layer to nn.layers * acknowledgement * Add dropout tests to test_nn.py * run pre-commit * Add activation functions and dropout3d ops * Add dropout tests for bfloat16 and float16 --- ACKNOWLEDGMENTS.md | 2 +- docs/src/python/nn/layers.rst | 1 + python/mlx/nn/layers/__init__.py | 2 +- python/mlx/nn/layers/dropout.py | 49 ++++++++++++++++++++++++++++++++ python/tests/test_nn.py | 48 +++++++++++++++++++++++++++++++ 5 files changed, 100 insertions(+), 2 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index ab3d5e1af..398f136b2 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: -- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. +- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added tri, tril, triu and safetensor support diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index aa59e0af2..7ead319fd 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -28,6 +28,7 @@ Layers GroupNorm Dropout Dropout2d + Dropout3d Transformer MultiHeadAttention ALiBi diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 84d2ceb9f..31bcc59dc 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -43,7 +43,7 @@ from mlx.nn.layers.activations import ( from mlx.nn.layers.base import Module from mlx.nn.layers.containers import Sequential from mlx.nn.layers.convolution import Conv1d, Conv2d -from mlx.nn.layers.dropout import Dropout, Dropout2d +from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Linear from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index f656f8db3..18b9b03a6 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -86,3 +86,52 @@ class Dropout2d(Module): mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) return (1 / self._p_1) * mask * x + + +class Dropout3d(Module): + r"""Apply 3D channel-wise dropout during training. + + Randomly zero out entire channels independently with probability :math:`p`. + This layer expects the channels to be last, i.e., the input shape should be + `NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth, + `H` is the input image height, `W` is the input image width, and `C` is + the number of input channels. + + The remaining channels are scaled by :math:`\frac{1}{1-p}` to + maintain the expected value of each element. Unlike traditional dropout, + which zeros individual entries, this layer zeros entire channels. This is + often beneficial for convolutional layers processing 3D data, like in + medical imaging or video processing. + + Args: + p (float): Probability of zeroing a channel during training. + """ + + def __init__(self, p: float = 0.5): + super().__init__() + + if p < 0 or p >= 1: + raise ValueError(f"The dropout probability {p} is not in [0, 1)") + + self._p_1 = 1 - p + + def _extra_repr(self): + return f"p={1-self._p_1}" + + def __call__(self, x): + if x.ndim not in (4, 5): + raise ValueError( + f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions." + ) + + if self._p_1 == 1 or not self.training: + return x + + # Dropout is applied on the whole channel + # 4D input: (1, 1, 1, C) + # 5D input: (B, 1, 1, 1, C) + mask_shape = list(x.shape) + mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1 + + mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) + return (1 / self._p_1) * mask * x diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 6557e7dbe..e620ad831 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -792,6 +792,54 @@ class TestNN(mlx_tests.MLXTestCase): loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") self.assertAlmostEqual(loss.item(), 0.433781, places=6) + def test_dropout(self): + x = mx.ones((2, 4)) + y = nn.Dropout(0.5)(x) + self.assertTrue(y.shape, x.shape) + self.assertTrue(y.dtype, mx.float32) + + x = mx.ones((2, 4), dtype=mx.bfloat16) + y = nn.Dropout(0.5)(x) + self.assertTrue(y.shape, x.shape) + self.assertTrue(y.dtype, mx.bfloat16) + + x = mx.ones((2, 4), dtype=mx.float16) + y = nn.Dropout(0.5)(x) + self.assertTrue(y.shape, x.shape) + self.assertTrue(y.dtype, mx.float16) + + def test_dropout2d(self): + x = mx.ones((2, 4, 4, 4)) + y = nn.Dropout2d(0.5)(x) + self.assertTrue(y.shape, x.shape) + self.assertTrue(y.dtype, mx.float32) + + x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16) + y = nn.Dropout2d(0.5)(x) + self.assertTrue(y.shape, x.shape) + self.assertTrue(y.dtype, mx.bfloat16) + + x = mx.ones((2, 4, 4, 4), dtype=mx.float16) + y = nn.Dropout2d(0.5)(x) + self.assertTrue(y.shape, x.shape) + self.assertTrue(y.dtype, mx.float16) + + def test_dropout3d(self): + x = mx.ones((2, 4, 4, 4, 4)) + y = nn.Dropout3d(0.5)(x) + self.assertTrue(y.shape, x.shape) + self.assertTrue(y.dtype, mx.float32) + + x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16) + y = nn.Dropout3d(0.5)(x) + self.assertTrue(y.shape, x.shape) + self.assertTrue(y.dtype, mx.bfloat16) + + x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16) + y = nn.Dropout3d(0.5)(x) + self.assertTrue(y.shape, x.shape) + self.assertTrue(y.dtype, mx.float16) + if __name__ == "__main__": unittest.main()