mlx/python
Bhargav Yagnik a098bc92e0
Fix: Preserve input dtype in Dropout layer output (#1323)
* Fix: Preserve input dtype in Dropout layer output

- Modified Dropout implementation to ensure that the output dtype matches the input dtype.
- This resolves the issue #1321

* Update test cases in test_nn.py

- Revised test cases to align with updated dropout code
- Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout,

* updated dropout.py
2024-08-13 11:54:21 -07:00
..
mlx Fix: Preserve input dtype in Dropout layer output (#1323) 2024-08-13 11:54:21 -07:00
src Add memory_efficient_threshold kwarg to sdpa kernel (#1319) 2024-08-12 12:57:09 -07:00
tests Fix: Preserve input dtype in Dropout layer output (#1323) 2024-08-13 11:54:21 -07:00