Commit Graph

13 Commits

Author SHA1 Message Date
Awni Hannun
ef631d63af
faster rms norm (#2433) 2025-07-29 13:12:00 -07:00
Awni Hannun
d107d8d495
add cuda gemv (#2400) 2025-07-22 08:24:13 -07:00
Cheng
f55c4ed1d6
Remove thrust iterators (#2396) 2025-07-21 07:30:27 -07:00
Cheng
cb349a291c
[CUDA] Use cuda::std::complex in place of cuComplex (#2372) 2025-07-15 00:36:13 -07:00
Cheng
d34f887abc
Add Primitive::name and remove Primitive::print (#2365) 2025-07-14 14:06:35 -07:00
Cheng
85873cb162
[CUDA] Do vectorized store/load in contiguous elementwise ops (#2342)
* Do vectorized store/load in unary ops

* Do vectorized store/load in binary_two ops

* Do vectorized store/load in copy ops

* Do vectorized store/load in ternary ops

* Use int32_t for IdxT

* binary => binary_two in binary_two.cu

* Fix tests on large arrays

* Use uint as index type

* Contig uses uint as index and non-contig uses int
2025-07-09 18:48:43 -07:00
Awni Hannun
ec0d5db67b
[CUDA] Switch to CUDA graphs (#2317)
* cuda graph prototype

fix signal bug + start to add dependencies

capture more

capture more ops

remaining ops

fix reduce and rope deps

add concurrent context

try update, but not working

cosistent topology order

use node api

use node api directly to reduce overhead

fix bug

use kernels in unary

cache graph

format

fix synchronization

format

* comment
2025-07-02 15:59:13 -07:00
Awni Hannun
dd4f53db63
use fp32 for testing, add more complex ops (#2322) 2025-07-01 07:30:00 -07:00
Angelos Katharopoulos
3d5e17e507
MLX_SWITCH macros to templates (#2320) 2025-07-01 01:33:44 -07:00
Awni Hannun
bc53f8293f
Cuda bug fixes 2 (#2298)
* more bug fixes

* more bug fixes

* format
2025-06-16 13:14:46 -07:00
Awni Hannun
c552ff2451
[CUDA] Fix back-end bugs and enable corresponding tests (#2296)
* Fix some cuda back-end bugs and enable corresponding tests

* more fixes

* enable more tests

* format
2025-06-16 08:45:40 -07:00
Cheng
a4fc671d3e
CUDA backend: compile (#2276)
* CUDA backend: compile

* Rename kernels/ to device/
2025-06-12 17:08:39 -07:00
Cheng
f8bad60609
CUDA backend: unary ops (#2158) 2025-06-09 06:45:08 -07:00