mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Fix: Preserve input dtype in Dropout layer output (#1323)
* Fix: Preserve input dtype in Dropout layer output - Modified Dropout implementation to ensure that the output dtype matches the input dtype. - This resolves the issue #1321 * Update test cases in test_nn.py - Revised test cases to align with updated dropout code - Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout, * updated dropout.py
This commit is contained in:
		| @@ -32,7 +32,7 @@ class Dropout(Module): | ||||
|  | ||||
|         mask = mx.random.bernoulli(self._p_1, x.shape) | ||||
|  | ||||
|         return (1 / self._p_1) * mask * x | ||||
|         return (mask * x) * (1 / self._p_1) | ||||
|  | ||||
|  | ||||
| class Dropout2d(Module): | ||||
| @@ -85,7 +85,7 @@ class Dropout2d(Module): | ||||
|         mask_shape[-2] = mask_shape[-3] = 1 | ||||
|  | ||||
|         mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) | ||||
|         return (1 / self._p_1) * mask * x | ||||
|         return (mask * x) * (1 / self._p_1) | ||||
|  | ||||
|  | ||||
| class Dropout3d(Module): | ||||
| @@ -134,4 +134,4 @@ class Dropout3d(Module): | ||||
|         mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1 | ||||
|  | ||||
|         mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) | ||||
|         return (1 / self._p_1) * mask * x | ||||
|         return (mask * x) * (1 / self._p_1) | ||||
|   | ||||
| @@ -939,73 +939,73 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|             shape = (1, 3, 4) | ||||
|             x = mx.random.uniform(shape=shape) | ||||
|             y = rope(x) | ||||
|             self.assertTrue(y.shape, shape) | ||||
|             self.assertTrue(y.dtype, mx.float32) | ||||
|             self.assertEqual(y.shape, shape) | ||||
|             self.assertEqual(y.dtype, mx.float32) | ||||
|  | ||||
|             y = rope(x, offset=3) | ||||
|             self.assertTrue(y.shape, shape) | ||||
|             self.assertEqual(y.shape, shape) | ||||
|  | ||||
|             y = rope(x.astype(mx.float16)) | ||||
|             self.assertTrue(y.dtype, mx.float16) | ||||
|             self.assertEqual(y.dtype, mx.float16) | ||||
|  | ||||
|     def test_alibi(self): | ||||
|         alibi = nn.ALiBi() | ||||
|         shape = [1, 8, 20, 20] | ||||
|         shape = (1, 8, 20, 20) | ||||
|         x = mx.random.uniform(shape=shape) | ||||
|         y = alibi(x) | ||||
|         self.assertTrue(y.shape, shape) | ||||
|         self.assertTrue(y.dtype, mx.float32) | ||||
|         self.assertEqual(y.shape, shape) | ||||
|         self.assertEqual(y.dtype, mx.float32) | ||||
|  | ||||
|         y = alibi(x.astype(mx.float16)) | ||||
|         self.assertTrue(y.dtype, mx.float16) | ||||
|         self.assertEqual(y.dtype, mx.float16) | ||||
|  | ||||
|     def test_dropout(self): | ||||
|         x = mx.ones((2, 4)) | ||||
|         y = nn.Dropout(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float32) | ||||
|         self.assertEqual(y.shape, x.shape) | ||||
|         self.assertEqual(y.dtype, mx.float32) | ||||
|  | ||||
|         x = mx.ones((2, 4), dtype=mx.bfloat16) | ||||
|         y = nn.Dropout(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.bfloat16) | ||||
|         self.assertEqual(y.shape, x.shape) | ||||
|         self.assertEqual(y.dtype, mx.bfloat16) | ||||
|  | ||||
|         x = mx.ones((2, 4), dtype=mx.float16) | ||||
|         y = nn.Dropout(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float16) | ||||
|         self.assertEqual(y.shape, x.shape) | ||||
|         self.assertEqual(y.dtype, mx.float16) | ||||
|  | ||||
|     def test_dropout2d(self): | ||||
|         x = mx.ones((2, 4, 4, 4)) | ||||
|         y = nn.Dropout2d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float32) | ||||
|         self.assertEqual(y.shape, x.shape) | ||||
|         self.assertEqual(y.dtype, mx.float32) | ||||
|  | ||||
|         x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16) | ||||
|         y = nn.Dropout2d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.bfloat16) | ||||
|         self.assertEqual(y.shape, x.shape) | ||||
|         self.assertEqual(y.dtype, mx.bfloat16) | ||||
|  | ||||
|         x = mx.ones((2, 4, 4, 4), dtype=mx.float16) | ||||
|         y = nn.Dropout2d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float16) | ||||
|         self.assertEqual(y.shape, x.shape) | ||||
|         self.assertEqual(y.dtype, mx.float16) | ||||
|  | ||||
|     def test_dropout3d(self): | ||||
|         x = mx.ones((2, 4, 4, 4, 4)) | ||||
|         y = nn.Dropout3d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float32) | ||||
|         self.assertEqual(y.shape, x.shape) | ||||
|         self.assertEqual(y.dtype, mx.float32) | ||||
|  | ||||
|         x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16) | ||||
|         y = nn.Dropout3d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.bfloat16) | ||||
|         self.assertEqual(y.shape, x.shape) | ||||
|         self.assertEqual(y.dtype, mx.bfloat16) | ||||
|  | ||||
|         x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16) | ||||
|         y = nn.Dropout3d(0.5)(x) | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float16) | ||||
|         self.assertEqual(y.shape, x.shape) | ||||
|         self.assertEqual(y.dtype, mx.float16) | ||||
|  | ||||
|     def test_upsample(self): | ||||
|         b, h, w, c = 1, 2, 2, 1 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Bhargav Yagnik
					Bhargav Yagnik