Fix: Preserve input dtype in Dropout layer output (#1323)

* Fix: Preserve input dtype in Dropout layer output

- Modified Dropout implementation to ensure that the output dtype matches the input dtype.
- This resolves the issue #1321

* Update test cases in test_nn.py

- Revised test cases to align with updated dropout code
- Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout,

* updated dropout.py
This commit is contained in:
Bhargav Yagnik 2024-08-13 14:54:21 -04:00 committed by GitHub
parent 1086dc4db0
commit a098bc92e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 29 deletions

View File

@ -32,7 +32,7 @@ class Dropout(Module):
mask = mx.random.bernoulli(self._p_1, x.shape)
return (1 / self._p_1) * mask * x
return (mask * x) * (1 / self._p_1)
class Dropout2d(Module):
@ -85,7 +85,7 @@ class Dropout2d(Module):
mask_shape[-2] = mask_shape[-3] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x
return (mask * x) * (1 / self._p_1)
class Dropout3d(Module):
@ -134,4 +134,4 @@ class Dropout3d(Module):
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
return (1 / self._p_1) * mask * x
return (mask * x) * (1 / self._p_1)

View File

@ -939,73 +939,73 @@ class TestLayers(mlx_tests.MLXTestCase):
shape = (1, 3, 4)
x = mx.random.uniform(shape=shape)
y = rope(x)
self.assertTrue(y.shape, shape)
self.assertTrue(y.dtype, mx.float32)
self.assertEqual(y.shape, shape)
self.assertEqual(y.dtype, mx.float32)
y = rope(x, offset=3)
self.assertTrue(y.shape, shape)
self.assertEqual(y.shape, shape)
y = rope(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16)
self.assertEqual(y.dtype, mx.float16)
def test_alibi(self):
alibi = nn.ALiBi()
shape = [1, 8, 20, 20]
shape = (1, 8, 20, 20)
x = mx.random.uniform(shape=shape)
y = alibi(x)
self.assertTrue(y.shape, shape)
self.assertTrue(y.dtype, mx.float32)
self.assertEqual(y.shape, shape)
self.assertEqual(y.dtype, mx.float32)
y = alibi(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16)
self.assertEqual(y.dtype, mx.float16)
def test_dropout(self):
x = mx.ones((2, 4))
y = nn.Dropout(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float32)
self.assertEqual(y.shape, x.shape)
self.assertEqual(y.dtype, mx.float32)
x = mx.ones((2, 4), dtype=mx.bfloat16)
y = nn.Dropout(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.bfloat16)
self.assertEqual(y.shape, x.shape)
self.assertEqual(y.dtype, mx.bfloat16)
x = mx.ones((2, 4), dtype=mx.float16)
y = nn.Dropout(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
self.assertEqual(y.shape, x.shape)
self.assertEqual(y.dtype, mx.float16)
def test_dropout2d(self):
x = mx.ones((2, 4, 4, 4))
y = nn.Dropout2d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float32)
self.assertEqual(y.shape, x.shape)
self.assertEqual(y.dtype, mx.float32)
x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16)
y = nn.Dropout2d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.bfloat16)
self.assertEqual(y.shape, x.shape)
self.assertEqual(y.dtype, mx.bfloat16)
x = mx.ones((2, 4, 4, 4), dtype=mx.float16)
y = nn.Dropout2d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
self.assertEqual(y.shape, x.shape)
self.assertEqual(y.dtype, mx.float16)
def test_dropout3d(self):
x = mx.ones((2, 4, 4, 4, 4))
y = nn.Dropout3d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float32)
self.assertEqual(y.shape, x.shape)
self.assertEqual(y.dtype, mx.float32)
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16)
y = nn.Dropout3d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.bfloat16)
self.assertEqual(y.shape, x.shape)
self.assertEqual(y.dtype, mx.bfloat16)
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16)
y = nn.Dropout3d(0.5)(x)
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
self.assertEqual(y.shape, x.shape)
self.assertEqual(y.dtype, mx.float16)
def test_upsample(self):
b, h, w, c = 1, 2, 2, 1