mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix: Preserve input dtype in Dropout layer output (#1323)
* Fix: Preserve input dtype in Dropout layer output - Modified Dropout implementation to ensure that the output dtype matches the input dtype. - This resolves the issue #1321 * Update test cases in test_nn.py - Revised test cases to align with updated dropout code - Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout, * updated dropout.py
This commit is contained in:
parent
1086dc4db0
commit
a098bc92e0
@ -32,7 +32,7 @@ class Dropout(Module):
|
|||||||
|
|
||||||
mask = mx.random.bernoulli(self._p_1, x.shape)
|
mask = mx.random.bernoulli(self._p_1, x.shape)
|
||||||
|
|
||||||
return (1 / self._p_1) * mask * x
|
return (mask * x) * (1 / self._p_1)
|
||||||
|
|
||||||
|
|
||||||
class Dropout2d(Module):
|
class Dropout2d(Module):
|
||||||
@ -85,7 +85,7 @@ class Dropout2d(Module):
|
|||||||
mask_shape[-2] = mask_shape[-3] = 1
|
mask_shape[-2] = mask_shape[-3] = 1
|
||||||
|
|
||||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||||
return (1 / self._p_1) * mask * x
|
return (mask * x) * (1 / self._p_1)
|
||||||
|
|
||||||
|
|
||||||
class Dropout3d(Module):
|
class Dropout3d(Module):
|
||||||
@ -134,4 +134,4 @@ class Dropout3d(Module):
|
|||||||
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
|
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
|
||||||
|
|
||||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||||
return (1 / self._p_1) * mask * x
|
return (mask * x) * (1 / self._p_1)
|
||||||
|
@ -939,73 +939,73 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
shape = (1, 3, 4)
|
shape = (1, 3, 4)
|
||||||
x = mx.random.uniform(shape=shape)
|
x = mx.random.uniform(shape=shape)
|
||||||
y = rope(x)
|
y = rope(x)
|
||||||
self.assertTrue(y.shape, shape)
|
self.assertEqual(y.shape, shape)
|
||||||
self.assertTrue(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
y = rope(x, offset=3)
|
y = rope(x, offset=3)
|
||||||
self.assertTrue(y.shape, shape)
|
self.assertEqual(y.shape, shape)
|
||||||
|
|
||||||
y = rope(x.astype(mx.float16))
|
y = rope(x.astype(mx.float16))
|
||||||
self.assertTrue(y.dtype, mx.float16)
|
self.assertEqual(y.dtype, mx.float16)
|
||||||
|
|
||||||
def test_alibi(self):
|
def test_alibi(self):
|
||||||
alibi = nn.ALiBi()
|
alibi = nn.ALiBi()
|
||||||
shape = [1, 8, 20, 20]
|
shape = (1, 8, 20, 20)
|
||||||
x = mx.random.uniform(shape=shape)
|
x = mx.random.uniform(shape=shape)
|
||||||
y = alibi(x)
|
y = alibi(x)
|
||||||
self.assertTrue(y.shape, shape)
|
self.assertEqual(y.shape, shape)
|
||||||
self.assertTrue(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
y = alibi(x.astype(mx.float16))
|
y = alibi(x.astype(mx.float16))
|
||||||
self.assertTrue(y.dtype, mx.float16)
|
self.assertEqual(y.dtype, mx.float16)
|
||||||
|
|
||||||
def test_dropout(self):
|
def test_dropout(self):
|
||||||
x = mx.ones((2, 4))
|
x = mx.ones((2, 4))
|
||||||
y = nn.Dropout(0.5)(x)
|
y = nn.Dropout(0.5)(x)
|
||||||
self.assertTrue(y.shape, x.shape)
|
self.assertEqual(y.shape, x.shape)
|
||||||
self.assertTrue(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
x = mx.ones((2, 4), dtype=mx.bfloat16)
|
x = mx.ones((2, 4), dtype=mx.bfloat16)
|
||||||
y = nn.Dropout(0.5)(x)
|
y = nn.Dropout(0.5)(x)
|
||||||
self.assertTrue(y.shape, x.shape)
|
self.assertEqual(y.shape, x.shape)
|
||||||
self.assertTrue(y.dtype, mx.bfloat16)
|
self.assertEqual(y.dtype, mx.bfloat16)
|
||||||
|
|
||||||
x = mx.ones((2, 4), dtype=mx.float16)
|
x = mx.ones((2, 4), dtype=mx.float16)
|
||||||
y = nn.Dropout(0.5)(x)
|
y = nn.Dropout(0.5)(x)
|
||||||
self.assertTrue(y.shape, x.shape)
|
self.assertEqual(y.shape, x.shape)
|
||||||
self.assertTrue(y.dtype, mx.float16)
|
self.assertEqual(y.dtype, mx.float16)
|
||||||
|
|
||||||
def test_dropout2d(self):
|
def test_dropout2d(self):
|
||||||
x = mx.ones((2, 4, 4, 4))
|
x = mx.ones((2, 4, 4, 4))
|
||||||
y = nn.Dropout2d(0.5)(x)
|
y = nn.Dropout2d(0.5)(x)
|
||||||
self.assertTrue(y.shape, x.shape)
|
self.assertEqual(y.shape, x.shape)
|
||||||
self.assertTrue(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16)
|
x = mx.ones((2, 4, 4, 4), dtype=mx.bfloat16)
|
||||||
y = nn.Dropout2d(0.5)(x)
|
y = nn.Dropout2d(0.5)(x)
|
||||||
self.assertTrue(y.shape, x.shape)
|
self.assertEqual(y.shape, x.shape)
|
||||||
self.assertTrue(y.dtype, mx.bfloat16)
|
self.assertEqual(y.dtype, mx.bfloat16)
|
||||||
|
|
||||||
x = mx.ones((2, 4, 4, 4), dtype=mx.float16)
|
x = mx.ones((2, 4, 4, 4), dtype=mx.float16)
|
||||||
y = nn.Dropout2d(0.5)(x)
|
y = nn.Dropout2d(0.5)(x)
|
||||||
self.assertTrue(y.shape, x.shape)
|
self.assertEqual(y.shape, x.shape)
|
||||||
self.assertTrue(y.dtype, mx.float16)
|
self.assertEqual(y.dtype, mx.float16)
|
||||||
|
|
||||||
def test_dropout3d(self):
|
def test_dropout3d(self):
|
||||||
x = mx.ones((2, 4, 4, 4, 4))
|
x = mx.ones((2, 4, 4, 4, 4))
|
||||||
y = nn.Dropout3d(0.5)(x)
|
y = nn.Dropout3d(0.5)(x)
|
||||||
self.assertTrue(y.shape, x.shape)
|
self.assertEqual(y.shape, x.shape)
|
||||||
self.assertTrue(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16)
|
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.bfloat16)
|
||||||
y = nn.Dropout3d(0.5)(x)
|
y = nn.Dropout3d(0.5)(x)
|
||||||
self.assertTrue(y.shape, x.shape)
|
self.assertEqual(y.shape, x.shape)
|
||||||
self.assertTrue(y.dtype, mx.bfloat16)
|
self.assertEqual(y.dtype, mx.bfloat16)
|
||||||
|
|
||||||
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16)
|
x = mx.ones((2, 4, 4, 4, 4), dtype=mx.float16)
|
||||||
y = nn.Dropout3d(0.5)(x)
|
y = nn.Dropout3d(0.5)(x)
|
||||||
self.assertTrue(y.shape, x.shape)
|
self.assertEqual(y.shape, x.shape)
|
||||||
self.assertTrue(y.dtype, mx.float16)
|
self.assertEqual(y.dtype, mx.float16)
|
||||||
|
|
||||||
def test_upsample(self):
|
def test_upsample(self):
|
||||||
b, h, w, c = 1, 2, 2, 1
|
b, h, w, c = 1, 2, 2, 1
|
||||||
|
Loading…
Reference in New Issue
Block a user