mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	feat: Add Dropout3d layer to nn.layers (#313)
* feat: Add Dropout3d layer to nn.layers * acknowledgement * Add dropout tests to test_nn.py * run pre-commit * Add activation functions and dropout3d ops * Add dropout tests for bfloat16 and float16
This commit is contained in:
		| @@ -43,7 +43,7 @@ from mlx.nn.layers.activations import ( | ||||
| from mlx.nn.layers.base import Module | ||||
| from mlx.nn.layers.containers import Sequential | ||||
| from mlx.nn.layers.convolution import Conv1d, Conv2d | ||||
| from mlx.nn.layers.dropout import Dropout, Dropout2d | ||||
| from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d | ||||
| from mlx.nn.layers.embedding import Embedding | ||||
| from mlx.nn.layers.linear import Linear | ||||
| from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm | ||||
|   | ||||
| @@ -86,3 +86,52 @@ class Dropout2d(Module): | ||||
|  | ||||
|         mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) | ||||
|         return (1 / self._p_1) * mask * x | ||||
|  | ||||
|  | ||||
| class Dropout3d(Module): | ||||
|     r"""Apply 3D channel-wise dropout during training. | ||||
|  | ||||
|     Randomly zero out entire channels independently with probability :math:`p`. | ||||
|     This layer expects the channels to be last, i.e., the input shape should be | ||||
|     `NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth, | ||||
|     `H` is the input image height, `W` is the input image width, and `C` is | ||||
|     the number of input channels. | ||||
|  | ||||
|     The remaining channels are scaled by :math:`\frac{1}{1-p}` to | ||||
|     maintain the expected value of each element. Unlike traditional dropout, | ||||
|     which zeros individual entries, this layer zeros entire channels. This is | ||||
|     often beneficial for convolutional layers processing 3D data, like in | ||||
|     medical imaging or video processing. | ||||
|  | ||||
|     Args: | ||||
|         p (float): Probability of zeroing a channel during training. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, p: float = 0.5): | ||||
|         super().__init__() | ||||
|  | ||||
|         if p < 0 or p >= 1: | ||||
|             raise ValueError(f"The dropout probability {p} is not in [0, 1)") | ||||
|  | ||||
|         self._p_1 = 1 - p | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return f"p={1-self._p_1}" | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         if x.ndim not in (4, 5): | ||||
|             raise ValueError( | ||||
|                 f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions." | ||||
|             ) | ||||
|  | ||||
|         if self._p_1 == 1 or not self.training: | ||||
|             return x | ||||
|  | ||||
|         # Dropout is applied on the whole channel | ||||
|         # 4D input: (1, 1, 1, C) | ||||
|         # 5D input: (B, 1, 1, 1, C) | ||||
|         mask_shape = list(x.shape) | ||||
|         mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1 | ||||
|  | ||||
|         mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) | ||||
|         return (1 / self._p_1) * mask * x | ||||
|   | ||||
| @@ -792,6 +792,54 @@ class TestNN(mlx_tests.MLXTestCase): | ||||
|         loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") | ||||
|         self.assertAlmostEqual(loss.item(), 0.433781, places=6) | ||||
|  | ||||
|     def test_dropout(self): | ||||
|         x = mx.ones((2, 4)) | ||||
|         y = nn.Dropout(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float32) | ||||
|  | ||||
|         x = mx.ones((2, 4), dtype=mx.bfloat16) | ||||
|         y = nn.Dropout(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.bfloat16) | ||||
|  | ||||
|         x = mx.ones((2, 4), dtype=mx.float16) | ||||
|         y = nn.Dropout(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float16) | ||||
|  | ||||
|     def test_dropout2d(self): | ||||
|         x = mx.ones((2, 4, 4, 4)) | ||||
|         y = nn.Dropout2d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float32) | ||||
|  | ||||
|         x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16) | ||||
|         y = nn.Dropout2d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.bfloat16) | ||||
|  | ||||
|         x = mx.ones((2, 4, 4, 4), dtype=mx.float16) | ||||
|         y = nn.Dropout2d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float16) | ||||
|  | ||||
|     def test_dropout3d(self): | ||||
|         x = mx.ones((2, 4, 4, 4, 4)) | ||||
|         y = nn.Dropout3d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float32) | ||||
|  | ||||
|         x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16) | ||||
|         y = nn.Dropout3d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.bfloat16) | ||||
|  | ||||
|         x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16) | ||||
|         y = nn.Dropout3d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float16) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Nripesh Niketan
					Nripesh Niketan