From a098bc92e0083bf3613559ba3990ca52ed8631c6 Mon Sep 17 00:00:00 2001 From: Bhargav Yagnik <41851612+bhargavyagnik@users.noreply.github.com> Date: Tue, 13 Aug 2024 14:54:21 -0400 Subject: [PATCH] Fix: Preserve input dtype in Dropout layer output (#1323) * Fix: Preserve input dtype in Dropout layer output - Modified Dropout implementation to ensure that the output dtype matches the input dtype. - This resolves the issue #1321 * Update test cases in test_nn.py - Revised test cases to align with updated dropout code - Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout, * updated dropout.py --- python/mlx/nn/layers/dropout.py | 6 ++-- python/tests/test_nn.py | 52 ++++++++++++++++----------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index 7008547c0..657f8c47a 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -32,7 +32,7 @@ class Dropout(Module): mask = mx.random.bernoulli(self._p_1, x.shape) - return (1 / self._p_1) * mask * x + return (mask * x) * (1 / self._p_1) class Dropout2d(Module): @@ -85,7 +85,7 @@ class Dropout2d(Module): mask_shape[-2] = mask_shape[-3] = 1 mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) - return (1 / self._p_1) * mask * x + return (mask * x) * (1 / self._p_1) class Dropout3d(Module): @@ -134,4 +134,4 @@ class Dropout3d(Module): mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1 mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) - return (1 / self._p_1) * mask * x + return (mask * x) * (1 / self._p_1) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index f3d47cf26..38659625f 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -939,73 +939,73 @@ class TestLayers(mlx_tests.MLXTestCase): shape = (1, 3, 4) x = mx.random.uniform(shape=shape) y = rope(x) - self.assertTrue(y.shape, shape) - self.assertTrue(y.dtype, mx.float32) + self.assertEqual(y.shape, shape) + self.assertEqual(y.dtype, mx.float32) y = rope(x, offset=3) - self.assertTrue(y.shape, shape) + self.assertEqual(y.shape, shape) y = rope(x.astype(mx.float16)) - self.assertTrue(y.dtype, mx.float16) + self.assertEqual(y.dtype, mx.float16) def test_alibi(self): alibi = nn.ALiBi() - shape = [1, 8, 20, 20] + shape = (1, 8, 20, 20) x = mx.random.uniform(shape=shape) y = alibi(x) - self.assertTrue(y.shape, shape) - self.assertTrue(y.dtype, mx.float32) + self.assertEqual(y.shape, shape) + self.assertEqual(y.dtype, mx.float32) y = alibi(x.astype(mx.float16)) - self.assertTrue(y.dtype, mx.float16) + self.assertEqual(y.dtype, mx.float16) def test_dropout(self): x = mx.ones((2, 4)) y = nn.Dropout(0.5)(x) - self.assertTrue(y.shape, x.shape) - self.assertTrue(y.dtype, mx.float32) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, mx.float32) x = mx.ones((2, 4), dtype=mx.bfloat16) y = nn.Dropout(0.5)(x) - self.assertTrue(y.shape, x.shape) - self.assertTrue(y.dtype, mx.bfloat16) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, mx.bfloat16) x = mx.ones((2, 4), dtype=mx.float16) y = nn.Dropout(0.5)(x) - self.assertTrue(y.shape, x.shape) - self.assertTrue(y.dtype, mx.float16) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, mx.float16) def test_dropout2d(self): x = mx.ones((2, 4, 4, 4)) y = nn.Dropout2d(0.5)(x) - self.assertTrue(y.shape, x.shape) - self.assertTrue(y.dtype, mx.float32) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, mx.float32) x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16) y = nn.Dropout2d(0.5)(x) - self.assertTrue(y.shape, x.shape) - self.assertTrue(y.dtype, mx.bfloat16) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, mx.bfloat16) x = mx.ones((2, 4, 4, 4), dtype=mx.float16) y = nn.Dropout2d(0.5)(x) - self.assertTrue(y.shape, x.shape) - self.assertTrue(y.dtype, mx.float16) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, mx.float16) def test_dropout3d(self): x = mx.ones((2, 4, 4, 4, 4)) y = nn.Dropout3d(0.5)(x) - self.assertTrue(y.shape, x.shape) - self.assertTrue(y.dtype, mx.float32) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, mx.float32) x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16) y = nn.Dropout3d(0.5)(x) - self.assertTrue(y.shape, x.shape) - self.assertTrue(y.dtype, mx.bfloat16) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, mx.bfloat16) x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16) y = nn.Dropout3d(0.5)(x) - self.assertTrue(y.shape, x.shape) - self.assertTrue(y.dtype, mx.float16) + self.assertEqual(y.shape, x.shape) + self.assertEqual(y.dtype, mx.float16) def test_upsample(self): b, h, w, c = 1, 2, 2, 1