mlx/python/tests
Bhargav Yagnik a098bc92e0
Fix: Preserve input dtype in Dropout layer output (#1323)
* Fix: Preserve input dtype in Dropout layer output

- Modified Dropout implementation to ensure that the output dtype matches the input dtype.
- This resolves the issue #1321

* Update test cases in test_nn.py

- Revised test cases to align with updated dropout code
- Fixed assertion method: replaced self.assertTrue with self.assertEqual for accurate comparisons in test_nn.py -> test_rope, test_alibi and test_dropout,

* updated dropout.py
2024-08-13 11:54:21 -07:00
..
mlx_tests.py Make MLX build on x64 macOS (#901) 2024-03-27 06:14:29 -07:00
mpi_test_distributed.py Add docs for the distributed namespace (#1184) 2024-06-06 11:37:00 -07:00
test_array.py Array api (#1289) 2024-07-26 10:40:49 -07:00
test_autograd.py Custom transforms (#1246) 2024-07-10 18:00:01 -07:00
test_bf16.py fix creating array from bf16 tensors in jax / torch (#1305) 2024-08-01 16:20:51 -07:00
test_blas.py Masked gemv (#1211) 2024-06-14 09:52:26 -07:00
test_compile.py fix donation condition for compilation (#1237) 2024-06-26 09:04:05 -07:00
test_constants.py Array api (#1289) 2024-07-26 10:40:49 -07:00
test_conv.py add bfloat conv for windograd (#1306) 2024-08-05 15:51:13 -07:00
test_device.py Adds device context manager (#679) 2024-02-14 14:14:58 -08:00
test_einsum.py Einsum (#1269) 2024-07-25 09:36:44 -07:00
test_eval.py Fix leak with multi-output primitives (#1274) 2024-07-23 06:34:18 -07:00
test_fast_sdpa.py Add memory_efficient_threshold kwarg to sdpa kernel (#1319) 2024-08-12 12:57:09 -07:00
test_fast.py Fix test tolerance and patch bump (#1315) 2024-08-08 14:51:09 -07:00
test_fft.py fix scatter + test (#1202) 2024-06-11 14:35:12 -07:00
test_graph.py Multi output primitives (#330) 2024-01-08 16:39:08 -08:00
test_init.py Make shape a tuple (#591) 2024-01-30 13:11:01 -08:00
test_linalg.py CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv (#1307) 2024-08-08 15:18:02 -07:00
test_load.py Fix logsumexp edge case (#740) 2024-02-25 08:39:55 -08:00
test_losses.py Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280) 2024-07-24 08:38:22 -07:00
test_metal.py Add softmin, hardshrink, hardtanh (#1180) 2024-06-04 15:48:18 -07:00
test_nn.py Fix: Preserve input dtype in Dropout layer output (#1323) 2024-08-13 11:54:21 -07:00
test_ops.py Add "edge" mode to mx.pad (#1309) 2024-08-06 11:23:10 -07:00
test_optimizers.py Treate 'minimum' differently in cosine decay (#1138) 2024-05-20 08:00:48 -07:00
test_quantized.py Fused Affine Quantize/Dequantize ops (#1282) 2024-07-29 15:11:38 -07:00
test_random.py Implement sampling from laplace distribution. (#1279) 2024-07-24 15:15:37 +02:00
test_reduce.py Add GPU support for uint64/int64 reductions (#569) 2024-01-31 11:18:04 -08:00
test_tree.py Add isort pre-commit and run (#68) 2023-12-08 11:31:47 -08:00
test_upsample.py Upsample with bicubic interpolation (#967) 2024-04-10 15:47:22 -07:00
test_vmap.py Add vmap to scatter (#1200) 2024-08-05 20:12:27 -07:00