Cheng
cb349a291c
[CUDA] Use cuda::std::complex in place of cuComplex ( #2372 )
2025-07-15 00:36:13 -07:00
Awni Hannun
e7d2ebadd2
[CUDA] Affine quantize ( #2354 )
...
* affine quantize and dequantize kernels
* format
* fix
* format
2025-07-14 15:45:44 -07:00
Cheng
d34f887abc
Add Primitive::name and remove Primitive::print ( #2365 )
2025-07-14 14:06:35 -07:00
Angelos Katharopoulos
5201df5030
Fix imag() vjp ( #2367 )
2025-07-14 13:11:16 -07:00
Cheng
2d3c26c565
[CUDA] Do not put kernels in annoymous namespace ( #2362 )
2025-07-12 14:24:45 -07:00
Cheng
6325f60d52
[CUDA] Bundle CCCL for JIT compilation ( #2357 )
...
* Ship CCCL for JIT compilation
* Remove cexpf
2025-07-11 18:45:37 -07:00
Awni Hannun
42cc9cfbc7
fix copy dispatch ( #2360 )
2025-07-11 10:59:35 -07:00
Cheng
8347575ba1
[CUDA] Implement Scan kernel ( #2347 )
...
* Contiguous scan
* Strided scan
* Enable tests
* Fix failing logaddexp test
* Use cexpf in Metal
2025-07-10 16:54:12 -07:00
Angelos Katharopoulos
b6eec20260
Fix edge check in qmm_n QuantizedLoader ( #2355 )
2025-07-10 16:28:50 -07:00
Cheng
afb9817599
[CUDA] Put version in ptx cache dir path ( #2352 )
2025-07-10 07:24:21 -07:00
Cheng
8fb3e7a26c
[CUDA] Set current device before cudaGraphLaunch ( #2351 )
2025-07-10 07:24:02 -07:00
jhavukainen
8c7bc30ce4
Align mlx::core::min op nan propagation with NumPy ( #2346 )
2025-07-10 06:20:43 -07:00
Cheng
85873cb162
[CUDA] Do vectorized store/load in contiguous elementwise ops ( #2342 )
...
* Do vectorized store/load in unary ops
* Do vectorized store/load in binary_two ops
* Do vectorized store/load in copy ops
* Do vectorized store/load in ternary ops
* Use int32_t for IdxT
* binary => binary_two in binary_two.cu
* Fix tests on large arrays
* Use uint as index type
* Contig uses uint as index and non-contig uses int
2025-07-09 18:48:43 -07:00
Awni Hannun
e14ee12491
add zero for argsort vjp ( #2345 )
2025-07-09 14:37:14 -07:00
jhavukainen
8b9a3f3cea
Align mlx::core::max op nan propagation with NumPy ( #2339 )
...
* Make max op NaN propagation rules align with numpy
* Adding benchmarks and testing for max op nanpropagation
* Pre-commit formatting
* Fix max complex64 nan propagation and add test
* Improve the cpp unittest
* Only check nans on non-integral types in simd_reduce_impl.
* Cleanup using namespace alias
* Add cpu Max nanpropagation. Fix a small fib in cpu max dispatch data types for int8/int16.
* Make the max nanpropagation test more meaningful for integer types
* Remove tuple unpacking syntax to comply with earlier python versions. Add cuda skip to nanpropagation tests, fix cuda implementation in a separate PR.
2025-07-09 11:26:27 -07:00
Awni Hannun
fb4e8b896b
patch bump ( #2343 )
2025-07-08 14:26:07 -07:00
Cheng
2ca533b279
Fix compilation with CUDA 11 ( #2331 )
2025-07-07 20:00:43 -07:00
Angelos Katharopoulos
4a9b29a875
MoE backward improvements ( #2335 )
2025-07-07 17:59:53 -07:00
Cheng
9d10239af7
[CUDA] Do vectorized store/load in binary ops ( #2330 )
2025-07-07 08:44:14 -07:00
Angelos Katharopoulos
f5299f72cd
Fix layernorm race condition ( #2340 )
2025-07-07 06:06:01 -07:00
Cheng
0e0d9ac522
[CUDA] Add MLX_CUDA_GRAPH_CACHE_SIZE env for setting graph cache size ( #2329 )
2025-07-05 08:33:29 -07:00
Awni Hannun
8917022deb
fix graphs for older cuda ( #2328 )
2025-07-02 19:37:58 -07:00
Awni Hannun
ec0d5db67b
[CUDA] Switch to CUDA graphs ( #2317 )
...
* cuda graph prototype
fix signal bug + start to add dependencies
capture more
capture more ops
remaining ops
fix reduce and rope deps
add concurrent context
try update, but not working
cosistent topology order
use node api
use node api directly to reduce overhead
fix bug
use kernels in unary
cache graph
format
fix synchronization
format
* comment
2025-07-02 15:59:13 -07:00
Cheng
e76e9b87f0
Fix compilation error from integral_constant ( #2326 )
2025-07-02 06:04:38 -07:00
Awni Hannun
58f3860306
patch bump ( #2324 )
2025-07-01 12:12:16 -07:00
Awni Hannun
dd4f53db63
use fp32 for testing, add more complex ops ( #2322 )
2025-07-01 07:30:00 -07:00
Angelos Katharopoulos
3d5e17e507
MLX_SWITCH macros to templates ( #2320 )
2025-07-01 01:33:44 -07:00
Angelos Katharopoulos
772f471ff2
[CUDA] Fix reductions ( #2314 )
2025-06-27 12:59:20 -07:00
Angelos Katharopoulos
2c11d10f8d
Split broadcast so it is always fused in compile ( #2318 )
2025-06-26 22:08:18 -07:00
Angelos Katharopoulos
656ed7f780
Fix get 2d grid dims ( #2316 )
2025-06-25 13:03:09 -07:00
Awni Hannun
81bb9a2a9e
Compile float64 functions on CPU ( #2311 )
2025-06-24 10:18:52 -07:00
Awni Hannun
c9a9180584
Cuda perf tuning ( #2307 )
...
* perf tuning
* fix adding inputs arrays in matmul / srot
* format
* fix
2025-06-20 14:50:57 -07:00
Awni Hannun
76831ed83d
Build CUDA release in Circle ( #2306 )
...
* cuda release
* add license
2025-06-19 15:26:36 -07:00
Angelos Katharopoulos
b3d7b85376
Make ptx cache settable by environment variable ( #2304 )
2025-06-17 23:55:56 -07:00
Awni Hannun
cad5c0241c
[CUDA] synch properly waits for all tasks to finish and clear ( #2303 )
...
* cuda synch properly waits for all tasks to finish and clear
* fix copy
2025-06-17 12:03:25 -07:00
Awni Hannun
b8022c578a
divmod, partition, sort fixes ( #2302 )
2025-06-16 18:49:32 -07:00
Awni Hannun
bc53f8293f
Cuda bug fixes 2 ( #2298 )
...
* more bug fixes
* more bug fixes
* format
2025-06-16 13:14:46 -07:00
Awni Hannun
c552ff2451
[CUDA] Fix back-end bugs and enable corresponding tests ( #2296 )
...
* Fix some cuda back-end bugs and enable corresponding tests
* more fixes
* enable more tests
* format
2025-06-16 08:45:40 -07:00
Angelos Katharopoulos
580776559b
RoPE for CUDA ( #2293 )
...
* First working CUDA rope
* Fix random
2025-06-15 06:08:07 -07:00
Awni Hannun
a14aaa7c9d
Fix cuda arg reduce ( #2291 )
2025-06-14 17:54:00 -07:00
Awni Hannun
a6d780154f
fix cuda gemm for bf16 ( #2288 )
2025-06-13 22:10:46 -07:00
Awni Hannun
6871e2eeb7
fix cuda jit ( #2287 )
2025-06-13 19:21:46 -07:00
Awni Hannun
8402a2acf4
Fix complex power and print ( #2286 )
...
* fix complex power and print
* fix complex matmul shape
2025-06-13 11:13:00 -07:00
Jagrit Digani
fddb6933e1
Collection of refactors ( #2274 )
...
* Refactor gemv into a function
* Refactor splitk step 1
* Refactor split k axpby
* Rearrange steel_gemm_regular
* Redirect steel_gemm_regular
* Add axpby routing to steel_matmul_regular
* Refactor AddMM step 1
* Redirect steel_gemm
* Update addmm
* Comments and format
* Some cleanup
* Add architecture gen to device
* Update no copy condition in normalization to account for axis size 1
2025-06-13 10:44:56 -07:00
Cheng
c8b4787e4e
CUDA backend: indexing ops ( #2277 )
2025-06-12 21:44:19 -07:00
Awni Hannun
2188199ff8
[CUDA] ternary with select op ( #2283 )
...
* cuda ternary with select op
* comment + fix
* fix
2025-06-12 20:24:43 -07:00
Awni Hannun
aa07429bad
Fix cuda build ( #2284 )
2025-06-12 17:48:05 -07:00
Awni Hannun
918761a25a
[CUDA] RMSNorm and VJP ( #2280 )
...
* rms norm start
* nit
2025-06-12 17:09:49 -07:00
Cheng
a4fc671d3e
CUDA backend: compile ( #2276 )
...
* CUDA backend: compile
* Rename kernels/ to device/
2025-06-12 17:08:39 -07:00
Awni Hannun
f5f65ef48c
Make sliceUpdate general ( #2282 )
...
* Make sliceUpdate general
* fix
2025-06-12 16:48:54 -07:00
Cheng
c2dd81a8aa
Fix warnings from latest CUDA toolkit ( #2275 )
2025-06-12 06:03:01 -07:00
Cheng
d7e680ffe4
CUDA backend: layernorm ( #2271 )
2025-06-11 15:48:32 -07:00
Cheng
c371baf53a
CUDA backend: softmax ( #2272 )
2025-06-11 13:55:22 -07:00
Cheng
ccf78f566c
CUDA backend: argreduce ( #2270 )
2025-06-11 13:26:17 -07:00
Cheng
c9fa68664a
CUDA backend: reduce ( #2269 )
2025-06-11 11:22:25 -07:00
Awni Hannun
c35f4d089a
start cuda circle config ( #2256 )
...
* rebase
* fix metal kernel linking issue on cuda
* start cuda circle config
2025-06-10 21:19:47 -07:00
Angelos Katharopoulos
8590c0941e
Add load_safe to the general conv loaders ( #2258 )
2025-06-10 20:58:16 -07:00
Cheng
99c33d011d
rebase + nit ( #2260 )
...
Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-10 10:51:51 -07:00
Awni Hannun
62fecf3e13
fix conv export ( #2265 )
2025-06-10 09:34:01 -07:00
Cheng
7c4eb5d03e
CUDA backend: random ( #2261 )
2025-06-10 08:59:56 -07:00
Cheng
bae9a6b404
CUDA backend: sort ( #2262 )
...
Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-10 08:59:47 -07:00
Cheng
7ebb2e0193
CUDA backend: binary ops ( #2259 )
2025-06-10 06:37:40 -07:00
Awni Hannun
9ce77798b1
fix export to work with gather/scatter axis ( #2263 )
2025-06-09 20:37:27 -07:00
Cheng
f8bad60609
CUDA backend: unary ops ( #2158 )
2025-06-09 06:45:08 -07:00
Awni Hannun
1ca616844b
Fix unintuitive metal kernel caching ( #2242 )
...
* Fix unintuitive metal kernel caching
* alternative solution
2025-06-06 20:08:15 -07:00
Angelos Katharopoulos
2e8cf0b450
Change layernorms to two pass algorithm ( #2246 )
2025-06-06 13:34:56 -07:00
Cheng
24f89173d1
CUDA backend: matmul ( #2241 )
2025-06-06 12:24:04 -07:00
Awni Hannun
c6a20b427a
Improve metal elementwise kernels ( #2247 )
...
* improve metal elementwise kernels
* compile and copy
* fix jit
2025-06-06 11:37:40 -07:00
Cheng
52dc8c8cd5
Add profiler annotations in common primitives for CUDA backend ( #2244 )
2025-06-04 19:55:12 -07:00
Angelos Katharopoulos
aede70e81d
Perf regression fix ( #2243 )
2025-06-03 17:55:12 -07:00
Cheng
85a8beb5e4
Avoid atomic updates across CPU/GPU in CUDA event ( #2231 )
2025-06-03 16:49:06 -07:00
Cheng
0bb89e9e5f
Share more common code in Compiled ( #2240 )
...
* Share more common code in Compiled
* Remove build_lib_name
2025-06-03 16:48:50 -07:00
Cheng
5685ceb3c7
Avoid invoking allocator::malloc when creating CUDA event ( #2232 )
2025-06-03 16:48:40 -07:00
Suryash Malviya
0408ba0a76
Optimizing Complex Matrix Multiplication using Karatsuba’s Algorithm ( #2220 )
...
* Implementing Complex Matmul using Karatsuba Algorithm
* Implemented Karatsuba's Algorithm for complex matmul and pre-commit them
* fix
---------
Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-02 15:58:46 -07:00
Awni Hannun
cbad6c3093
version ( #2237 )
2025-06-02 15:58:33 -07:00
Cheng
1b021f6984
Fast primitives decide when to use the fallback ( #2216 )
2025-06-02 13:26:37 -07:00
Cheng
95b7551d65
Do not check event.is_signaled() in eval_impl ( #2230 )
2025-06-02 13:23:34 -07:00
Cheng
db5a7c6192
Add memory cache to CUDA backend ( #2221 )
...
* Move BufferCache out of allocator
* Add memory cache to cuda backend allocator
* Simplify BufferCache assuming buf can not be null
2025-05-30 12:12:54 -07:00
Awni Hannun
6ef2f67e7f
5bit quants ( #2226 )
...
* 5bit quants
* 5bit quants
2025-05-30 12:12:10 -07:00
Cheng
f76ee1ffd2
Move some dims utils to common ( #2223 )
2025-05-29 06:48:30 -07:00
Cheng
54a71f270a
Remove unused defines ( #2217 )
2025-05-23 06:14:58 -07:00
Cheng
79071bfba4
Fix out-of-bounds default value in logsumexp/softmax ( #2213 )
2025-05-21 07:25:16 -07:00
Cheng
7774b87cbd
Remove redundant simd_sum in logsumexp ( #2210 )
2025-05-21 07:25:03 -07:00
Cheng
35c87741cf
Build for compute capability 70 instead of 75 ( #2209 )
2025-05-20 19:42:48 -07:00
Clement Liaw
ab8883dd55
include mlx::core::version() symbols in the mlx static library ( #2207 )
2025-05-20 07:39:11 -07:00
Awni Hannun
eebe73001a
fix large arg reduce ( #2206 )
2025-05-19 13:10:44 -07:00
Cheng
237f9e58a8
Fix BEFORE keyword in target_include_directories ( #2204 )
2025-05-19 06:10:44 -07:00
Awni Hannun
8576e6fe36
fix conv2d bug + faster conv 1d ( #2195 )
...
* fix conv2d bug + faster conv 1d
* revert sort + flaky test
2025-05-18 06:05:11 -07:00
Angelos Katharopoulos
0654543dcc
Add complex eigh ( #2191 )
2025-05-18 00:18:43 -07:00
Awni Hannun
48ef3e74e2
reduce vjp for all and any ( #2193 )
2025-05-16 08:38:49 -07:00
Cheng
7d4b378952
Include cuda_bf16.h for bfloat16 overloads ( #2192 )
...
* Include cuda_bf16.h for bfloat16 overloads
* Add NO_GPU_MULTI(Eig) in cuda backend
2025-05-16 06:44:42 -07:00
Jack Wind
7ff5c41e06
Add set_threadgroup_memory_length to CommandEncoder ( #2183 )
2025-05-16 00:28:03 -07:00
Awni Hannun
602f43e3d1
fix conv grad ( #2187 )
2025-05-15 19:20:36 -07:00
Awni Hannun
c1eb9d05d9
non-symmetric eig and eigh ( #2188 )
2025-05-15 13:01:44 -07:00
Angelos Katharopoulos
cf6c939e86
Fix some complex vjps ( #2178 )
2025-05-14 23:37:12 -07:00
Angelos Katharopoulos
130df35e1b
Add random normal distribution for complex numbers ( #2182 )
2025-05-13 22:43:45 -07:00
Cheng
0751263dec
Fix typo in row_reduce_small ( #2179 )
2025-05-13 20:19:54 -07:00
Cheng
eca2f3eb97
Add remove_index utility ( #2173 )
2025-05-13 17:09:56 -07:00
Angelos Katharopoulos
3aa9cf3f9e
Fix put_along_axis for empty arrays ( #2181 )
2025-05-13 14:27:53 -07:00
Awni Hannun
8f3d208dce
Close a couple edge case bugs: hadamard and addmm on empty inputs ( #2177 )
...
* handle hadamard and addmm on empty inputs
* fix
2025-05-12 10:48:57 -07:00