mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
Fix: Preserve input dtype in Dropout layer output (#1323)
* Fix: Preserve input dtype in Dropout layer output - Modified Dropout implementation to ensure that the output dtype matches the input dtype. - This resolves the issue #1321 * Update test cases in test_nn.py - Revised test cases to align with updated dropout code - Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout, * updated dropout.py
This commit is contained in:
@@ -32,7 +32,7 @@ class Dropout(Module):
|
||||
|
||||
mask = mx.random.bernoulli(self._p_1, x.shape)
|
||||
|
||||
return (1 / self._p_1) * mask * x
|
||||
return (mask * x) * (1 / self._p_1)
|
||||
|
||||
|
||||
class Dropout2d(Module):
|
||||
@@ -85,7 +85,7 @@ class Dropout2d(Module):
|
||||
mask_shape[-2] = mask_shape[-3] = 1
|
||||
|
||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||
return (1 / self._p_1) * mask * x
|
||||
return (mask * x) * (1 / self._p_1)
|
||||
|
||||
|
||||
class Dropout3d(Module):
|
||||
@@ -134,4 +134,4 @@ class Dropout3d(Module):
|
||||
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
|
||||
|
||||
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
||||
return (1 / self._p_1) * mask * x
|
||||
return (mask * x) * (1 / self._p_1)
|
||||
|
Reference in New Issue
Block a user