mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
feat: Add Dropout3d layer to nn.layers (#313)
* feat: Add Dropout3d layer to nn.layers * acknowledgement * Add dropout tests to test_nn.py * run pre-commit * Add activation functions and dropout3d ops * Add dropout tests for bfloat16 and float16
This commit is contained in:
parent
99c20f523e
commit
e09bf35b28
@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions.
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added tri, tril, triu and safetensor support
|
||||
|
@ -28,6 +28,7 @@ Layers
|
||||
GroupNorm
|
||||
Dropout
|
||||
Dropout2d
|
||||
Dropout3d
|
||||
Transformer
|
||||
MultiHeadAttention
|
||||
ALiBi
|
||||
|
@ -43,7 +43,7 @@ from mlx.nn.layers.activations import (
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.containers import Sequential
|
||||
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
||||
from mlx.nn.layers.dropout import Dropout, Dropout2d
|
||||
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
||||
from mlx.nn.layers.embedding import Embedding
|
||||
from mlx.nn.layers.linear import Linear
|
||||
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
|
||||
|
@ -86,3 +86,52 @@ class Dropout2d(Module):
|
||||
|
||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||
return (1 / self._p_1) * mask * x
|
||||
|
||||
|
||||
class Dropout3d(Module):
|
||||
r"""Apply 3D channel-wise dropout during training.
|
||||
|
||||
Randomly zero out entire channels independently with probability :math:`p`.
|
||||
This layer expects the channels to be last, i.e., the input shape should be
|
||||
`NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
|
||||
`H` is the input image height, `W` is the input image width, and `C` is
|
||||
the number of input channels.
|
||||
|
||||
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
|
||||
maintain the expected value of each element. Unlike traditional dropout,
|
||||
which zeros individual entries, this layer zeros entire channels. This is
|
||||
often beneficial for convolutional layers processing 3D data, like in
|
||||
medical imaging or video processing.
|
||||
|
||||
Args:
|
||||
p (float): Probability of zeroing a channel during training.
|
||||
"""
|
||||
|
||||
def __init__(self, p: float = 0.5):
|
||||
super().__init__()
|
||||
|
||||
if p < 0 or p >= 1:
|
||||
raise ValueError(f"The dropout probability {p} is not in [0, 1)")
|
||||
|
||||
self._p_1 = 1 - p
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"p={1-self._p_1}"
|
||||
|
||||
def __call__(self, x):
|
||||
if x.ndim not in (4, 5):
|
||||
raise ValueError(
|
||||
f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions."
|
||||
)
|
||||
|
||||
if self._p_1 == 1 or not self.training:
|
||||
return x
|
||||
|
||||
# Dropout is applied on the whole channel
|
||||
# 4D input: (1, 1, 1, C)
|
||||
# 5D input: (B, 1, 1, 1, C)
|
||||
mask_shape = list(x.shape)
|
||||
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
|
||||
|
||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||
return (1 / self._p_1) * mask * x
|
||||
|
@ -792,6 +792,54 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
|
||||
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
|
||||
|
||||
def test_dropout(self):
|
||||
x = mx.ones((2, 4))
|
||||
y = nn.Dropout(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float32)
|
||||
|
||||
x = mx.ones((2, 4), dtype=mx.bfloat16)
|
||||
y = nn.Dropout(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.bfloat16)
|
||||
|
||||
x = mx.ones((2, 4), dtype=mx.float16)
|
||||
y = nn.Dropout(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_dropout2d(self):
|
||||
x = mx.ones((2, 4, 4, 4))
|
||||
y = nn.Dropout2d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float32)
|
||||
|
||||
x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16)
|
||||
y = nn.Dropout2d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.bfloat16)
|
||||
|
||||
x = mx.ones((2, 4, 4, 4), dtype=mx.float16)
|
||||
y = nn.Dropout2d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_dropout3d(self):
|
||||
x = mx.ones((2, 4, 4, 4, 4))
|
||||
y = nn.Dropout3d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float32)
|
||||
|
||||
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16)
|
||||
y = nn.Dropout3d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.bfloat16)
|
||||
|
||||
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16)
|
||||
y = nn.Dropout3d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user