Commit Graph

1243 Commits

Author SHA1 Message Date
Angelos Katharopoulos
3d4174cd37 Add gradient for the scales and biases in gather qmm 2025-07-05 00:58:17 -07:00
Angelos Katharopoulos
bda1534a44 Improve the gradient of gather_qmm as well 2025-07-04 20:23:58 -07:00
Angelos Katharopoulos
b28577289e Disable the test for CUDA 2025-07-04 19:17:45 -07:00
Angelos Katharopoulos
2d0f452aae Fix the test and cpu edge case 2025-07-04 18:36:20 -07:00
Angelos Katharopoulos
bd0622c4d9 Address floating point exception on linux blas 2025-07-04 13:16:54 -07:00
Angelos Katharopoulos
22f9b8a6fc More JIT fixes 2025-07-03 17:33:25 -07:00
Angelos Katharopoulos
aaf44d915f Add no cpu and no gpu implementations where needed 2025-07-03 17:03:35 -07:00
Angelos Katharopoulos
ed1a81210d Fix the JIT 2025-07-03 15:57:43 -07:00
Angelos Katharopoulos
f2994d5b29 Fix scatter_axis on single axis arrays 2025-07-03 14:50:23 -07:00
Angelos Katharopoulos
06a2e74eb2 Fix metal jit 2025-07-03 14:09:05 -07:00
Angelos Katharopoulos
d96a33c776 Add rudimentary test for gather_mm with sorted indices 2025-07-03 14:02:33 -07:00
Angelos Katharopoulos
4babc035a3 Add a test for segmented_mm 2025-07-03 13:49:46 -07:00
Angelos Katharopoulos
a8d7b74984 Simplify the jacobian as well 2025-07-02 18:23:48 -07:00
Angelos Katharopoulos
a29fa053c6 Use segmented_mm to calculate the MoE gradient 2025-07-02 16:24:21 -07:00
Angelos Katharopoulos
8f771efb82 Add the metal version of segmented mm 2025-07-02 15:45:09 -07:00
Angelos Katharopoulos
3104c3eb14 Change the segments to be more general 2025-07-02 15:45:09 -07:00
Angelos Katharopoulos
6020ad6363 Start the segmented_mm op and CPU primitive 2025-07-02 15:45:09 -07:00
Cheng
e76e9b87f0 Fix compilation error from integral_constant (#2326) 2025-07-02 06:04:38 -07:00
Awni Hannun
cfb6a244ea allow parameters to be deleted (#2325) 2025-07-01 21:27:23 -07:00
Awni Hannun
58f3860306 patch bump (#2324) v0.26.2 2025-07-01 12:12:16 -07:00
Awni Hannun
dd4f53db63 use fp32 for testing, add more complex ops (#2322) 2025-07-01 07:30:00 -07:00
Angelos Katharopoulos
3d5e17e507 MLX_SWITCH macros to templates (#2320) 2025-07-01 01:33:44 -07:00
Awni Hannun
33bf1a244b Fix module update in strict mode (#2321)
* fix module update in strict mode

* allow GELU to be pickled
2025-06-29 11:12:29 -07:00
Angelos Katharopoulos
772f471ff2 [CUDA] Fix reductions (#2314) 2025-06-27 12:59:20 -07:00
Angelos Katharopoulos
2c11d10f8d Split broadcast so it is always fused in compile (#2318) 2025-06-26 22:08:18 -07:00
Angelos Katharopoulos
656ed7f780 Fix get 2d grid dims (#2316) 2025-06-25 13:03:09 -07:00
Awni Hannun
81bb9a2a9e Compile float64 functions on CPU (#2311) 2025-06-24 10:18:52 -07:00
Angelos Katharopoulos
5adf185f86 Fix update_modules() when providing a subset (#2308) 2025-06-20 17:19:46 -07:00
Awni Hannun
c9a9180584 Cuda perf tuning (#2307)
* perf tuning

* fix adding inputs arrays in matmul / srot

* format

* fix
2025-06-20 14:50:57 -07:00
Awni Hannun
76831ed83d Build CUDA release in Circle (#2306)
* cuda release

* add license
2025-06-19 15:26:36 -07:00
Angelos Katharopoulos
b3d7b85376 Make ptx cache settable by environment variable (#2304) 2025-06-17 23:55:56 -07:00
Awni Hannun
cad5c0241c [CUDA] synch properly waits for all tasks to finish and clear (#2303)
* cuda synch properly waits for all tasks to finish and clear

* fix copy
2025-06-17 12:03:25 -07:00
Awni Hannun
b8022c578a divmod, partition, sort fixes (#2302) 2025-06-16 18:49:32 -07:00
Awni Hannun
bc53f8293f Cuda bug fixes 2 (#2298)
* more bug fixes

* more bug fixes

* format
2025-06-16 13:14:46 -07:00
Awni Hannun
c552ff2451 [CUDA] Fix back-end bugs and enable corresponding tests (#2296)
* Fix some cuda back-end bugs and enable corresponding tests

* more fixes

* enable more tests

* format
2025-06-16 08:45:40 -07:00
Awni Hannun
4fda5fbdf9 add python testing for cuda with ability to skip list of tests (#2295) 2025-06-15 10:56:48 -07:00
Angelos Katharopoulos
580776559b RoPE for CUDA (#2293)
* First working CUDA rope

* Fix random
2025-06-15 06:08:07 -07:00
Awni Hannun
a14aaa7c9d Fix cuda arg reduce (#2291) 2025-06-14 17:54:00 -07:00
Awni Hannun
a6d780154f fix cuda gemm for bf16 (#2288) 2025-06-13 22:10:46 -07:00
Awni Hannun
6871e2eeb7 fix cuda jit (#2287) 2025-06-13 19:21:46 -07:00
Awni Hannun
8402a2acf4 Fix complex power and print (#2286)
* fix complex power and print

* fix complex matmul shape
2025-06-13 11:13:00 -07:00
Jagrit Digani
fddb6933e1 Collection of refactors (#2274)
* Refactor gemv into a function

* Refactor splitk step 1

* Refactor split k axpby

* Rearrange steel_gemm_regular

* Redirect steel_gemm_regular

* Add axpby routing to steel_matmul_regular

* Refactor AddMM step 1

* Redirect steel_gemm

* Update addmm

* Comments and format

* Some cleanup

* Add architecture gen to device

* Update no copy condition in normalization to account for axis size 1
2025-06-13 10:44:56 -07:00
Cheng
c8b4787e4e CUDA backend: indexing ops (#2277) 2025-06-12 21:44:19 -07:00
Awni Hannun
2188199ff8 [CUDA] ternary with select op (#2283)
* cuda ternary with select op

* comment + fix

* fix
2025-06-12 20:24:43 -07:00
Awni Hannun
aa07429bad Fix cuda build (#2284) 2025-06-12 17:48:05 -07:00
Awni Hannun
918761a25a [CUDA] RMSNorm and VJP (#2280)
* rms norm start

* nit
2025-06-12 17:09:49 -07:00
Cheng
a4fc671d3e CUDA backend: compile (#2276)
* CUDA backend: compile

* Rename kernels/ to device/
2025-06-12 17:08:39 -07:00
Awni Hannun
f5f65ef48c Make sliceUpdate general (#2282)
* Make sliceUpdate general

* fix
2025-06-12 16:48:54 -07:00
Cheng
c2dd81a8aa Fix warnings from latest CUDA toolkit (#2275) 2025-06-12 06:03:01 -07:00
Cheng
d7e680ffe4 CUDA backend: layernorm (#2271) 2025-06-11 15:48:32 -07:00