Manuel Villanueva
9cbb1b0148
Modified sort behavior when running CPU or Metal to match NumPy/JAX ( #2667 )
...
* Modified sort behavior when running CPU or Metal to match NumPy/JAX sorting behavior.
* Modified sort behavior when running CPU or Metal to match NumPy/JAX
* nits
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2025-10-13 14:36:45 -07:00
Awni Hannun
630350ad3e
Precise sigmoid ( #2659 )
...
* bump patch
* Sigmoid matches PyTorch and is more precise on tails
2025-10-10 10:05:23 -07:00
Awni Hannun
380aeb58ae
enable admm low-precision cpu ( #2661 )
2025-10-10 09:50:54 -07:00
Awni Hannun
e89e8b4272
Export with callback ( #2612 )
...
* export with callback
* export with callback
* Add types, fix kwarg ordering bug + test
* cleanup, test, fix
* typos
2025-10-08 19:24:33 -07:00
Awni Hannun
343e33b6d5
fix all_gather vjp ( #2654 )
2025-10-07 06:05:23 -07:00
Awni Hannun
a7a94b29d7
Fix compile when outputs change ( #2648 )
2025-10-03 08:40:57 -07:00
Awni Hannun
1c9ae1eaa1
cuda fix flaky test ( #2646 )
2025-10-02 15:40:04 -07:00
Awni Hannun
e88f2d4a8e
fix cross entropy axis param ( #2641 )
...
* fix cross entropy axis param
* faster grad clipping
2025-10-01 16:49:55 -07:00
Angelos Katharopoulos
eb24267b56
Compile now can attach arbitrary data to an entry ( #2634 )
2025-09-30 13:33:27 -07:00
Awni Hannun
dc371ae7a5
fix for max block dim ( #2631 )
2025-09-29 08:59:25 -07:00
Awni Hannun
d4f4ff3c5e
Allow None input to compiled functions ( #2621 )
...
* Allow None input to compiled functions
* Allow None input to compiled functions
2025-09-25 08:42:23 -07:00
Daniel Yeh
fbbf3b9b3e
Support pickling array for bfloat16 ( #2586 )
...
* add bfloat16 pickling
* Improvements
* improve
---------
Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de >
2025-09-22 20:12:15 -07:00
Awni Hannun
711a645807
avoid producing NaN in attention ( #2608 )
2025-09-22 13:10:43 -07:00
Josh Bleecher Snyder
aa9d44b3d4
implement Convolution::output_shape ( #2601 )
...
- pull conv_out_shape out for re-use
- add Conv::output_shape
- add e2e python tests confirming shapeless=True support and correctness
Updates #2599
2025-09-22 10:09:45 -07:00
Cheng
787c0d90cd
Detect cache thrashing in LRUCache ( #2600 )
...
* Detect cache thrashing in LRUCache
* Do not check cache thrashing in tests
2025-09-19 09:12:14 +09:00
Awni Hannun
50cc09887f
expose depends ( #2606 )
2025-09-18 10:06:15 -07:00
Cheng
6a3acf2301
[CUDA] Set bias as input when using bias epilogue ( #2584 )
2025-09-11 15:31:09 +09:00
Awni Hannun
d6977f2a57
Add sdpa with sinks ( #2558 )
...
* add sdpa with sinks
* fix 2 pass
* fix matrix sdpa
* fix perf regression
* add to cuda (#2580 )
2025-09-10 14:53:00 -07:00
Cheng
52b8384d10
Fix flaky addmm tests ( #2581 )
2025-09-10 14:22:22 +09:00
Cheng
44cc5da4bc
[CUDA] Fix alpha not respected when using bias epilogue ( #2578 )
2025-09-10 09:08:01 +09:00
Awni Hannun
17310d91a6
Add batch offsets for mx.fast.rope ( #2564 )
...
* implement batch rope for Metal
* cuda rope (#2576 )
2025-09-08 17:35:07 -07:00
Awni Hannun
b61a65e313
fix copies in sdpa ( #2563 )
2025-09-02 11:00:36 -07:00
Awni Hannun
8ce49cd39e
fix quantized vjp for mxfp4 ( #2555 )
2025-08-29 10:06:15 -07:00
Awni Hannun
70560b6bd5
Add mode parameter for quantization ( #2499 )
...
* add mode parameter for quantization
* mxfp4 quantize/dequantize + start of optional biases
* mxfp4 works
* speedup
* cpu mxfp4
* fix
* fix test tol
* fix
* refactor
* add quant mode enum
2025-08-28 06:45:26 -07:00
Awni Hannun
7ef8a6f2d5
[CUDA] fix sort ( #2550 )
...
* [CUDA] fix sort
* fix test
2025-08-27 19:48:43 -07:00
Awni Hannun
5458d43247
add load with path tests ( #2543 )
2025-08-26 14:24:47 -07:00
Awni Hannun
3dcb286baf
Remove stream from average grads so it uses default ( #2532 )
...
* Remove stream from average grads so it uses default
* comment
2025-08-25 15:56:29 -07:00
Cheng
4822c3dbe9
[CUDA] Implement DynamicSlice/DynamicSliceUpdate ( #2533 )
...
* Move DynamicSlice to gpu/primitives
* Implement compute_dynamic_offset in CUDA
2025-08-26 07:31:39 +09:00
Anastasiia Filippova
9392fc3f88
NCCL backend ( #2476 )
2025-08-21 11:56:15 -07:00
Awni Hannun
e843c4d8d5
fix power ( #2523 )
2025-08-21 06:46:01 -07:00
Angelos Katharopoulos
e397177f6e
Custom cuda kernel ( #2517 )
2025-08-20 17:20:22 -07:00
Cheng
f4c8888cbe
[CUDA] Fix stride of singleton dims before passing to cuDNN ( #2521 )
2025-08-21 08:55:26 +09:00
Angelos Katharopoulos
25c1e03205
Fix overflow in large filter small channels ( #2520 )
2025-08-20 08:03:29 -07:00
Cheng
ac85ddfdb7
[CUDA] Add GEMM-based fallback convolution kernels ( #2511 )
...
* Add gemm_conv
* Add gemm_grouped_conv
2025-08-20 10:06:22 +09:00
Awni Hannun
e7c6e1db82
no segfault with uninitialized array.at ( #2514 )
2025-08-18 08:33:38 -07:00
Awni Hannun
c5fcd5b61b
fix custom kernel test ( #2510 )
2025-08-18 06:45:59 -07:00
Cheng
1ba18ff7d9
[CUDA] Fix conv grads with groups ( #2495 )
...
* Put reshape utils in one file
* [CUDA] Fix conv grads with groups
* Put the reshape utils in gpu/copy.h
2025-08-16 10:09:18 +09:00
Luca Vivona
728d4db582
Support destination arg in tree flatten/unflatten ( #2450 )
2025-08-06 15:34:59 -07:00
Awni Hannun
fa89f0b150
faster gather qmm sorted test ( #2463 )
2025-08-05 06:27:40 -07:00
Awni Hannun
0b807893a7
fix wraps compile ( #2461 )
2025-08-04 16:14:18 -07:00
Cheng
86c6a15571
[CUDA] Backward convolution ( #2431 )
2025-08-01 09:54:05 +09:00
junpeiz
8b25ce62d5
Add tests for export including control flow models and quantized models ( #2430 )
...
* Add tests for export, including control flow export and quantized model export.
* Skip quantization related test for CUDA backend.
2025-07-31 11:06:26 -07:00
Awni Hannun
d32519c8ee
fix gemv regression ( #2445 )
2025-07-30 14:23:01 -07:00
Awni Hannun
b405591249
fix circular reference ( #2443 )
2025-07-30 09:37:44 -07:00
Awni Hannun
ef631d63af
faster rms norm ( #2433 )
2025-07-29 13:12:00 -07:00
Awni Hannun
5597fa089c
Fix qvm splitk ( #2415 )
2025-07-25 11:50:24 -07:00
Cheng
6f5874a2f2
[CUDA] Initial implementation of Convolution with cuDNN ( #2385 )
...
* Link with cuDNN
* Initial implementation
* Remove backend apis
* Fix recording cudnn conv
* More unused backend apis
* Fix C++ conv tests
* include cudnn as python dep
* Install libcudnn9-dev-cuda-12 in CI
* cudnn only accepts contiguous inputs
* Switch to backend apis
* Plan needs to be kept alive
* Turn off tf32
* Add cache
* Test the native cuda graph api
* Set cudnn stream before execution
* Make LRUCache more like a normal container
* Do error check for cublas handle
* Zero-initilizing array
* Use tf32 for conv
* Skip TestConv.test_torch_conv_2D test
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2025-07-25 08:12:10 +09:00
Awni Hannun
e1840853ce
full row mask in sdpa consistently gives nan ( #2406 )
2025-07-23 16:37:03 -07:00
Gökdeniz Gülmez
deee214a95
Adding support for the Muon Optimizer ( #1914 )
...
* initial commit with workong optmimizer
* update ACKNOWLEDGMENTS.md
* nits and adding it to test
* nits
* G.astype(mx.bfloat16) to G.astype(G.dtype)
* G.ndim >= 2 to assert G.ndim == 2
* remove coments
* replace with mx.addmm
* remove comments
* format
* nits
* match muon
* fix addmm
---------
Co-authored-by: Awni Hannun <awni@apple.com >
2025-07-18 12:25:28 -07:00
Awni Hannun
f409b229a4
fix ring distributed test ( #2380 )
2025-07-16 11:25:24 -07:00