Merge branch 'ml-explore:main' into feature_expand_nn_linear

This commit is contained in:
Asaf Zorea 2024-01-01 00:40:24 +02:00 committed by GitHub
commit b36b9017b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 100 additions and 2 deletions

View File

@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added tri, tril, triu and safetensor support

View File

@ -28,6 +28,7 @@ Layers
GroupNorm
Dropout
Dropout2d
Dropout3d
Transformer
MultiHeadAttention
ALiBi

View File

@ -43,7 +43,7 @@ from mlx.nn.layers.activations import (
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential
from mlx.nn.layers.convolution import Conv1d, Conv2d
from mlx.nn.layers.dropout import Dropout, Dropout2d
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
from mlx.nn.layers.embedding import Embedding
from mlx.nn.layers.linear import Bilinear, Identity, Linear
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm

View File

@ -86,3 +86,52 @@ class Dropout2d(Module):
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x
class Dropout3d(Module):
r"""Apply 3D channel-wise dropout during training.
Randomly zero out entire channels independently with probability :math:`p`.
This layer expects the channels to be last, i.e., the input shape should be
`NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
`H` is the input image height, `W` is the input image width, and `C` is
the number of input channels.
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
maintain the expected value of each element. Unlike traditional dropout,
which zeros individual entries, this layer zeros entire channels. This is
often beneficial for convolutional layers processing 3D data, like in
medical imaging or video processing.
Args:
p (float): Probability of zeroing a channel during training.
"""
def __init__(self, p: float = 0.5):
super().__init__()
if p < 0 or p >= 1:
raise ValueError(f"The dropout probability {p} is not in [0, 1)")
self._p_1 = 1 - p
def _extra_repr(self):
return f"p={1-self._p_1}"
def __call__(self, x):
if x.ndim not in (4, 5):
raise ValueError(
f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions."
)
if self._p_1 == 1 or not self.training:
return x
# Dropout is applied on the whole channel
# 4D input: (1, 1, 1, C)
# 5D input: (B, 1, 1, 1, C)
mask_shape = list(x.shape)
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x

View File

@ -805,6 +805,54 @@ class TestNN(mlx_tests.MLXTestCase):
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
def test_dropout(self):
x = mx.ones((2, 4))
y = nn.Dropout(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float32)
x = mx.ones((2, 4), dtype=mx.bfloat16)
y = nn.Dropout(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.bfloat16)
x = mx.ones((2, 4), dtype=mx.float16)
y = nn.Dropout(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
def test_dropout2d(self):
x = mx.ones((2, 4, 4, 4))
y = nn.Dropout2d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float32)
x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16)
y = nn.Dropout2d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.bfloat16)
x = mx.ones((2, 4, 4, 4), dtype=mx.float16)
y = nn.Dropout2d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
def test_dropout3d(self):
x = mx.ones((2, 4, 4, 4, 4))
y = nn.Dropout3d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float32)
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16)
y = nn.Dropout3d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.bfloat16)
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16)
y = nn.Dropout3d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
if __name__ == "__main__":
unittest.main()