mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
feat: Add Dropout3d layer to nn.layers (#313)
* feat: Add Dropout3d layer to nn.layers * acknowledgement * Add dropout tests to test_nn.py * run pre-commit * Add activation functions and dropout3d ops * Add dropout tests for bfloat16 and float16
This commit is contained in:
@@ -792,6 +792,54 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
|
||||
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
|
||||
|
||||
def test_dropout(self):
|
||||
x = mx.ones((2, 4))
|
||||
y = nn.Dropout(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float32)
|
||||
|
||||
x = mx.ones((2, 4), dtype=mx.bfloat16)
|
||||
y = nn.Dropout(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.bfloat16)
|
||||
|
||||
x = mx.ones((2, 4), dtype=mx.float16)
|
||||
y = nn.Dropout(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_dropout2d(self):
|
||||
x = mx.ones((2, 4, 4, 4))
|
||||
y = nn.Dropout2d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float32)
|
||||
|
||||
x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16)
|
||||
y = nn.Dropout2d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.bfloat16)
|
||||
|
||||
x = mx.ones((2, 4, 4, 4), dtype=mx.float16)
|
||||
y = nn.Dropout2d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_dropout3d(self):
|
||||
x = mx.ones((2, 4, 4, 4, 4))
|
||||
y = nn.Dropout3d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float32)
|
||||
|
||||
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16)
|
||||
y = nn.Dropout3d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.bfloat16)
|
||||
|
||||
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16)
|
||||
y = nn.Dropout3d(0.5)(x)
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user