Compare commits

...

223 Commits

Author SHA1 Message Date
Cheng
8347575ba1 [CUDA] Implement Scan kernel (#2347)
* Contiguous scan

* Strided scan

* Enable tests

* Fix failing logaddexp test

* Use cexpf in Metal
2025-07-10 16:54:12 -07:00
Angelos Katharopoulos
b6eec20260 Fix edge check in qmm_n QuantizedLoader (#2355) 2025-07-10 16:28:50 -07:00
Angelos Katharopoulos
0eb035b4b1 Fix type promotion in Adam with bias correction (#2350) 2025-07-10 11:14:42 -07:00
Cheng
afb9817599 [CUDA] Put version in ptx cache dir path (#2352) 2025-07-10 07:24:21 -07:00
Cheng
8fb3e7a26c [CUDA] Set current device before cudaGraphLaunch (#2351) 2025-07-10 07:24:02 -07:00
jhavukainen
8c7bc30ce4 Align mlx::core::min op nan propagation with NumPy (#2346) 2025-07-10 06:20:43 -07:00
Cheng
85873cb162 [CUDA] Do vectorized store/load in contiguous elementwise ops (#2342)
* Do vectorized store/load in unary ops

* Do vectorized store/load in binary_two ops

* Do vectorized store/load in copy ops

* Do vectorized store/load in ternary ops

* Use int32_t for IdxT

* binary => binary_two in binary_two.cu

* Fix tests on large arrays

* Use uint as index type

* Contig uses uint as index and non-contig uses int
2025-07-09 18:48:43 -07:00
Awni Hannun
e14ee12491 add zero for argsort vjp (#2345) 2025-07-09 14:37:14 -07:00
jhavukainen
8b9a3f3cea Align mlx::core::max op nan propagation with NumPy (#2339)
* Make max op NaN propagation rules align with numpy

* Adding benchmarks and testing for max op nanpropagation

* Pre-commit formatting

* Fix max complex64 nan propagation and add test

* Improve the cpp unittest

* Only check nans on non-integral types in simd_reduce_impl.

* Cleanup using namespace alias

* Add cpu Max nanpropagation. Fix a small fib in cpu max dispatch data types for int8/int16.

* Make the max nanpropagation test more meaningful for integer types

* Remove tuple unpacking syntax to comply with earlier python versions. Add cuda skip to nanpropagation tests, fix cuda implementation in a separate PR.
2025-07-09 11:26:27 -07:00
Awni Hannun
fb4e8b896b patch bump (#2343) 2025-07-08 14:26:07 -07:00
Cheng
2ca533b279 Fix compilation with CUDA 11 (#2331) 2025-07-07 20:00:43 -07:00
Angelos Katharopoulos
4a9b29a875 MoE backward improvements (#2335) 2025-07-07 17:59:53 -07:00
Awni Hannun
a4fcc893cd auto build linux release (#2341) 2025-07-07 09:29:23 -07:00
Cheng
9d10239af7 [CUDA] Do vectorized store/load in binary ops (#2330) 2025-07-07 08:44:14 -07:00
Cheng
19facd4b20 Build with all cpu cores by default (#2336) 2025-07-07 06:06:45 -07:00
Angelos Katharopoulos
f5299f72cd Fix layernorm race condition (#2340) 2025-07-07 06:06:01 -07:00
Cheng
0e0d9ac522 [CUDA] Add MLX_CUDA_GRAPH_CACHE_SIZE env for setting graph cache size (#2329) 2025-07-05 08:33:29 -07:00
Awni Hannun
8917022deb fix graphs for older cuda (#2328) 2025-07-02 19:37:58 -07:00
Awni Hannun
ec0d5db67b [CUDA] Switch to CUDA graphs (#2317)
* cuda graph prototype

fix signal bug + start to add dependencies

capture more

capture more ops

remaining ops

fix reduce and rope deps

add concurrent context

try update, but not working

cosistent topology order

use node api

use node api directly to reduce overhead

fix bug

use kernels in unary

cache graph

format

fix synchronization

format

* comment
2025-07-02 15:59:13 -07:00
Cheng
e76e9b87f0 Fix compilation error from integral_constant (#2326) 2025-07-02 06:04:38 -07:00
Awni Hannun
cfb6a244ea allow parameters to be deleted (#2325) 2025-07-01 21:27:23 -07:00
Awni Hannun
58f3860306 patch bump (#2324) 2025-07-01 12:12:16 -07:00
Awni Hannun
dd4f53db63 use fp32 for testing, add more complex ops (#2322) 2025-07-01 07:30:00 -07:00
Angelos Katharopoulos
3d5e17e507 MLX_SWITCH macros to templates (#2320) 2025-07-01 01:33:44 -07:00
Awni Hannun
33bf1a244b Fix module update in strict mode (#2321)
* fix module update in strict mode

* allow GELU to be pickled
2025-06-29 11:12:29 -07:00
Angelos Katharopoulos
772f471ff2 [CUDA] Fix reductions (#2314) 2025-06-27 12:59:20 -07:00
Angelos Katharopoulos
2c11d10f8d Split broadcast so it is always fused in compile (#2318) 2025-06-26 22:08:18 -07:00
Angelos Katharopoulos
656ed7f780 Fix get 2d grid dims (#2316) 2025-06-25 13:03:09 -07:00
Awni Hannun
81bb9a2a9e Compile float64 functions on CPU (#2311) 2025-06-24 10:18:52 -07:00
Angelos Katharopoulos
5adf185f86 Fix update_modules() when providing a subset (#2308) 2025-06-20 17:19:46 -07:00
Awni Hannun
c9a9180584 Cuda perf tuning (#2307)
* perf tuning

* fix adding inputs arrays in matmul / srot

* format

* fix
2025-06-20 14:50:57 -07:00
Awni Hannun
76831ed83d Build CUDA release in Circle (#2306)
* cuda release

* add license
2025-06-19 15:26:36 -07:00
Angelos Katharopoulos
b3d7b85376 Make ptx cache settable by environment variable (#2304) 2025-06-17 23:55:56 -07:00
Awni Hannun
cad5c0241c [CUDA] synch properly waits for all tasks to finish and clear (#2303)
* cuda synch properly waits for all tasks to finish and clear

* fix copy
2025-06-17 12:03:25 -07:00
Awni Hannun
b8022c578a divmod, partition, sort fixes (#2302) 2025-06-16 18:49:32 -07:00
Awni Hannun
bc53f8293f Cuda bug fixes 2 (#2298)
* more bug fixes

* more bug fixes

* format
2025-06-16 13:14:46 -07:00
Awni Hannun
c552ff2451 [CUDA] Fix back-end bugs and enable corresponding tests (#2296)
* Fix some cuda back-end bugs and enable corresponding tests

* more fixes

* enable more tests

* format
2025-06-16 08:45:40 -07:00
Awni Hannun
4fda5fbdf9 add python testing for cuda with ability to skip list of tests (#2295) 2025-06-15 10:56:48 -07:00
Angelos Katharopoulos
580776559b RoPE for CUDA (#2293)
* First working CUDA rope

* Fix random
2025-06-15 06:08:07 -07:00
Awni Hannun
a14aaa7c9d Fix cuda arg reduce (#2291) 2025-06-14 17:54:00 -07:00
Awni Hannun
a6d780154f fix cuda gemm for bf16 (#2288) 2025-06-13 22:10:46 -07:00
Awni Hannun
6871e2eeb7 fix cuda jit (#2287) 2025-06-13 19:21:46 -07:00
Awni Hannun
8402a2acf4 Fix complex power and print (#2286)
* fix complex power and print

* fix complex matmul shape
2025-06-13 11:13:00 -07:00
Jagrit Digani
fddb6933e1 Collection of refactors (#2274)
* Refactor gemv into a function

* Refactor splitk step 1

* Refactor split k axpby

* Rearrange steel_gemm_regular

* Redirect steel_gemm_regular

* Add axpby routing to steel_matmul_regular

* Refactor AddMM step 1

* Redirect steel_gemm

* Update addmm

* Comments and format

* Some cleanup

* Add architecture gen to device

* Update no copy condition in normalization to account for axis size 1
2025-06-13 10:44:56 -07:00
Cheng
c8b4787e4e CUDA backend: indexing ops (#2277) 2025-06-12 21:44:19 -07:00
Awni Hannun
2188199ff8 [CUDA] ternary with select op (#2283)
* cuda ternary with select op

* comment + fix

* fix
2025-06-12 20:24:43 -07:00
Awni Hannun
aa07429bad Fix cuda build (#2284) 2025-06-12 17:48:05 -07:00
Awni Hannun
918761a25a [CUDA] RMSNorm and VJP (#2280)
* rms norm start

* nit
2025-06-12 17:09:49 -07:00
Cheng
a4fc671d3e CUDA backend: compile (#2276)
* CUDA backend: compile

* Rename kernels/ to device/
2025-06-12 17:08:39 -07:00
Awni Hannun
f5f65ef48c Make sliceUpdate general (#2282)
* Make sliceUpdate general

* fix
2025-06-12 16:48:54 -07:00
Cheng
c2dd81a8aa Fix warnings from latest CUDA toolkit (#2275) 2025-06-12 06:03:01 -07:00
Cheng
d7e680ffe4 CUDA backend: layernorm (#2271) 2025-06-11 15:48:32 -07:00
Cheng
c371baf53a CUDA backend: softmax (#2272) 2025-06-11 13:55:22 -07:00
Cheng
ccf78f566c CUDA backend: argreduce (#2270) 2025-06-11 13:26:17 -07:00
Cheng
c9fa68664a CUDA backend: reduce (#2269) 2025-06-11 11:22:25 -07:00
Awni Hannun
c35f4d089a start cuda circle config (#2256)
* rebase

* fix metal kernel linking issue on cuda

* start cuda circle config
2025-06-10 21:19:47 -07:00
Angelos Katharopoulos
8590c0941e Add load_safe to the general conv loaders (#2258) 2025-06-10 20:58:16 -07:00
Cheng
095163b8d1 Fix building cpp benchmarks on Linux (#2268) 2025-06-10 17:10:24 -07:00
Cheng
99c33d011d rebase + nit (#2260)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-10 10:51:51 -07:00
Awni Hannun
62fecf3e13 fix conv export (#2265) 2025-06-10 09:34:01 -07:00
Cheng
7c4eb5d03e CUDA backend: random (#2261) 2025-06-10 08:59:56 -07:00
Cheng
bae9a6b404 CUDA backend: sort (#2262)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-10 08:59:47 -07:00
Christopher Fleetwood
004c1d8ef2 Report number of missing parameters (#2264)
* chore: inform

* chore: format

---------

Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com>
2025-06-10 06:37:50 -07:00
Cheng
7ebb2e0193 CUDA backend: binary ops (#2259) 2025-06-10 06:37:40 -07:00
Awni Hannun
9ce77798b1 fix export to work with gather/scatter axis (#2263) 2025-06-09 20:37:27 -07:00
Cheng
f8bad60609 CUDA backend: unary ops (#2158) 2025-06-09 06:45:08 -07:00
Emmanuel Ferdman
5866b3857b Refactor the lu test (#2250)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-06-07 06:12:08 -07:00
Awni Hannun
1ca616844b Fix unintuitive metal kernel caching (#2242)
* Fix unintuitive metal kernel caching

* alternative solution
2025-06-06 20:08:15 -07:00
Angelos Katharopoulos
2e8cf0b450 Change layernorms to two pass algorithm (#2246) 2025-06-06 13:34:56 -07:00
Cheng
24f89173d1 CUDA backend: matmul (#2241) 2025-06-06 12:24:04 -07:00
Awni Hannun
c6a20b427a Improve metal elementwise kernels (#2247)
* improve metal elementwise kernels

* compile and copy

* fix jit
2025-06-06 11:37:40 -07:00
Awni Hannun
a5ac9244c4 fix linux linking error (#2248) 2025-06-06 10:41:51 -07:00
Awni Hannun
c763fe1be0 default strict mode for module update and update_modules (#2239) 2025-06-05 15:27:02 -07:00
Cheng
52dc8c8cd5 Add profiler annotations in common primitives for CUDA backend (#2244) 2025-06-04 19:55:12 -07:00
Angelos Katharopoulos
aede70e81d Perf regression fix (#2243) 2025-06-03 17:55:12 -07:00
Cheng
85a8beb5e4 Avoid atomic updates across CPU/GPU in CUDA event (#2231) 2025-06-03 16:49:06 -07:00
Cheng
0bb89e9e5f Share more common code in Compiled (#2240)
* Share more common code in Compiled

* Remove build_lib_name
2025-06-03 16:48:50 -07:00
Cheng
5685ceb3c7 Avoid invoking allocator::malloc when creating CUDA event (#2232) 2025-06-03 16:48:40 -07:00
Suryash Malviya
0408ba0a76 Optimizing Complex Matrix Multiplication using Karatsuba’s Algorithm (#2220)
* Implementing Complex Matmul using Karatsuba Algorithm

* Implemented Karatsuba's Algorithm for complex matmul and pre-commit them

* fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-02 15:58:46 -07:00
Awni Hannun
cbad6c3093 version (#2237) 2025-06-02 15:58:33 -07:00
Cheng
1b021f6984 Fast primitives decide when to use the fallback (#2216) 2025-06-02 13:26:37 -07:00
Cheng
95b7551d65 Do not check event.is_signaled() in eval_impl (#2230) 2025-06-02 13:23:34 -07:00
Cheng
db5a7c6192 Add memory cache to CUDA backend (#2221)
* Move BufferCache out of allocator

* Add memory cache to cuda backend allocator

* Simplify BufferCache assuming buf can not be null
2025-05-30 12:12:54 -07:00
Awni Hannun
6ef2f67e7f 5bit quants (#2226)
* 5bit quants

* 5bit quants
2025-05-30 12:12:10 -07:00
Cheng
f76ee1ffd2 Move some dims utils to common (#2223) 2025-05-29 06:48:30 -07:00
Cheng
54a71f270a Remove unused defines (#2217) 2025-05-23 06:14:58 -07:00
Awni Hannun
55b4062dd8 copyright in docs (#2214) 2025-05-21 17:13:04 -07:00
Cheng
79071bfba4 Fix out-of-bounds default value in logsumexp/softmax (#2213) 2025-05-21 07:25:16 -07:00
Cheng
7774b87cbd Remove redundant simd_sum in logsumexp (#2210) 2025-05-21 07:25:03 -07:00
Cheng
35c87741cf Build for compute capability 70 instead of 75 (#2209) 2025-05-20 19:42:48 -07:00
Jack Wind
4cbe605214 Feat: Allow per-target Metal debug flags (#2201)
* feat: allow per-target Metal debug flags

* formatting fix
2025-05-20 10:22:26 -07:00
Clement Liaw
ab8883dd55 include mlx::core::version() symbols in the mlx static library (#2207) 2025-05-20 07:39:11 -07:00
Awni Hannun
eebe73001a fix large arg reduce (#2206) 2025-05-19 13:10:44 -07:00
Angelos Katharopoulos
0359bf02c9 Nearest upsample (#2202) 2025-05-19 11:23:38 -07:00
Cheng
237f9e58a8 Fix BEFORE keyword in target_include_directories (#2204) 2025-05-19 06:10:44 -07:00
Awni Hannun
8576e6fe36 fix conv2d bug + faster conv 1d (#2195)
* fix conv2d bug + faster conv 1d

* revert sort + flaky test
2025-05-18 06:05:11 -07:00
Angelos Katharopoulos
0654543dcc Add complex eigh (#2191) 2025-05-18 00:18:43 -07:00
Awni Hannun
48ef3e74e2 reduce vjp for all and any (#2193) 2025-05-16 08:38:49 -07:00
Cheng
7d4b378952 Include cuda_bf16.h for bfloat16 overloads (#2192)
* Include cuda_bf16.h for bfloat16 overloads

* Add NO_GPU_MULTI(Eig) in cuda backend
2025-05-16 06:44:42 -07:00
Jack Wind
7ff5c41e06 Add set_threadgroup_memory_length to CommandEncoder (#2183) 2025-05-16 00:28:03 -07:00
Awni Hannun
602f43e3d1 fix conv grad (#2187) 2025-05-15 19:20:36 -07:00
Awni Hannun
a2cadb8218 real and imag properties (#2189) 2025-05-15 18:17:50 -07:00
Awni Hannun
c1eb9d05d9 non-symmetric eig and eigh (#2188) 2025-05-15 13:01:44 -07:00
Angelos Katharopoulos
cf6c939e86 Fix some complex vjps (#2178) 2025-05-14 23:37:12 -07:00
Angelos Katharopoulos
130df35e1b Add random normal distribution for complex numbers (#2182) 2025-05-13 22:43:45 -07:00
Cheng
0751263dec Fix typo in row_reduce_small (#2179) 2025-05-13 20:19:54 -07:00
Cheng
eca2f3eb97 Add remove_index utility (#2173) 2025-05-13 17:09:56 -07:00
Angelos Katharopoulos
3aa9cf3f9e Fix put_along_axis for empty arrays (#2181) 2025-05-13 14:27:53 -07:00
Awni Hannun
8f3d208dce Close a couple edge case bugs: hadamard and addmm on empty inputs (#2177)
* handle hadamard and addmm on empty inputs

* fix
2025-05-12 10:48:57 -07:00
Ivan Fioravanti
caaa3f1f8c Small typos in mx.metal deprecations (#2176) 2025-05-11 06:03:47 -07:00
Awni Hannun
659a51919f patch bump (#2162) 2025-05-09 14:35:14 -07:00
Awni Hannun
6661387066 Fix fft for integer overflow (#2161) 2025-05-09 14:25:12 -07:00
ATurker
a7fae8a176 fix: conv_general differences between gpu, cpu (#2070)
* fix general_conv padding

* fix bugs

* add test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-05-09 10:26:52 -07:00
Cheng
0cae0bdac8 CUDA backend: backbone (#2075) 2025-05-06 21:26:46 -07:00
Awni Hannun
5a1a5d5ed1 fix input coherent kernel launch (#2153) 2025-05-05 17:30:50 -07:00
Cheng
1683975acf Move common gpu primitives to backend/gpu (#2145) 2025-05-05 13:45:29 -07:00
Awni Hannun
af705590ac fix batched vector sdpa (#2152) 2025-05-05 13:13:03 -07:00
Awni Hannun
825124af8f fix bw for elementwise ops (#2151)
* fix bw for elementwise ops

* add compile

* fix

* fix

* fix

* fix
2025-05-05 06:15:04 -07:00
Awni Hannun
9c5e7da507 fix compile merging (#2150) 2025-05-02 15:08:50 -07:00
Angelos Katharopoulos
481349495b GPU Hadamard for large N (#1879) 2025-05-01 17:19:17 -07:00
Awni Hannun
9daa6b003f fix shapeless export (#2148) 2025-05-01 15:02:02 -07:00
Angelos Katharopoulos
a3a632d567 Fix the launcher when ran locally (#2147) 2025-05-01 12:56:09 -07:00
Awni Hannun
e496c5a4b4 fix integer overflow in qmm (#2143) 2025-04-30 09:28:56 -07:00
Cheng
ea890d8710 Remove metal-only tests (#2139) 2025-04-30 09:08:39 -07:00
Awni Hannun
aa5d84f102 Allow quant layer to be unfrozen (#2142) 2025-04-30 09:08:29 -07:00
Awni Hannun
f1606486d2 Generalize gpu backend (#2138)
* generalize gpu backend

* fix no_gpu build

* fix no_gpu build

* generalize gpu backend
2025-04-30 09:08:17 -07:00
Cheng
87720a8908 Fix building with uv (#2141) 2025-04-30 06:04:07 -07:00
Aashiq Dheeraj
bb6565ef14 add fftshift and ifftshift fft helpers (#2135)
* add fftshift and ifftshift fft helpers

* address comments

* axes have to be iterable

* fix fp error in roll + add test

---------

Co-authored-by: Aashiq Dheeraj <aashiq@aashiq-mbp-m4.local>
2025-04-29 22:13:45 -07:00
Awni Hannun
7bb063bcb3 Enable vjp for quantized scale and bias (#2129)
* Enable vjp for quantized scale and bias

* higher tol
2025-04-29 13:03:09 -07:00
Alex Chi Z.
b36dd472bb return library if it is successfully loaded (#2131) 2025-04-29 07:30:36 -07:00
hdeng-apple
167b759a38 Fix typos (#2136) 2025-04-29 07:26:05 -07:00
charan-003
99b9868859 Clarify dimension notation in conv1d, conv2d, and conv3d docstrings (#2123)
* Clarify dimension notation in conv1d, conv2d, and conv3d docstrings

* Updating transposed convs in conv1d, conv2d, and conv3d

---------

Co-authored-by: Sai Charan Arvapally <saicharan@Sais-MacBook-Pro.local>
2025-04-25 12:18:30 -07:00
1ndig0
6b2d5448f2 Fix the error message in mx.right_shift and mx.left_shift (#2121)
* update right_shift and lef_shift

* simplify

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-04-25 09:14:28 -07:00
Awni Hannun
eaf709b83e patch (#2119) 2025-04-24 16:11:07 -07:00
Angelos Katharopoulos
f0e70afff0 Fix swift pm load (#2117) 2025-04-24 10:58:29 -07:00
hdeng-apple
86984cad68 Remove static initializers (#2059)
* Remove static initializers in device.cpp, load.cpp, pocketfft.h

* Remove static initializer InTracing::trace_stack

* Remove static initializer of CompilerCache cache

* Revert changes in pocketfft.h

* Remove duplicate private section of thread_pool()
2025-04-24 06:14:49 -07:00
Awni Hannun
fbc89e3ced fix pinv (#2110) 2025-04-23 13:08:28 -07:00
hdeng-apple
38c1e720c2 Search mlx.metallib in macOS framework "Resources" dir (#2061)
---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-04-23 09:53:13 -07:00
Param Thakkar
600e87e03c Added output_padding parameters in conv_transpose (#2092) 2025-04-23 09:26:33 -07:00
Hyunsung Lee
3836445241 Add broadcast_shapes in python API (#2091) 2025-04-22 18:57:39 -07:00
Yury Popov
1d2c9d6a07 Complex scan (#2094) 2025-04-22 18:56:28 -07:00
Awni Hannun
e8ac6bd2f5 irfft throws instead of segfaults on scalars (#2109) 2025-04-22 10:25:55 -07:00
Awni Hannun
fdadc4f22c Add more complex unary ops (#2101) 2025-04-21 13:04:54 -07:00
Awni Hannun
79b527f45f conv vmap (#2102) 2025-04-21 13:04:39 -07:00
Awni Hannun
dc4eada7f0 Use unordered map for kwargs in export/import (#2087)
* use unordered map for kwargs in export/import

* comment
2025-04-21 07:17:22 -07:00
Cheng
70ebc3b598 Return const ref in array::data_shared_ptr (#2100) 2025-04-21 07:17:09 -07:00
Cheng
b13f2aed16 Introduce macros for dispatching dynamic dtypes as static types (#2073) 2025-04-19 06:16:30 -07:00
Param Thakkar
5f04c0f818 Fixed shift operations issue (#2080)
* Fixed shift operations issue

* Added tests and fixes

* Fixed loop syntax error

* Added tests for bool

* Fixed typo
2025-04-18 14:28:33 -07:00
Awni Hannun
55935ccae7 fix py gc edge case (#2079) 2025-04-18 12:46:53 -07:00
Awni Hannun
b529515eb1 minor bump (#2081) 2025-04-17 14:57:11 -07:00
Angelos Katharopoulos
3cde719eb7 Route to gather qmm only for many tokens per expert (#2082) 2025-04-17 14:53:08 -07:00
Angelos Katharopoulos
5de6d94a90 Gather qmm batched kernel and refactoring of quantized (#2078) 2025-04-17 13:53:11 -07:00
Angelos Katharopoulos
99eefd2ec0 Gather mm new kernel and small refactoring (#2040) 2025-04-14 16:37:36 -07:00
Yury Popov
e9e268336b LogCumSumExp (#2069) 2025-04-13 01:27:29 -07:00
Awni Hannun
7275ac7523 Fix release build (#2072) 2025-04-12 20:41:58 -07:00
Angelos Katharopoulos
c4189a38e4 Add float mask to sdpa vector (#2068) 2025-04-11 17:29:40 -07:00
Awni Hannun
68d1b3256b nit: fix exception handling (#2066) 2025-04-11 14:12:08 -07:00
Awni Hannun
9c6953bda7 Fix stubgen (#2065)
* Fix stubgen

* add multi optim to docs
2025-04-11 12:02:54 -07:00
Awni Hannun
ef7ece9851 fix fft bug (#2062) 2025-04-10 19:41:27 -07:00
Angelos Katharopoulos
ddaa4b7dcb Fix the test and add custom min/max reductions for uncommon MPI types (#2060) 2025-04-10 17:01:17 -07:00
Cheng
dfae2c6989 Fix MSVC build due to use of M_LN2 (#2058) 2025-04-10 07:41:41 -07:00
Anastasiia Filippova
515f104926 Min / max reductions (#2041) 2025-04-09 23:22:20 -07:00
Angelos Katharopoulos
9ecefd56db Do not load the default lib if another is requested (#2055) 2025-04-09 13:31:38 -07:00
Awni Hannun
e5d35aa187 no sdpa in grad (#2054) 2025-04-08 19:13:54 -07:00
Awni Hannun
00794c42bc Fix causal mask sdpa vec (#2053)
* fix sdpa vector causal mask

* test
2025-04-08 09:11:23 -07:00
Cheng
08a1bf3f10 Remove Event::Signal() (#2052) 2025-04-08 06:20:27 -07:00
Awni Hannun
60c4154346 Only request residency once (#2051) 2025-04-07 10:47:51 -07:00
Awni Hannun
f2c85308c1 add a half simd gemm fallback (#2046)
* add a half simd gemm fallback

* nit
2025-04-07 09:31:29 -07:00
Awni Hannun
1a28b69ee2 only add to residency set once (#2049) 2025-04-06 17:38:25 -07:00
Cheng
ba09f01ce8 Remove test of converting negative float to uint (#2048) 2025-04-06 06:21:46 -07:00
Cheng
6cf48872b7 wait_for_one should wait for task to finish (#2047) 2025-04-05 20:05:16 -07:00
Angelos Katharopoulos
7b3b8fa000 Fix ci release (#2045) 2025-04-04 20:25:01 -07:00
Awni Hannun
ec5e2aae61 nit in doc (#2044) 2025-04-04 12:04:17 -07:00
Awni Hannun
86389bf970 patch bump (#2043) 2025-04-03 13:15:18 -07:00
Jagrit Digani
3290bfa690 Add new sdpa function overload (#2035)
* Add new sdpa function overload

* Address comments

* Remove std::varaint from cpp sdpa function
2025-04-03 11:58:28 -07:00
Jagrit Digani
8777fd104f Depthwise Conv2D optimization (#2036)
- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions
- Add related tests
2025-04-03 09:42:04 -07:00
Awni Hannun
c41f7565ed fix softmax / logsumexp (#2042) 2025-04-03 08:32:59 -07:00
Awni Hannun
9ba81e3da4 tune quant dispatch (#2031) 2025-04-02 20:05:54 -07:00
Awni Hannun
c23888acd7 Fix build warning (#2033) 2025-04-01 14:42:27 -07:00
Awni Hannun
f98ce25ab9 fix residency set for real (#2032) 2025-04-01 12:59:48 -07:00
Awni Hannun
de5f38fd48 Custom logsumexp (#2028)
* initial custom logsumexp

* more tests

* comments + fix
2025-03-31 07:36:55 -07:00
Angelos Katharopoulos
ec2854b13a Swap -inf for finite_minimum value (#2029) 2025-03-30 21:55:04 -07:00
Stephen Panaro
90823d2938 Add missing funcs to docs (#2021) 2025-03-30 18:29:33 -07:00
Jesper Stemann Andersen
5f5770e3a2 Fix CPU sign for unsigned ints (#2024)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-03-30 17:56:59 -07:00
Awni Hannun
28f39e9038 Log for complex numbers in Metal (#2025)
* Log for complex numbers in Metal

* fix log2
2025-03-30 17:04:38 -07:00
Awni Hannun
b2d2b37888 fix residency set clearing (#2027) 2025-03-30 16:27:26 -07:00
Awni Hannun
fe597e141c add pinv to doc (#2020) 2025-03-30 15:54:18 -07:00
Yi Wang
72ca1539e0 Remove unused variable in /setup.py (#2026)
This is a follow up of https://github.com/ml-explore/mlx/pull/2011
2025-03-30 12:52:33 -07:00
Awni Hannun
13b26775f1 use minimum deployment target (#2016) 2025-03-28 14:31:53 -07:00
Awni Hannun
05d7118561 causal vector sdpa (#2018)
* causal vector sdpa

* get rid of memory threshold
2025-03-28 12:36:13 -07:00
Awni Hannun
98b901ad66 enable complex gemm (#2017) 2025-03-28 10:45:13 -07:00
Awni Hannun
5580b47291 iinfo and scalar overflow detection (#2009) 2025-03-27 19:54:56 -07:00
Awni Hannun
bc62932984 sdpa specialization for head dim 256 (#2007) 2025-03-27 19:31:25 -07:00
Awni Hannun
a6b5d6e759 revise cmake minimum for doctest (#2014) 2025-03-27 19:30:58 -07:00
Yi Wang
a8931306e1 Remove unused variable in CMakeBuild (#2011)
Fix https://github.com/ml-explore/mlx/issues/2010
2025-03-27 16:00:51 -07:00
Yi Wang
fecdb8717e Polish CONTRIBUTING>md (#2005) 2025-03-25 19:06:34 -07:00
Awni Hannun
916fd273ea wire cache (#2006) 2025-03-25 18:54:01 -07:00
Yi Wang
0da8506552 Update docs for extensions (#2004) 2025-03-25 18:35:03 -07:00
Cheng
eda7a7b43e Do not join threads during process exit on Windows (#1738) 2025-03-25 06:33:08 -07:00
Chunyang Wen
022eabb734 Remove unused import (#1987) 2025-03-24 20:19:32 -07:00
Awni Hannun
aba899cef8 patch bump (#2000) 2025-03-24 12:47:05 -07:00
Jagrit Digani
6a40e1c176 Fix looping limit in causal attention (#1999) 2025-03-24 12:28:00 -07:00
Jesper Stemann Andersen
9307b2ab8b Fixed 32-bit platform support for distributed/ring implementation (#1996)
Replaced unsigned long integer literals with size_t literals in ring implementation, e.g., 1UL with size_t(1).
2025-03-24 08:08:40 -07:00
Jesper Stemann Andersen
522d8d3917 Added missing netinet/in.h include that fixes build on FreeBSD (#1997)
Defines IPPROTO_TCP.
2025-03-24 08:07:34 -07:00
Awni Hannun
a84cc0123f promote mask when needed (#1998) 2025-03-23 19:58:28 -07:00
Andrey Velichkevich
f018e248cd fix(backend): Include algorithm library in Allocator (#1992)
Signed-off-by: Andrey Velichkevich <andrey.velichkevich@gmail.com>
2025-03-22 21:27:51 -07:00
Awni Hannun
cfd7237a80 fix docs (#1991) 2025-03-21 19:58:53 -07:00
Angelos Katharopoulos
4eef8102c9 Distributed layers (#1270) 2025-03-21 13:52:17 -07:00
Angelos Katharopoulos
69e4dd506b Add a ring all gather (#1985) 2025-03-21 13:36:51 -07:00
Angelos Katharopoulos
25814a9458 Disable mpi on version mismatch (#1989) 2025-03-21 13:36:26 -07:00
Awni Hannun
2a980a76ce Add stats and limit to common allocator and enable tests (#1988)
* add stats to common allocator and enable tests

* linux memory and default

* fix
2025-03-21 12:28:36 -07:00
Angelos Katharopoulos
d343782c8b Cross platform libmpi loading (#1975) 2025-03-21 11:23:10 -07:00
Awni Hannun
4e1994e9d7 move memory APIs into top level mlx.core (#1982) 2025-03-21 07:25:12 -07:00
jiyzhang
65a38c452b update the formula of smooth_l1_loss (#1986) 2025-03-21 06:25:23 -07:00
Awni Hannun
7b7e2352cd fix malloc or wait deadlock (#1976) 2025-03-20 16:48:43 -07:00
Awni Hannun
1177d28395 patch bump (#1981) 2025-03-20 15:12:22 -07:00
Awni Hannun
005e7efa64 fix mask in sdpa (#1980)
* fix mask in sdpa

* fix attention mask

* Re-enable routing for array mask

---------

Co-authored-by: Jagrit Digani <digani@apple.com>
2025-03-20 14:53:12 -07:00
Jagrit Digani
b42d13ec84 Update attention tests to show diff, disable array masks (#1978) 2025-03-20 14:25:38 -07:00
Jagrit Digani
9adcd1a650 Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
2025-03-20 11:01:32 -07:00
Awni Hannun
3c164fca8c Fix multistream GPU deadlock (#1969)
* fix multistream GPU deadlock

* comments
2025-03-20 07:19:47 -07:00
jiyzhang
95e335db7b Update smooth_l1_loss in losses.py (#1974)
According the definition of smooth_l1_loss, the line 

diff = predictions - targets

Should be updated to 

diff = mx.abs(predictions - targets)

After the modification, the result is consistent with PyTorch smooth_l1_loss
2025-03-19 20:19:02 -07:00
Awni Hannun
f90206ad74 Guard nullptr dereference (#1972)
* guard nullptr dereference

* comment
2025-03-19 16:24:10 -07:00
Chunyang Wen
3779150750 refactor: all use schedule (#1973) 2025-03-19 11:24:04 -07:00
403 changed files with 29334 additions and 6521 deletions

View File

@@ -16,6 +16,9 @@ parameters:
linux_release: linux_release:
type: boolean type: boolean
default: false default: false
cuda_release:
type: boolean
default: false
jobs: jobs:
build_documentation: build_documentation:
@@ -24,8 +27,8 @@ jobs:
type: boolean type: boolean
default: false default: false
macos: macos:
xcode: "15.2.0" xcode: "16.2.0"
resource_class: macos.m1.medium.gen1 resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
@@ -38,7 +41,7 @@ jobs:
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
pip install -r docs/requirements.txt pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v pip install . -v
- when: - when:
condition: condition:
not: << parameters.upload-docs >> not: << parameters.upload-docs >>
@@ -89,16 +92,13 @@ jobs:
pip install numpy pip install numpy
sudo apt-get update sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py build_ext --inplace python3 setup.py build_ext --inplace
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \ CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python3 setup.py develop python3 setup.py develop
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -110,6 +110,8 @@ jobs:
name: Run Python tests name: Run Python tests
command: | command: |
python3 -m unittest discover python/tests -v python3 -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run: - run:
name: Build CPP only name: Build CPP only
command: | command: |
@@ -124,10 +126,15 @@ jobs:
parameters: parameters:
xcode_version: xcode_version:
type: string type: string
default: "15.2.0" default: "16.2.0"
macosx_deployment_target:
type: string
default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1 environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m2pro.medium
steps: steps:
- checkout - checkout
- run: - run:
@@ -148,8 +155,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
pip install -e . -v pip install -e . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -199,13 +205,34 @@ jobs:
name: Run Python tests with JIT name: Run Python tests with JIT
command: | command: |
source env/bin/activate source env/bin/activate
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \ CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \ LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \ METAL_DEBUG_ERROR_MODE=0 \
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
cuda_build_and_test:
machine:
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Install Python package
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
source env/bin/activate
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install -e ".[dev]"
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
build_release: build_release:
parameters: parameters:
python_version: python_version:
@@ -213,13 +240,18 @@ jobs:
default: "3.9" default: "3.9"
xcode_version: xcode_version:
type: string type: string
default: "15.2.0" default: "16.2.0"
build_env: build_env:
type: string type: string
default: "" default: ""
macosx_deployment_target:
type: string
default: ""
macos: macos:
xcode: << parameters.xcode_version >> xcode: << parameters.xcode_version >>
resource_class: macos.m1.medium.gen1 resource_class: m2pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps: steps:
- checkout - checkout
- run: - run:
@@ -240,8 +272,7 @@ jobs:
name: Install Python package name: Install Python package
command: | command: |
source env/bin/activate source env/bin/activate
DEV_RELEASE=1 \ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
pip install . -v pip install . -v
- run: - run:
name: Generate package stubs name: Generate package stubs
@@ -253,9 +284,7 @@ jobs:
name: Build Python package name: Build Python package
command: | command: |
source env/bin/activate source env/bin/activate
<< parameters.build_env >> \ << parameters.build_env >> python -m build -w
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
python -m build -w
- when: - when:
condition: << parameters.build_env >> condition: << parameters.build_env >>
steps: steps:
@@ -303,14 +332,10 @@ jobs:
pip install patchelf pip install patchelf
pip install build pip install build
pip install twine pip install twine
<< parameters.extra_env >> \ << parameters.extra_env >> pip install . -v
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
<< parameters.extra_env >> \ << parameters.extra_env >> python -m build --wheel
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
python -m build --wheel
auditwheel show dist/* auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64 auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run: - run:
@@ -321,6 +346,46 @@ jobs:
- store_artifacts: - store_artifacts:
path: wheelhouse/ path: wheelhouse/
build_cuda_release:
parameters:
python_version:
type: string
default: "3.9"
extra_env:
type: string
default: "DEV_RELEASE=1"
machine:
image: linux-cuda-12:default
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- run:
name: Build wheel
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
python -m venv env
source env/bin/activate
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
pip install ".[dev]" -v
python setup.py generate_stubs
<< parameters.extra_env >> \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build --wheel
bash python/scripts/repair_cuda.sh
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
workflows: workflows:
build_and_test: build_and_test:
when: when:
@@ -335,8 +400,9 @@ workflows:
- mac_build_and_test: - mac_build_and_test:
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test - linux_build_and_test
- cuda_build_and_test
- build_documentation - build_documentation
build_pypi_release: build_pypi_release:
@@ -355,8 +421,70 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "PYPI_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "PYPI_RELEASE=1"
- build_documentation: - build_documentation:
filters: filters:
tags: tags:
@@ -364,6 +492,16 @@ workflows:
branches: branches:
ignore: /.*/ ignore: /.*/
upload-docs: true upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]
prb: prb:
when: when:
@@ -379,9 +517,11 @@ workflows:
requires: [ hold ] requires: [ hold ]
matrix: matrix:
parameters: parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0"]
- linux_build_and_test: - linux_build_and_test:
requires: [ hold ] requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
nightly_build: nightly_build:
when: when:
and: and:
@@ -392,7 +532,54 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
weekly_build: weekly_build:
when: when:
and: and:
@@ -403,8 +590,70 @@ workflows:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
xcode_version: ["16.2.0", "15.0.0"]
exclude:
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "13.5"
xcode_version: "16.2.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "14.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.9"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.10"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.11"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.12"
build_env: "DEV_RELEASE=1"
- macosx_deployment_target: "15.0"
xcode_version: "15.0.0"
python_version: "3.13"
build_env: "DEV_RELEASE=1"
linux_test_release: linux_test_release:
when: when:
and: and:
@@ -416,3 +665,14 @@ workflows:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"] extra_env: ["PYPI_RELEASE=1"]
cuda_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.cuda_release >>
jobs:
- build_cuda_release:
matrix:
parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

1
.gitignore vendored
View File

@@ -36,6 +36,7 @@ share/python-wheels/
.installed.cfg .installed.cfg
*.egg *.egg
MANIFEST MANIFEST
uv.lock
# vim # vim
*.swp *.swp

View File

@@ -34,6 +34,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON) option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
@@ -83,6 +84,10 @@ if(MLX_BUILD_METAL)
set(QUARTZ_LIB "-framework QuartzCore") set(QUARTZ_LIB "-framework QuartzCore")
endif() endif()
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
if(MLX_BUILD_METAL AND NOT METAL_LIB) if(MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU") message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
@@ -212,24 +217,6 @@ else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()
find_package(MPI)
if(MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET)
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif(MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING "MPI found but mpirun is not available. Building without MPI.")
else()
set(MPI_FOUND FALSE)
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
endif()
endif()
message(STATUS "Downloading json") message(STATUS "Downloading json")
FetchContent_Declare( FetchContent_Declare(
json json
@@ -244,6 +231,9 @@ target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}> mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
$<INSTALL_INTERFACE:include>) $<INSTALL_INTERFACE:include>)
# Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
FetchContent_Declare( FetchContent_Declare(
fmt fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git GIT_REPOSITORY https://github.com/fmtlib/fmt.git

View File

@@ -17,13 +17,13 @@ possible.
You can also run the formatters manually as follows: You can also run the formatters manually as follows:
``` ```shell
clang-format -i file.cpp clang-format -i file.cpp
``` ```
``` ```shell
black file.py black file.py
``` ```
or run `pre-commit run --all-files` to check all files in the repo. or run `pre-commit run --all-files` to check all files in the repo.

View File

@@ -1,4 +1,6 @@
include CMakeLists.txt include CMakeLists.txt
include mlx.pc.in
recursive-include mlx/ * recursive-include mlx/ *
include cmake/*
include python/src/* include python/src/*
include python/mlx/py.typed # support type hinting as in PEP-561 include python/mlx/py.typed # support type hinting as in PEP-561

View File

@@ -1,5 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cstring>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>

View File

@@ -192,6 +192,22 @@ void time_reductions() {
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); }; auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
TIME(argmin_along_1); TIME(argmin_along_1);
auto indices = mx::array({1});
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
std::vector<int> axes{0};
auto b = scatter(a, {indices}, updates, axes);
mx::eval(b);
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
TIME(max_along_0);
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
TIME(max_along_1);
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
TIME(min_along_0);
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
TIME(min_along_1);
} }
void time_gather_scatter() { void time_gather_scatter() {

View File

@@ -5,6 +5,7 @@ import os
import time import time
import torch import torch
import torch.cuda
import torch.mps import torch.mps
@@ -44,8 +45,10 @@ def bench(f, *args):
def sync_if_needed(x): def sync_if_needed(x):
if x.device != torch.device("cpu"): if x.device == torch.device("mps"):
torch.mps.synchronize() torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()
@torch.no_grad() @torch.no_grad()
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
sync_if_needed(x) sync_if_needed(x)
@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)
@torch.no_grad() @torch.no_grad()
def softmax(axis, x): def softmax(axis, x):
ys = [] ys = []
@@ -340,7 +351,11 @@ if __name__ == "__main__":
args.axis.pop(0) args.axis.pop(0)
torch.set_num_threads(1) torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps" device = "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"
types = args.dtype types = args.dtype
if not types: if not types:
@@ -460,5 +475,8 @@ if __name__ == "__main__":
elif args.benchmark == "selu": elif args.benchmark == "selu":
print(bench(selu, x)) print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else: else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.") raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -0,0 +1,107 @@
import math
import time
import mlx.core as mx
import numpy as np
import torch
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
dtype = "float32"
shapes = (
(4, 32, 32, 21, 3, 3, 128),
(4, 32, 32, 21, 3, 3, 37),
(4, 32, 32, 370, 3, 3, 370),
(4, 32, 32, 370, 7, 7, 128),
(2, 320, 640, 21, 7, 7, 21),
)
for N, H, W, C, kh, kw, O in shapes:
time_mlx, time_torch = bench_shape(
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -0,0 +1,74 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_mm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = x @ w1.T
x = x @ w2.T
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_mm()

View File

@@ -0,0 +1,84 @@
# Copyright © 2025 Apple Inc.
import mlx.core as mx
from time_utils import time_fn
N = 1024
D = 1024
M = 1024
E = 32
I = 4
def gather_sort(x, indices):
N, M = indices.shape
indices = indices.flatten()
order = mx.argsort(indices)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // M], indices[order], inv_order
def scatter_unsort(x, inv_order, shape=None):
x = x[inv_order]
if shape is not None:
x = mx.unflatten(x, 0, shape)
return x
def gather_mm_simulate(x, w, indices):
x, idx, inv_order = gather_sort(x, indices)
for i in range(2):
y = mx.concatenate(
[
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
for i, j in enumerate(idx.tolist())
],
axis=0,
)
x = y[:, None]
x = scatter_unsort(x, inv_order, indices.shape)
return x
def time_gather_qmm():
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
w1 = mx.random.normal((E, M, D)) / 1024**0.5
w2 = mx.random.normal((E, D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
mx.eval(x, w1, w2, indices, sorted_indices)
def gather_mm(x, w1, w2, indices, sort):
idx = indices
inv_order = None
if sort:
x, idx, inv_order = gather_sort(x, indices)
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
if sort:
x = scatter_unsort(x, inv_order, indices.shape)
return x
time_fn(gather_mm, x, w1, w2, indices, False)
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
time_fn(gather_mm, x, w1, w2, indices, True)
x = mx.random.normal((N * I, D)) / 1024**0.5
w1 = mx.random.normal((M, D)) / 1024**0.5
w2 = mx.random.normal((D, M)) / 1024**0.5
w1 = mx.quantize(w1)
w2 = mx.quantize(w2)
mx.eval(x, w1, w2)
def equivalent_matmul(x, w1, w2):
x = mx.quantized_matmul(x, *w1, transpose=True)
x = mx.quantized_matmul(x, *w2, transpose=True)
return x
time_fn(equivalent_matmul, x, w1, w2)
if __name__ == "__main__":
time_gather_qmm()

View File

@@ -1,5 +1,7 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from functools import partial
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from time_utils import time_fn from time_utils import time_fn
@@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
return y return y
def time_layer_norm(): def time_layer_norm(N, dt):
L = 1024
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum() f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum() f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0, 1, 2)) g1 = mx.grad(f1, argnums=(0, 1, 2))
g2 = mx.grad(f2, argnums=(0, 1, 2)) g2 = mx.grad(f2, argnums=(0, 1, 2))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16) w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16) b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y) mx.eval(x, w, b, y)
def layer_norm_loop(g, x, w, b): def layer_norm_loop(f, x, w, b):
for _ in range(32):
x = f(x, w, b)
return x
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
def layer_norm_grad_loop(g, x, w, b):
gx, gw, gb = x, w, b gx, gw, gb = x, w, b
for _ in range(32): for _ in range(32):
gx, gw, gb = g(gx, gw, gb, y) gx, gw, gb = g(gx, gw, gb, y)
return gx, gw, gb return gx, gw, gb
time_fn(layer_norm_loop, g1, x, w, b) time_fn(layer_norm_grad_loop, g1, x, w, b)
time_fn(layer_norm_loop, g2, x, w, b) time_fn(layer_norm_grad_loop, g2, x, w, b)
time_fn(layer_norm_loop, mx.compile(g1), x, w, b) time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
time_fn(layer_norm_loop, mx.compile(g2), x, w, b) time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum() f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum() f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
g1 = mx.grad(f1, argnums=(0,)) g1 = mx.grad(f1, argnums=(0,))
g2 = mx.grad(f2, argnums=(0,)) g2 = mx.grad(f2, argnums=(0,))
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) x = mx.random.uniform(shape=(8, L, N)).astype(dt)
w = mx.random.uniform(shape=(4096,)).astype(mx.float16) w = mx.random.uniform(shape=(N,)).astype(dt)
b = mx.random.uniform(shape=(4096,)).astype(mx.float16) b = mx.random.uniform(shape=(N,)).astype(dt)
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16) y = mx.random.uniform(shape=(8, L, N)).astype(dt)
mx.eval(x, w, b, y) mx.eval(x, w, b, y)
def layer_norm_loop(g, x): def layer_norm_grad_x_loop(g, x):
gx = x gx = x
for _ in range(32): for _ in range(32):
gx = g(gx, y) gx = g(gx, y)
return gx return gx
time_fn(layer_norm_loop, g1, x) time_fn(layer_norm_grad_x_loop, g1, x)
time_fn(layer_norm_loop, g2, x) time_fn(layer_norm_grad_x_loop, g2, x)
time_fn(layer_norm_loop, mx.compile(g1), x) time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
time_fn(layer_norm_loop, mx.compile(g2), x) time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
if __name__ == "__main__": if __name__ == "__main__":
time_layer_norm() for dt in [mx.float32, mx.float16, mx.bfloat16]:
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
print(dt, n)
time_layer_norm(n, dt)

View File

@@ -28,11 +28,34 @@ def bench(f, *args):
return (e - s) * 1e-9 return (e - s) * 1e-9
def mlx_sdpa_fused_inner(q, k, v, scale): def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None) np_dtype = getattr(np, dtype)
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if mask is not None:
if mask == "additive":
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
mask = mx.array(mask_np)
elif mask == "bool":
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
mask = mx.array(mask_np)
return q_mx, k_mx, v_mx, scale, mask
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
q_dtype = q.dtype q_dtype = q.dtype
q = q * mx.array(scale, q_dtype) q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3] n_q_heads = q.shape[-3]
@@ -41,6 +64,7 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
B = q.shape[0] B = q.shape[0]
L = q.shape[2] L = q.shape[2]
kL = k.shape[2]
if n_repeats > 1: if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
@@ -48,10 +72,27 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
v = mx.expand_dims(v, 2) v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2) scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype) if mask is not None:
else:
scores = mx.softmax(scores, axis=-1) if mask == "causal":
q_offset = max(0, kL - L)
q_indices = mx.arange(q_offset, q_offset + L)
k_indices = mx.arange(kL)
mask = q_indices[:, None] >= k_indices[None]
if n_repeats > 1 and mask.ndim >= 3:
if mask.shape[-3] == 1:
mask = mx.expand_dims(mask, -3)
else:
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -np.float32(np.inf))
else:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = scores @ v out = scores @ v
if n_repeats > 1: if n_repeats > 1:
@@ -60,74 +101,55 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
return out return out
def mlx_spda_unfused(q, k, v, scale, transpose): def mlx_fused_attn(q, k, v, scale, mask):
q_out = q return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
if transpose: if transpose:
k = mx.transpose(k, (0, 2, 1, 3)) q_t = mx.transpose(q, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3)) k_t = mx.transpose(k, (0, 2, 1, 3))
v_t = mx.transpose(v, (0, 2, 1, 3))
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
return mx.transpose(o_t, (0, 2, 1, 3))
else:
return f(q, k, v, scale=scale, mask=mask)
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
q_out = q
for i in range(N_iter_func): for i in range(N_iter_func):
if transpose: q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out) mx.eval(q_out)
return q_out return q_out
def mlx_spda_fused(q, k, v, scale, transpose): def bench_shape(
q_out = q B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
if transpose: ):
k = mx.transpose(k, (0, 2, 1, 3)) q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
v = mx.transpose(v, (0, 2, 1, 3)) B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
) )
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype) time_mlx_unfused = bench(
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) )
time_mlx_fused = bench(
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
scale = math.sqrt(1.0 / head_dim) o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
o_mlx_unfused = do_attention(
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
)
q_mx = mx.array(q_np) atol = 1e-5 if dtype == "float32" else 2e-4
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose) if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print( print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}" f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
) )
return time_mlx_fused, time_mlx_unfused return time_mlx_fused, time_mlx_unfused
@@ -151,39 +173,51 @@ if __name__ == "__main__":
( 1, 128, 128, 64, 32, 32), ( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32), ( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32), ( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32), ( 1, 1024, 1024, 64, 32, 8),
( 1, 2048, 2048, 64, 32, 32), ( 1, 2048, 2048, 64, 32, 8),
( 1, 4096, 4096, 64, 32, 32), ( 1, 4096, 4096, 64, 32, 8),
) )
shapes_80 = ( shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh) # ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32), ( 1, 1024, 1024, 80, 32, 8),
( 1, 2048, 2048, 80, 32, 32), ( 1, 2048, 2048, 80, 32, 8),
( 1, 4096, 4096, 80, 32, 32), ( 1, 4096, 4096, 80, 32, 8),
) )
shapes_128 = ( shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh) # ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32), ( 1, 1024, 1024, 128, 32, 8),
( 1, 2048, 2048, 128, 32, 32), ( 1, 2048, 2048, 128, 32, 8),
( 1, 4096, 4096, 128, 32, 32), ( 1, 4096, 4096, 128, 32, 8),
) )
# fmt: on # fmt: on
shapes = shapes_64 + shapes_80 + shapes_128 shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%") masks = [None, "bool", "causal"]
print(
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
)
for dtype in dtypes: for dtype in dtypes:
for transpose in transposes: for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes: for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype) for mask_in in masks:
time_mlx_fused, time_mlx_unfused = bench_shape( time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose B,
) qsl,
diff = time_mlx_unfused / time_mlx_fused - 1.0 ksl,
t_str = 1 if transpose else 0 head_dim,
print( n_q_heads,
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%" n_kv_heads,
) dtype,
transpose,
mask_in,
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -51,6 +51,20 @@ def time_maximum():
time_fn(mx.maximum, a, b) time_fn(mx.maximum, a, b)
def time_max():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.max, a, 0)
def time_min():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.min, a, 0)
def time_negative(): def time_negative():
a = mx.random.uniform(shape=(10000, 1000)) a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a) mx.eval(a)
@@ -108,6 +122,8 @@ if __name__ == "__main__":
time_add() time_add()
time_matmul() time_matmul()
time_min()
time_max()
time_maximum() time_maximum()
time_exp() time_exp()
time_negative() time_negative()

View File

@@ -11,13 +11,14 @@ include(CMakeParseArguments)
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of # Args: TARGET: Custom target to be added for the metal library TITLE: Name of
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List # the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers) # files (like headers) DEBUG: Boolean, if true, enables debug compile options
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
# #
# clang format on # clang format on
macro(mlx_build_metallib) macro(mlx_build_metallib)
# Parse args # Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS) set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
@@ -26,6 +27,10 @@ macro(mlx_build_metallib)
# Collect compile options # Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
-frecord-sources)
endif()
# Prepare metallib build command # Prepare metallib build command
add_custom_command( add_custom_command(

View File

@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES FULL_PATH_NAMES = YES
RECURSIVE = YES RECURSIVE = YES
GENERATE_HTML = YES GENERATE_HTML = NO
GENERATE_LATEX = NO GENERATE_LATEX = NO
GENERATE_XML = YES GENERATE_XML = YES
XML_PROGRAMLISTING = YES XML_PROGRAMLISTING = YES

View File

@@ -10,7 +10,7 @@ import mlx.core as mx
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "MLX" project = "MLX"
copyright = "2023, MLX Contributors" copyright = "2023, Apple"
author = "MLX Contributors" author = "MLX Contributors"
version = ".".join(mx.__version__.split(".")[:3]) version = ".".join(mx.__version__.split(".")[:3])
release = version release = version

View File

@@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
Simple Example Simple Example
-------------- --------------
.. currentmodule:: mlx.core
Let's write a custom kernel that computes ``exp`` elementwise: Let's write a custom kernel that computes ``exp`` elementwise:
.. code-block:: python .. code-block:: python
def exp_elementwise(a: mx.array): source = """
source = """ uint elem = thread_position_in_grid.x;
uint elem = thread_position_in_grid.x; T tmp = inp[elem];
T tmp = inp[elem]; out[elem] = metal::exp(tmp);
out[elem] = metal::exp(tmp); """
"""
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="myexp", name="myexp",
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
source=source, source=source,
) )
def exp_elementwise(a: mx.array):
outputs = kernel( outputs = kernel(
inputs=[a], inputs=[a],
template=[("T", mx.float32)], template=[("T", mx.float32)],
@@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
b = exp_elementwise(a) b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a)) assert mx.allclose(b, mx.exp(a))
Every time you make a kernel, a new Metal library is created and possibly
JIT compiled. To reduce the overhead from that, build the kernel once with
:func:`fast.metal_kernel` and then use it many times.
.. note:: .. note::
We are only required to pass the body of the Metal kernel in ``source``. Only pass the body of the Metal kernel in ``source``. The function
signature is generated automatically.
The full function signature will be generated using: The full function signature will be generated using:
@@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>; template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function. Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups. <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension. function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
``threadgroup`` size threadgroups. For optimal performance, each thread group
dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes. Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
generated code for debugging purposes.
Using Shape/Strides Using Shape/Strides
------------------- -------------------
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default. :func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. is ``True`` by default. This will copy the array inputs if needed
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims before the kernel is launched to ensure that the memory layout is row
when indexing. contiguous. Generally this makes writing the kernel easier, since we don't
have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
input array ``a`` if any are present in ``source``. ``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
We can then use MLX's built in indexing utils to fetch the right elements for each thread. present in ``source``. We can then use MLX's built in indexing utils to fetch
the right elements for each thread.
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``: Let's convert ``myexp`` above to support arbitrarily strided arrays without
relying on a copy from ``ensure_row_contiguous``:
.. code-block:: python .. code-block:: python
def exp_elementwise(a: mx.array): source = """
source = """ uint elem = thread_position_in_grid.x;
uint elem = thread_position_in_grid.x; // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); T tmp = inp[loc];
T tmp = inp[loc]; // Output arrays are always row contiguous
// Output arrays are always row contiguous out[elem] = metal::exp(tmp);
out[elem] = metal::exp(tmp); """
"""
kernel = mx.fast.metal_kernel( kernel = mx.fast.metal_kernel(
name="myexp_strided", name="myexp_strided",
input_names=["inp"], input_names=["inp"],
output_names=["out"], output_names=["out"],
source=source source=source
) )
def exp_elementwise(a: mx.array):
outputs = kernel( outputs = kernel(
inputs=[a], inputs=[a],
template=[("T", mx.float32)], template=[("T", mx.float32)],
@@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
.. code-block:: python .. code-block:: python
def grid_sample_ref(x, grid): def grid_sample_ref(x, grid):
N, H_in, W_in, _ = x.shape N, H_in, W_in, _ = x.shape
ix = ((grid[..., 0] + 1) * W_in - 1) / 2 ix = ((grid[..., 0] + 1) * W_in - 1) / 2
iy = ((grid[..., 1] + 1) * H_in - 1) / 2 iy = ((grid[..., 1] + 1) * H_in - 1) / 2
ix_nw = mx.floor(ix).astype(mx.int32) ix_nw = mx.floor(ix).astype(mx.int32)
iy_nw = mx.floor(iy).astype(mx.int32) iy_nw = mx.floor(iy).astype(mx.int32)
ix_ne = ix_nw + 1 ix_ne = ix_nw + 1
iy_ne = iy_nw iy_ne = iy_nw
ix_sw = ix_nw ix_sw = ix_nw
iy_sw = iy_nw + 1 iy_sw = iy_nw + 1
ix_se = ix_nw + 1 ix_se = ix_nw + 1
iy_se = iy_nw + 1 iy_se = iy_nw + 1
nw = (ix_se - ix) * (iy_se - iy) nw = (ix_se - ix) * (iy_se - iy)
ne = (ix - ix_sw) * (iy_sw - iy) ne = (ix - ix_sw) * (iy_sw - iy)
sw = (ix_ne - ix) * (iy - iy_ne) sw = (ix_ne - ix) * (iy - iy_ne)
se = (ix - ix_nw) * (iy - iy_nw) se = (ix - ix_nw) * (iy - iy_nw)
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
I_nw *= mask_nw[..., None] I_nw *= mask_nw[..., None]
I_ne *= mask_ne[..., None] I_ne *= mask_ne[..., None]
I_sw *= mask_sw[..., None] I_sw *= mask_sw[..., None]
I_se *= mask_se[..., None] I_se *= mask_se[..., None]
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
return output return output
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel`` Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
to write a fast GPU kernel for both the forward and backward passes. to write a fast GPU kernel for both the forward and backward passes.
First we'll implement the forward pass as a fused kernel: First we'll implement the forward pass as a fused kernel:
.. code-block:: python .. code-block:: python
@mx.custom_function source = """
def grid_sample(x, grid): uint elem = thread_position_in_grid.x;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
assert x.ndim == 4, "`x` must be 4D." int w_stride = C;
assert grid.ndim == 4, "`grid` must be 4D." int h_stride = W * w_stride;
int b_stride = H * h_stride;
B, _, _, C = x.shape uint grid_idx = elem / C * 2;
_, gN, gM, D = grid.shape float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
out_shape = (B, gN, gM, C) float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
assert D == 2, "Last dim of `grid` must be size 2." int ix_nw = floor(ix);
int iy_nw = floor(iy);
source = """ int ix_ne = ix_nw + 1;
uint elem = thread_position_in_grid.x; int iy_ne = iy_nw;
int H = x_shape[1];
int W = x_shape[2];
int C = x_shape[3];
int gH = grid_shape[1];
int gW = grid_shape[2];
int w_stride = C; int ix_sw = ix_nw;
int h_stride = W * w_stride; int iy_sw = iy_nw + 1;
int b_stride = H * h_stride;
uint grid_idx = elem / C * 2; int ix_se = ix_nw + 1;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2; int iy_se = iy_nw + 1;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix); T nw = (ix_se - ix) * (iy_se - iy);
int iy_nw = floor(iy); T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_ne = ix_nw + 1; int batch_idx = elem / C / gH / gW * b_stride;
int iy_ne = iy_nw; int channel_idx = elem % C;
int base_idx = batch_idx + channel_idx;
int ix_sw = ix_nw; T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
int iy_sw = iy_nw + 1; T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
int ix_se = ix_nw + 1; I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
int iy_se = iy_nw + 1; I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
T nw = (ix_se - ix) * (iy_se - iy); out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
T ne = (ix - ix_sw) * (iy_sw - iy); """
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int batch_idx = elem / C / gH / gW * b_stride; kernel = mx.fast.metal_kernel(
int channel_idx = elem % C; name="grid_sample",
int base_idx = batch_idx + channel_idx; input_names=["x", "grid"],
output_names=["out"],
source=source,
)
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; @mx.custom_function
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; def grid_sample(x, grid):
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; assert x.ndim == 4, "`x` must be 4D."
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; assert grid.ndim == 4, "`grid` must be 4D."
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; B, _, _, C = x.shape
""" _, gN, gM, D = grid.shape
kernel = mx.fast.metal_kernel( out_shape = (B, gN, gM, C)
name="grid_sample",
input_names=["x", "grid"], assert D == 2, "Last dim of `grid` must be size 2."
output_names=["out"],
source=source, outputs = kernel(
) inputs=[x, grid],
outputs = kernel( template=[("T", x.dtype)],
inputs=[x, grid], output_shapes=[out_shape],
template=[("T", x.dtype)], output_dtypes=[x.dtype],
output_shapes=[out_shape], grid=(np.prod(out_shape), 1, 1),
output_dtypes=[x.dtype], threadgroup=(256, 1, 1),
grid=(np.prod(out_shape), 1, 1), )
threadgroup=(256, 1, 1), return outputs[0]
)
return outputs[0]
For a reasonably sized input such as: For a reasonably sized input such as:
.. code-block:: python .. code-block:: python
x.shape = (8, 1024, 1024, 64) x.shape = (8, 1024, 1024, 64)
grid.shape = (8, 256, 256, 2) grid.shape = (8, 256, 256, 2)
On an M1 Max, we see a big performance improvement: On an M1 Max, we see a big performance improvement:
@@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
Grid Sample VJP Grid Sample VJP
--------------- ---------------
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
its custom vjp transform so MLX can differentiate it. define its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
requires a few extra ``mx.fast.metal_kernel`` features: requires a few extra :func:`fast.metal_kernel` features:
* ``init_value=0`` * ``init_value=0``
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
@@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
.. code-block:: python .. code-block:: python
@grid_sample.vjp source = """
def grid_sample_vjp(primals, cotangent, _): uint elem = thread_position_in_grid.x;
x, grid = primals int H = x_shape[1];
B, _, _, C = x.shape int W = x_shape[2];
_, gN, gM, D = grid.shape int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
assert D == 2, "Last dim of `grid` must be size 2." int gH = grid_shape[1];
int gW = grid_shape[2];
source = """ int w_stride = C;
uint elem = thread_position_in_grid.x; int h_stride = W * w_stride;
int H = x_shape[1]; int b_stride = H * h_stride;
int W = x_shape[2];
int C = x_shape[3];
// Pad C to the nearest larger simdgroup size multiple
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
int gH = grid_shape[1]; uint grid_idx = elem / C_padded * 2;
int gW = grid_shape[2]; float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int w_stride = C; int ix_nw = floor(ix);
int h_stride = W * w_stride; int iy_nw = floor(iy);
int b_stride = H * h_stride;
uint grid_idx = elem / C_padded * 2; int ix_ne = ix_nw + 1;
float ix = ((grid[grid_idx] + 1) * W - 1) / 2; int iy_ne = iy_nw;
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
int ix_nw = floor(ix); int ix_sw = ix_nw;
int iy_nw = floor(iy); int iy_sw = iy_nw + 1;
int ix_ne = ix_nw + 1; int ix_se = ix_nw + 1;
int iy_ne = iy_nw; int iy_se = iy_nw + 1;
int ix_sw = ix_nw; T nw = (ix_se - ix) * (iy_se - iy);
int iy_sw = iy_nw + 1; T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
int ix_se = ix_nw + 1; int batch_idx = elem / C_padded / gH / gW * b_stride;
int iy_se = iy_nw + 1; int channel_idx = elem % C_padded;
int base_idx = batch_idx + channel_idx;
T nw = (ix_se - ix) * (iy_se - iy); T gix = T(0);
T ne = (ix - ix_sw) * (iy_sw - iy); T giy = T(0);
T sw = (ix_ne - ix) * (iy - iy_ne); if (channel_idx < C) {
T se = (ix - ix_nw) * (iy - iy_nw); int cot_index = elem / C_padded * C + channel_idx;
T cot = cotangent[cot_index];
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
int batch_idx = elem / C_padded / gH / gW * b_stride; T I_nw = x[offset];
int channel_idx = elem % C_padded; gix -= I_nw * (iy_se - iy) * cot;
int base_idx = batch_idx + channel_idx; giy -= I_nw * (ix_se - ix) * cot;
}
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
T gix = T(0); T I_ne = x[offset];
T giy = T(0); gix += I_ne * (iy_sw - iy) * cot;
if (channel_idx < C) { giy -= I_ne * (ix - ix_sw) * cot;
int cot_index = elem / C_padded * C + channel_idx; }
T cot = cotangent[cot_index]; if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
T I_nw = x[offset]; T I_sw = x[offset];
gix -= I_nw * (iy_se - iy) * cot; gix -= I_sw * (iy - iy_ne) * cot;
giy -= I_nw * (ix_se - ix) * cot; giy += I_sw * (ix_ne - ix) * cot;
} }
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_ne = x[offset]; T I_se = x[offset];
gix += I_ne * (iy_sw - iy) * cot; gix += I_se * (iy - iy_nw) * cot;
giy -= I_ne * (ix - ix_sw) * cot; giy += I_se * (ix - ix_nw) * cot;
} }
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { }
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
T I_sw = x[offset]; T gix_mult = W / 2;
gix -= I_sw * (iy - iy_ne) * cot; T giy_mult = H / 2;
giy += I_sw * (ix_ne - ix) * cot;
}
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
T I_se = x[offset]; // Reduce across each simdgroup first.
gix += I_se * (iy - iy_nw) * cot; // This is much faster than relying purely on atomics.
giy += I_se * (ix - ix_nw) * cot; gix = simd_sum(gix);
} giy = simd_sum(giy);
}
T gix_mult = W / 2; if (thread_index_in_simdgroup == 0) {
T giy_mult = H / 2; atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
}
"""
kernel = mx.fast.metal_kernel(
name="grid_sample_grad",
input_names=["x", "grid", "cotangent"],
output_names=["x_grad", "grid_grad"],
source=source,
atomic_outputs=True,
)
// Reduce across each simdgroup first. @grid_sample.vjp
// This is much faster than relying purely on atomics. def grid_sample_vjp(primals, cotangent, _):
gix = simd_sum(gix); x, grid = primals
giy = simd_sum(giy); B, _, _, C = x.shape
_, gN, gM, D = grid.shape
if (thread_index_in_simdgroup == 0) { assert D == 2, "Last dim of `grid` must be size 2."
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); # pad the output channels to simd group size
} # so that our `simd_sum`s don't overlap.
""" simdgroup_size = 32
kernel = mx.fast.metal_kernel( C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
name="grid_sample_grad", grid_size = B * gN * gM * C_padded
input_names=["x", "grid", "cotangent"], outputs = kernel(
output_names=["x_grad", "grid_grad"], inputs=[x, grid, cotangent],
source=source, template=[("T", x.dtype)],
atomic_outputs=True, output_shapes=[x.shape, grid.shape],
) output_dtypes=[x.dtype, x.dtype],
# pad the output channels to simd group size grid=(grid_size, 1, 1),
# so that our `simd_sum`s don't overlap. threadgroup=(256, 1, 1),
simdgroup_size = 32 init_value=0,
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size )
grid_size = B * gN * gM * C_padded return outputs[0], outputs[1]
outputs = kernel(
inputs=[x, grid, cotangent],
template=[("T", x.dtype)],
output_shapes=[x.shape, grid.shape],
output_dtypes=[x.dtype, x.dtype],
grid=(grid_size, 1, 1),
threadgroup=(256, 1, 1),
init_value=0,
)
return outputs[0], outputs[1]
There's an even larger speed up for the vjp: There's an even larger speed up for the vjp:

View File

@@ -93,9 +93,9 @@ Primitives
^^^^^^^^^^^ ^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create outputs arrays given a input arrays. Further, a defines how to create output arrays given input arrays. Further, a
:class:`Primitive` has methods to run on the CPU or GPU and for function :class:`Primitive` has methods to run on the CPU or GPU and for function
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
more concrete: more concrete:
.. code-block:: C++ .. code-block:: C++
@@ -128,7 +128,7 @@ more concrete:
/** The vector-Jacobian product. */ /** The vector-Jacobian product. */
std::vector<array> vjp( std::vector<array> vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const array& cotan, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) override; const std::vector<array>& outputs) override;
@@ -247,9 +247,7 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
float alpha_, float alpha_,
float beta_, float beta_,
mx::Stream stream) { mx::Stream stream) {
// Allocate the output with `malloc_or_wait` which synchronously allocates out.set_data(mx::allocator::malloc(out.nbytes()));
// memory, potentially waiting if the system is under memory pressure
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// Get the CPU command encoder and register input and output arrays // Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream); auto& encoder = mx::cpu::get_command_encoder(stream);
@@ -393,17 +391,17 @@ below.
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
// Allocate output memory // Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
// Resolve name of kernel // Resolve name of kernel
std::ostringstream kname; std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out); kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available // Load the metal library
d.register_library("mlx_ext"); auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -471,7 +469,7 @@ one we just defined:
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents // Forward mode diff that pushes along the tangents
// The jvp transform on the primitive can built with ops // The jvp transform on the primitive can be built with ops
// that are scheduled on the same stream as the primitive // that are scheduled on the same stream as the primitive
// If argnums = {0}, we only push along x in which case the // If argnums = {0}, we only push along x in which case the
@@ -483,7 +481,7 @@ one we just defined:
auto scale_arr = array(scale, tangents[0].dtype()); auto scale_arr = array(scale, tangents[0].dtype());
return {multiply(scale_arr, tangents[0], stream())}; return {multiply(scale_arr, tangents[0], stream())};
} }
// If, argnums = {0, 1}, we take contributions from both // If argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta // which gives us jvp = tangent_x * alpha + tangent_y * beta
else { else {
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
@@ -737,7 +735,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}") print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}") print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}") print(f"c is correct: {mx.all(c == 6.0).item()}")
Output: Output:
@@ -745,7 +743,7 @@ Output:
c shape: [3, 4] c shape: [3, 4]
c dtype: float32 c dtype: float32
c correctness: True c is correct: True
Results Results
^^^^^^^ ^^^^^^^

View File

@@ -70,6 +70,7 @@ are the CPU and GPU.
python/fft python/fft
python/linalg python/linalg
python/metal python/metal
python/memory_management
python/nn python/nn
python/optimizers python/optimizers
python/distributed python/distributed

View File

@@ -30,6 +30,16 @@ MLX is also available on conda-forge. To install MLX with conda do:
conda install conda-forge::mlx conda install conda-forge::mlx
CUDA
^^^^
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
.. code-block:: shell
pip install mlx-cuda
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^
@@ -65,6 +75,8 @@ Build Requirements
Python API Python API
^^^^^^^^^^ ^^^^^^^^^^
.. _python install:
To build and install the MLX python library from source, first, clone MLX from To build and install the MLX python library from source, first, clone MLX from
`its GitHub repo <https://github.com/ml-explore/mlx>`_: `its GitHub repo <https://github.com/ml-explore/mlx>`_:
@@ -76,20 +88,20 @@ Then simply build and install MLX using pip:
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install . pip install .
For developing, install the package with development dependencies, and use an For developing, install the package with development dependencies, and use an
editable install: editable install:
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]" pip install -e ".[dev]"
Once the development dependencies are installed, you can build faster with: Once the development dependencies are installed, you can build faster with:
.. code-block:: shell .. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace python setup.py build_ext --inplace
Run the tests with: Run the tests with:
@@ -107,6 +119,8 @@ IDE:
C++ API C++ API
^^^^^^^ ^^^^^^^
.. _cpp install:
Currently, MLX must be built and installed from source. Currently, MLX must be built and installed from source.
Similarly to the python library, to build and install the MLX C++ library start Similarly to the python library, to build and install the MLX C++ library start
@@ -185,6 +199,7 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version xcrun -sdk macosx --show-sdk-version
Binary Size Minimization Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
@@ -213,6 +228,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists across reboots. Metal kernel cache persists across reboots.
Linux
^^^^^
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
For example on Ubuntu, run the following:
.. code-block:: shell
apt-get update -y
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
From here follow the instructions to install either the :ref:`Python <python
install>` or :ref:`C++ <cpp install>` APIs.
CUDA
^^^^
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
and the CUDA toolkit. For example on Ubuntu, run the following:
.. code-block:: shell
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
apt-get update -y
apt-get -y install cuda-toolkit-12-9
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
When building either the Python or C++ APIs make sure to pass the cmake flag
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
.. code-block:: shell
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
To build the C++ package run:
.. code-block:: shell
mkdir -p build && cd build
cmake .. -DMLX_BUILD_CUDA=ON && make -j
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@@ -19,6 +19,8 @@ Array
array.ndim array.ndim
array.shape array.shape
array.size array.size
array.real
array.imag
array.abs array.abs
array.all array.all
array.any array.any
@@ -38,6 +40,7 @@ Array
array.log10 array.log10
array.log1p array.log1p
array.log2 array.log2
array.logcumsumexp
array.logsumexp array.logsumexp
array.max array.max
array.mean array.mean

View File

@@ -20,3 +20,5 @@ FFT
irfft2 irfft2
rfftn rfftn
irfftn irfftn
fftshift
ifftshift

View File

@@ -16,9 +16,12 @@ Linear Algebra
cross cross
qr qr
svd svd
eigvals
eig
eigvalsh eigvalsh
eigh eigh
lu lu
lu_factor lu_factor
pinv
solve solve
solve_triangular solve_triangular

View File

@@ -0,0 +1,16 @@
Memory Management
=================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache

View File

@@ -8,13 +8,5 @@ Metal
is_available is_available
device_info device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture start_capture
stop_capture stop_capture

View File

@@ -36,10 +36,12 @@ Operations
bitwise_or bitwise_or
bitwise_xor bitwise_xor
block_masked_mm block_masked_mm
broadcast_arrays
broadcast_to broadcast_to
ceil ceil
clip clip
concatenate concatenate
contiguous
conj conj
conjugate conjugate
convolve convolve
@@ -101,6 +103,7 @@ Operations
log10 log10
log1p log1p
logaddexp logaddexp
logcumsumexp
logical_not logical_not
logical_and logical_and
logical_or logical_or

View File

@@ -18,3 +18,4 @@ Common Optimizers
AdamW AdamW
Adamax Adamax
Lion Lion
MultiOptimizer

View File

@@ -9,6 +9,7 @@ Transforms
:toctree: _autosummary :toctree: _autosummary
eval eval
async_eval
compile compile
custom_function custom_function
disable_compile disable_compile

View File

@@ -107,6 +107,16 @@ same array:
>>> a >>> a
array([1, 2, 0], dtype=int32) array([1, 2, 0], dtype=int32)
Note, unlike NumPy, updates to the same location are nondeterministic:
.. code-block:: shell
>>> a = mx.array([1, 2, 3])
>>> a[[0, 0]] = mx.array([4, 5])
The first element of ``a`` could be ``4`` or ``5``.
Transformations of functions which use in-place updates are allowed and work as Transformations of functions which use in-place updates are allowed and work as
expected. For example: expected. For example:

View File

@@ -72,9 +72,7 @@ void axpby_impl(
float alpha_, float alpha_,
float beta_, float beta_,
mx::Stream stream) { mx::Stream stream) {
// Allocate the output with `malloc_or_wait` which synchronously allocates out.set_data(mx::allocator::malloc(out.nbytes()));
// memory, potentially waiting if the system is under memory pressure
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
// Get the CPU command encoder and register input and output arrays // Get the CPU command encoder and register input and output arrays
auto& encoder = mx::cpu::get_command_encoder(stream); auto& encoder = mx::cpu::get_command_encoder(stream);
@@ -160,12 +158,12 @@ void Axpby::eval_gpu(
// Allocate output memory with strides based on specialization // Allocate output memory with strides based on specialization
if (contiguous_kernel) { if (contiguous_kernel) {
out.set_data( out.set_data(
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()), mx::allocator::malloc(x.data_size() * out.itemsize()),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
} else { } else {
out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); out.set_data(mx::allocator::malloc(out.nbytes()));
} }
// Resolve name of kernel (corresponds to axpby.metal) // Resolve name of kernel (corresponds to axpby.metal)
@@ -174,11 +172,11 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_"); kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out); kname << type_to_name(out);
// Make sure the metal library is available // Load the metal library
d.register_library("mlx_ext"); auto lib = d.get_library("mlx_ext");
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), lib);
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);

View File

@@ -5,6 +5,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
@@ -20,7 +21,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h) ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
# Define MLX_VERSION only in the version.cpp file. # Define MLX_VERSION only in the version.cpp file.
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp) add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}") target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>) target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
@@ -48,5 +49,19 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if(MLX_BUILD_METAL) if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
else() else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal) target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
endif()
if(MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
else()
target_sources(mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
endif()
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
endif() endif()

View File

@@ -4,12 +4,11 @@
#include <sstream> #include <sstream>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator { namespace mlx::core::allocator {
Buffer malloc(size_t size) { Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size, /* allow_swap */ true); auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) { if (size && !buffer.ptr()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes."; msg << "[malloc] Unable to allocate " << size << " bytes.";
@@ -22,45 +21,4 @@ void free(Buffer buffer) {
allocator().free(buffer); allocator().free(buffer);
} }
Buffer CommonAllocator::malloc(size_t size, bool) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
Buffer malloc_or_wait(size_t size) {
auto buffer = allocator().malloc(size);
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
buffer = allocator().malloc(size);
}
// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
}
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
void free(Buffer buffer); void free(Buffer buffer);
// Wait for running tasks to finish and free up memory
// if allocation fails
Buffer malloc_or_wait(size_t size);
class Allocator { class Allocator {
/** Abstract base class for a memory allocator. */ /** Abstract base class for a memory allocator. */
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0; virtual size_t size(Buffer buffer) const = 0;
@@ -53,16 +49,4 @@ class Allocator {
Allocator& allocator(); Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -224,6 +224,10 @@ class array {
// Not copyable // Not copyable
Data(const Data& d) = delete; Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete; Data& operator=(const Data& d) = delete;
Data(Data&& o) : buffer(o.buffer), d(o.d) {
o.buffer = allocator::Buffer(nullptr);
o.d = [](allocator::Buffer) {};
}
~Data() { ~Data() {
d(buffer); d(buffer);
} }
@@ -339,11 +343,11 @@ class array {
return allocator::allocator().size(buffer()); return allocator::allocator().size(buffer());
} }
// Return a copy of the shared pointer // Return the shared pointer to the array::Data struct
// to the array::Data struct const std::shared_ptr<Data>& data_shared_ptr() const {
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data; return array_desc_->data;
} }
// Return a raw pointer to the arrays data // Return a raw pointer to the arrays data
template <typename T> template <typename T>
T* data() { T* data() {
@@ -356,7 +360,7 @@ class array {
} }
enum Status { enum Status {
// The ouptut of a computation which has not been scheduled. // The output of a computation which has not been scheduled.
// For example, the status of `x` in `auto x = a + b`. // For example, the status of `x` in `auto x = a + b`.
unscheduled, unscheduled,

View File

@@ -1,6 +1,7 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp

View File

@@ -44,14 +44,14 @@ inline void set_binary_op_output_data(
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
if (b_donatable) { if (b_donatable) {
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()), allocator::malloc(b.data_size() * out.itemsize()),
b.data_size(), b.data_size(),
b.strides(), b.strides(),
b.flags()); b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()), allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(), a.data_size(),
a.strides(), a.strides(),
a.flags()); a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()), allocator::malloc(a.data_size() * out.itemsize()),
a.data_size(), a.data_size(),
a.strides(), a.strides(),
a.flags()); a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) { b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
break; break;
} }

View File

@@ -0,0 +1,24 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -0,0 +1,11 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void broadcast(const array& in, array& out);
} // namespace mlx::core

View File

@@ -0,0 +1,157 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cassert>
#include <functional>
#include <map>
namespace mlx::core {
template <typename T>
class BufferCache {
public:
BufferCache(
size_t page_size,
std::function<size_t(T*)> get_size,
std::function<void(T*)> free)
: page_size_(page_size),
get_size_(std::move(get_size)),
free_(std::move(free)) {}
~BufferCache() {
clear();
}
BufferCache(const BufferCache&) = delete;
BufferCache& operator=(const BufferCache&) = delete;
T* reuse_from_cache(size_t size) {
// Find the closest buffer in pool.
auto it = buffer_pool_.lower_bound(size);
if (it == buffer_pool_.end() ||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
return nullptr;
}
// Collect from the cache.
T* buf = it->second->buf;
pool_size_ -= it->first;
// Remove from record.
remove_from_list(it->second);
buffer_pool_.erase(it);
return buf;
}
void recycle_to_cache(T* buf) {
assert(buf);
// Add to cache.
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
size_t size = get_size_(buf);
pool_size_ += size;
buffer_pool_.emplace(size, bh);
}
int release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
// Release buffer.
size_t size = get_size_(tail_->buf);
total_bytes_freed += size;
free_(tail_->buf);
n_release++;
// Remove from record.
auto its = buffer_pool_.equal_range(size);
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
return el.second == tail_;
});
assert(it != buffer_pool_.end());
buffer_pool_.erase(it);
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
int clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
free_(holder->buf);
n_release++;
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
size_t cache_size() const {
return pool_size_;
}
size_t page_size() const {
return page_size_;
}
private:
struct BufferHolder {
public:
explicit BufferHolder(T* buf_) : buf(buf_) {}
BufferHolder* prev{nullptr};
BufferHolder* next{nullptr};
T* buf;
};
void add_at_head(BufferHolder* to_add) {
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void remove_from_list(BufferHolder* to_remove) {
if (to_remove->prev && to_remove->next) { // if middle
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // if tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // if head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // if only element
head_ = nullptr;
tail_ = nullptr;
}
delete to_remove;
}
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_{nullptr};
BufferHolder* tail_{nullptr};
size_t pool_size_{0};
const size_t page_size_;
std::function<size_t(T*)> get_size_;
std::function<void(T*)> free_;
};
} // namespace mlx::core

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <cassert> #include <cassert>
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
} }
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) { void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out); broadcast(inputs[0], out);
} }
@@ -103,7 +87,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) { void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
double numel = 1; double numel = 1;
for (auto ax : axes_) { for (auto ax : axes_) {

View File

@@ -1,8 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/graph_utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -15,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) {
return print_float_constant<float16_t>(os, x); return print_float_constant<float16_t>(os, x);
case bfloat16: case bfloat16:
return print_float_constant<bfloat16_t>(os, x); return print_float_constant<bfloat16_t>(os, x);
case float64:
return print_float_constant<double>(os, x);
case complex64: case complex64:
return print_complex_constant<complex64_t>(os, x); return print_complex_constant<complex64_t>(os, x);
case int8: case int8:
@@ -51,6 +52,8 @@ std::string get_type_string(Dtype d) {
return "float16_t"; return "float16_t";
case bfloat16: case bfloat16:
return "bfloat16_t"; return "bfloat16_t";
case float64:
return "double";
case complex64: case complex64:
return "complex64_t"; return "complex64_t";
case bool_: case bool_:
@@ -79,55 +82,6 @@ std::string get_type_string(Dtype d) {
} }
} }
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
bool compiled_check_contiguity( bool compiled_check_contiguity(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const Shape& shape) { const Shape& shape) {
@@ -159,8 +113,7 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs( void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::vector<array>& inputs_, const std::function<bool(size_t)>& is_constant,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous) { bool contiguous) {
if (contiguous) { if (contiguous) {
int o = 0; int o = 0;
@@ -175,8 +128,7 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Not a constant // - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && in.is_donatable() && is_constant(i)) {
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in); outputs[o++].copy_shared_buffer(in);
} }
// Get representative input flags to properly set non-donated outputs // Get representative input flags to properly set non-donated outputs
@@ -188,7 +140,7 @@ void compiled_allocate_outputs(
} }
for (; o < outputs.size(); ++o) { for (; o < outputs.size(); ++o) {
outputs[o].set_data( outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()), allocator::malloc(data_size * outputs[o].itemsize()),
data_size, data_size,
strides, strides,
flags); flags);
@@ -204,16 +156,86 @@ void compiled_allocate_outputs(
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { is_constant(i)) {
outputs[o].copy_shared_buffer( outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size()); in, outputs[o].strides(), in.flags(), in.data_size());
o++; o++;
} }
} }
for (; o < outputs.size(); ++o) { for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
} }
} }
} }
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant) {
const Shape& shape = out.shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
if (contiguous) {
return {true, shape, {}};
}
std::vector<Strides> strides_vec{out.strides()};
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants.
if (is_constant(i)) {
continue;
}
// Skip scalar inputs.
const auto& x = inputs[i];
if (is_scalar(x)) {
continue;
}
// Broadcast the inputs to the output shape.
Strides xstrides;
size_t j = 0;
for (; j < shape.size() - x.ndim(); ++j) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides_vec.push_back(std::move(xstrides));
}
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
}
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous) {
if (contiguous) {
size_t max_size = 0;
for (const auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
return max_size > UINT32_MAX;
} else {
size_t max_size = 0;
for (const auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
return max_size > UINT32_MAX;
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,9 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#include <functional>
#include <iomanip> #include <iomanip>
#include <sstream>
#include <unordered_set>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -14,19 +13,17 @@ inline bool is_static_cast(const Primitive& p) {
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
} }
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids);
std::string get_type_string(Dtype d); std::string get_type_string(Dtype d);
template <typename T> template <typename T>
void print_float_constant(std::ostream& os, const array& x) { void print_float_constant(std::ostream& os, const array& x) {
auto old_precision = os.precision(); auto old_precision = os.precision();
os << std::setprecision(std::numeric_limits<float>::digits10 + 1) if constexpr (std::is_same_v<T, double>) {
<< x.item<T>() << std::setprecision(old_precision); os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
} else {
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
}
os << x.item<T>() << std::setprecision(old_precision);
} }
template <typename T> template <typename T>
@@ -60,8 +57,19 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs( void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::vector<array>& inputs_, const std::function<bool(size_t)>& is_constant,
const std::unordered_set<uintptr_t>& constant_ids_, bool contiguous);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant);
// Return whether the kernel should use large index.
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous); bool contiguous);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,7 +2,7 @@
#pragma once #pragma once
#include "mlx/array.h" #include "mlx/backend/common/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -26,19 +26,19 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types // If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output. // have the same size, then the input buffer can hold the output.
if (in.is_donatable() && in.itemsize() == out.itemsize()) { if (is_donatable(in, out)) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
return true; return true;
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()), allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(), in.data_size(),
in.strides(), in.strides(),
in.flags()); in.flags());
return false; return false;
} }
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
return false; return false;
} }
} }

View File

@@ -99,6 +99,10 @@ inline std::pair<int, int> decompose_hadamard(int n) {
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
} }
} }
if (n > (1 << 26)) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where k <= 26");
}
return {n, m}; return {n, m};
} }

View File

@@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core { namespace mlx::core {
void Load::eval_cpu(const std::vector<array>& inputs, array& out) { void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(), auto read_task = [out_ptr = out.data<char>(),
size = out.size(), size = out.size(),
itemsize = out.itemsize(), itemsize = out.itemsize(),

View File

@@ -0,0 +1,67 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/common/utils.h"
#include "mlx/utils.h"
#include <sstream>
namespace mlx::core {
inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a,
const array& b) {
if (a.ndim() == 2) {
return {{1}, {0}, {0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
auto a_batch_strides = batch_strides[0];
auto b_batch_strides = batch_strides[1];
if (batch_shape.empty()) {
batch_shape.push_back(1);
a_batch_strides.push_back(0);
b_batch_strides.push_back(0);
}
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
}
inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) {
return {{1}, {0}, {0}, {0}};
}
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
auto C_batch_stride = batch_strides[2];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
C_batch_stride.push_back(0);
}
return std::make_tuple(
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
} // namespace mlx::core

View File

@@ -5,11 +5,9 @@
namespace mlx::core { namespace mlx::core {
std::pair<Shape, Strides> shapes_without_reduction_axes( std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x, Shape shape,
Strides strides,
const std::vector<int>& axes) { const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) { for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i]; int a = axes[i];
shape.erase(shape.begin() + a); shape.erase(shape.begin() + a);
@@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
return std::make_pair(shape, strides); return std::make_pair(shape, strides);
} }
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
return shapes_without_reduction_axes(
std::move(shape), std::move(strides), axes);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) { ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything // The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() && if (x.size() == x.data_size() && axes.size() == x.ndim() &&

View File

@@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes( std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x, const array& x,
const std::vector<int>& axes); const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const std::vector<int>& axes);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -48,12 +48,12 @@ inline void set_ternary_op_output_data(
switch (topt) { switch (topt) {
case TernaryOpType::ScalarScalarScalar: case TernaryOpType::ScalarScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags()); allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
break; break;
case TernaryOpType::VectorVectorVector: case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize() * b.data_size()), allocator::malloc(out.itemsize() * b.data_size()),
b.data_size(), b.data_size(),
b.strides(), b.strides(),
b.flags()); b.flags());
@@ -64,7 +64,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) || if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) || (b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) { (c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
break; break;
} }

View File

@@ -0,0 +1,26 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
inline void set_unary_output_data(const array& in, array& out) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
}
}
} // namespace mlx::core

View File

@@ -1,9 +1,16 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims( std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape, const Shape& shape,
const std::vector<Strides>& strides, const std::vector<Strides>& strides,
@@ -101,4 +108,118 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap); return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
} }
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == pow2) {
break;
}
}
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
}
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
// Dims with strides of 0 are ignored as they
// correspond to broadcasted dimensions
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
Dims get_2d_grid_dims_common(
const Shape& shape,
const Strides& strides,
size_t divisor) {
// Compute the 2d grid dimensions such that the total size of the grid is
// divided by divisor.
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
// No need to add this shape we can just remove it from the divisor.
if (divisor % shape[i] == 0) {
divisor /= shape[i];
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
if (divisor > 1) {
if (grid_x % divisor == 0) {
grid_x /= divisor;
divisor = 1;
} else if (grid_y % divisor == 0) {
grid_y /= divisor;
divisor = 1;
}
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
if (divisor > 1) {
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
}
return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
auto gx = (dim0 + bx - 1) / bx;
auto gy = (dim1 + by - 1) / by;
auto gz = (dim2 + bz - 1) / bz;
return std::make_pair(
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,12 +2,15 @@
#pragma once #pragma once
#include <tuple>
#include <vector> #include <vector>
#include "mlx/array.h" #include "mlx/array.h"
namespace mlx::core { namespace mlx::core {
std::string get_primitive_string(Primitive* primitive);
inline int64_t inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) { elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
int64_t loc = 0; int64_t loc = 0;
@@ -70,6 +73,31 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
const array& a, const array& a,
int64_t size_cap = std::numeric_limits<int32_t>::max()); int64_t size_cap = std::numeric_limits<int32_t>::max());
// Compute the thread block dimensions which fit the given
// input dimensions.
// - The thread block dimensions will be powers of two
// - The thread block size will be less than 2^pow2
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
// Computes a 2D grid where each element is < UINT_MAX
// Assumes:
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
// - shape and strides correspond to a contiguous (no holes) but
// possibly broadcasted array
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
// Same as above but we do an implicit division with divisor.
// Basically, equivalent to factorizing
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
Dims get_2d_grid_dims_common(
const Shape& shape,
const Strides& strides,
size_t divisor);
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
struct ContiguousIterator { struct ContiguousIterator {
inline void step() { inline void step() {
int dims = shape_.size(); int dims = shape_.size();
@@ -165,4 +193,11 @@ void shared_buffer_reshape(
const array& in, const array& in,
const Strides& out_strides, const Strides& out_strides,
array& out); array& out);
template <typename T>
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -40,11 +40,13 @@ add_dependencies(mlx cpu_compiled_preamble)
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
@@ -58,6 +60,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp ${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
@@ -73,8 +76,8 @@ target_sources(
if(MLX_BUILD_ACCELERATE) if(MLX_BUILD_ACCELERATE)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
else() else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
endif() endif()
if(IOS) if(IOS)

View File

@@ -14,10 +14,8 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) { void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis]; auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis]; auto axis_stride = in.strides()[axis];
Strides strides = in.strides(); Strides strides = remove_index(in.strides(), axis);
Shape shape = in.shape(); Shape shape = remove_index(in.shape(), axis);
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
auto in_ptr = in.data<InT>(); auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>(); auto out_ptr = out.data<uint32_t>();
@@ -68,7 +66,7 @@ void arg_reduce_dispatch(
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) { void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);

View File

@@ -0,0 +1,11 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/available.h"
namespace mlx::core::cpu {
bool is_available() {
return true;
}
} // namespace mlx::core::cpu

View File

@@ -0,0 +1,9 @@
// Copyright © 2025 Apple Inc.
#pragma once
namespace mlx::core::cpu {
bool is_available();
} // namespace mlx::core::cpu

View File

@@ -172,9 +172,12 @@ void binary_float(
case bfloat16: case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt); binary_op<bfloat16_t, Op>(a, b, out, bopt);
break; break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"[binary_float] Only supports non-complex floating point types."); "[binary_float] Only supports floating point types.");
} }
}); });
} }

View File

@@ -40,7 +40,10 @@ struct CompilerCache {
std::shared_mutex mtx; std::shared_mutex mtx;
}; };
static CompilerCache cache{}; static CompilerCache& cache() {
static CompilerCache cache_;
return cache_;
};
// GPU compile is always available if the GPU is available and since we are in // GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available. // this file CPU compile is also available.
@@ -56,14 +59,16 @@ void* compile(
const std::string& kernel_name, const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) { const std::function<std::string(void)>& source_builder) {
{ {
std::shared_lock lock(cache.mtx); std::shared_lock lock(cache().mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
return it->second; return it->second;
} }
} }
std::unique_lock lock(cache.mtx); std::unique_lock lock(cache().mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
return it->second; return it->second;
} }
std::string source_code = source_builder(); std::string source_code = source_builder();
@@ -120,10 +125,10 @@ void* compile(
} }
// load library // load library
cache.libs.emplace_back(shared_lib_path); cache().libs.emplace_back(shared_lib_path);
// Load function // Load function
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str()); void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
if (!fun) { if (!fun) {
std::ostringstream msg; std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function " msg << "[Compile::eval_cpu] Failed to load compiled function "
@@ -131,7 +136,7 @@ void* compile(
<< dlerror(); << dlerror();
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
cache.kernels.insert({kernel_name, fun}); cache().kernels.insert({kernel_name, fun});
return fun; return fun;
} }
@@ -141,18 +146,9 @@ inline void build_kernel(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<array>& outputs, const std::vector<array>& outputs,
const std::vector<array>& tape, const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids, const std::function<bool(size_t)>& is_constant,
bool contiguous, bool contiguous,
int ndim) { int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer; NodeNamer namer;
#ifdef _MSC_VER #ifdef _MSC_VER
@@ -165,14 +161,15 @@ inline void build_kernel(
// Add the input arguments // Add the input arguments
int cnt = 0; int cnt = 0;
for (auto& x : inputs) { for (size_t i = 0; i < inputs.size(); ++i) {
auto& xname = namer.get_name(x);
// Skip constants from the input list // Skip constants from the input list
if (is_constant(x)) { if (is_constant(i)) {
continue; continue;
} }
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
auto tstr = get_type_string(x.dtype()); auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++ os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl; << "];" << std::endl;
@@ -206,10 +203,11 @@ inline void build_kernel(
} }
// Read the inputs in tmps // Read the inputs in tmps
for (auto& x : inputs) { for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
if (is_constant(x)) { if (is_constant(i)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x); print_constant(os, x);
os << ";" << std::endl; os << ";" << std::endl;
@@ -259,8 +257,9 @@ inline void build_kernel(
} else { } else {
for (int d = ndim - 1; d >= 0; --d) { for (int d = ndim - 1; d >= 0; --d) {
// Update pointers // Update pointers
for (auto& x : inputs) { for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant(x) || is_scalar(x)) { const auto& x = inputs[i];
if (is_constant(i) || is_scalar(x)) {
continue; continue;
} }
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
@@ -282,65 +281,37 @@ inline void build_kernel(
void Compiled::eval_cpu( void Compiled::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
// Handle all broadcasting and collect function input arguments // Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Collect function input arguments.
std::vector<void*> args; std::vector<void*> args;
std::vector<std::vector<size_t>> strides; int strides_index = 1;
for (int i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants. if (is_constant_(i)) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue; continue;
} }
auto& x = inputs[i]; const auto& x = inputs[i];
encoder.set_input_array(x); encoder.set_input_array(x);
args.push_back((void*)x.data<void>()); args.push_back((void*)x.data<void>());
if (!contiguous && !is_scalar(x)) {
if (contiguous || is_scalar(x)) { args.push_back(strides[strides_index++].data());
continue;
} }
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
} }
// Get the kernel name from the lib // Get the kernel name from the lib
int ndim = shape.size(); int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) { if (!contiguous) {
kernel_name += std::to_string(shape.size()); kernel_name += std::to_string(ndim);
} }
// Get the function // Get the function
auto fn_ptr = compile(kernel_name, [&]() { auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
std::ostringstream kernel; std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl; kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl; kernel << "extern \"C\" {" << std::endl;
@@ -350,7 +321,7 @@ void Compiled::eval_cpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
contiguous, contiguous,
ndim); ndim);
// Close extern "C" // Close extern "C"
@@ -358,26 +329,22 @@ void Compiled::eval_cpu(
return kernel.str(); return kernel.str();
}); });
compiled_allocate_outputs( compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
inputs, outputs, inputs_, constant_ids_, contiguous);
for (auto& x : outputs) { for (auto& x : outputs) {
args.push_back(x.data<void>()); args.push_back(x.data<void>());
encoder.set_output_array(x); encoder.set_output_array(x);
} }
Shape out_shape;
if (!contiguous) { if (!contiguous) {
out_shape = outputs[0].shape(); args.push_back((void*)shape.data());
args.push_back((void*)out_shape.data());
} else { } else {
args.push_back((void*)outputs[0].data_size()); args.push_back((void*)outputs[0].data_size());
} }
auto fun = (void (*)(void**))fn_ptr; auto fun = (void (*)(void**))fn_ptr;
encoder.dispatch( encoder.dispatch([fun,
[fun, args = std::move(args),
args = std::move(args), strides = std::move(strides),
strides = std::move(strides), shape = std::move(shape)]() mutable { fun(args.data()); });
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -22,7 +22,8 @@ void slow_conv_1D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@@ -60,7 +61,8 @@ void slow_conv_1D(
out_stride_O = out.strides()[2], out_stride_O = out.strides()[2],
flip, flip,
padding = padding[0], padding_lo = padding_lo[0],
padding_hi = padding_hi[0],
wt_stride = wt_strides[0], wt_stride = wt_strides[0],
wt_dilation = wt_dilation[0], wt_dilation = wt_dilation[0],
in_dilation = in_dilation[0]]() mutable { in_dilation = in_dilation[0]]() mutable {
@@ -77,7 +79,7 @@ void slow_conv_1D(
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh; int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_stride - padding + wh_flip * wt_dilation; int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
auto ih_div = std::div(ih, in_dilation); auto ih_div = std::div(ih, in_dilation);
@@ -109,7 +111,8 @@ void slow_conv_2D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@@ -120,230 +123,235 @@ void slow_conv_2D(
encoder.set_input_array(wt); encoder.set_input_array(wt);
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.dispatch([st_wt_ptr = wt.data<T>(), encoder.dispatch(
st_in_ptr = in.data<T>(), [st_wt_ptr = wt.data<T>(),
st_out_ptr = out.data<T>(), st_in_ptr = in.data<T>(),
st_out_ptr = out.data<T>(),
N = in.shape( N = in.shape(0), // Batch size, should be the same as out.shape(0)
0), // Batch size, should be the same as out.shape(0) iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
iH = 1 + iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
in_dilation[0] * (in.shape(1) - 1), // Input spatial dim C = in.shape(3), // In channels
iW = 1 + oH = out.shape(1), // Output spatial dim
in_dilation[1] * (in.shape(2) - 1), // Input spatial dim oW = out.shape(2), // Output spatial dim
C = in.shape(3), // In channels O = wt.shape(0), // Out channels
oH = out.shape(1), // Output spatial dim wH = wt.shape(1), // Weight spatial dim
oW = out.shape(2), // Output spatial dim wW = wt.shape(2), // Weight spatial dim
O = wt.shape(0), // Out channels
wH = wt.shape(1), // Weight spatial dim
wW = wt.shape(2), // Weight spatial dim
groups = in.shape(3) / wt.shape(3), groups = in.shape(3) / wt.shape(3),
C_per_group = wt.shape(3), C_per_group = wt.shape(3),
in_stride_N = in.strides()[0], in_stride_N = in.strides()[0],
in_stride_H = in.strides()[1], in_stride_H = in.strides()[1],
in_stride_W = in.strides()[2], in_stride_W = in.strides()[2],
in_stride_C = in.strides()[3], in_stride_C = in.strides()[3],
wt_stride_O = wt.strides()[0], wt_stride_O = wt.strides()[0],
wt_stride_H = wt.strides()[1], wt_stride_H = wt.strides()[1],
wt_stride_W = wt.strides()[2], wt_stride_W = wt.strides()[2],
wt_stride_C = wt.strides()[3], wt_stride_C = wt.strides()[3],
out_stride_N = out.strides()[0], out_stride_N = out.strides()[0],
out_stride_H = out.strides()[1], out_stride_H = out.strides()[1],
out_stride_W = out.strides()[2], out_stride_W = out.strides()[2],
out_stride_O = out.strides()[3], out_stride_O = out.strides()[3],
padding, padding_lo,
wt_strides, padding_hi,
wt_dilation, wt_strides,
in_dilation, wt_dilation,
flip]() mutable { in_dilation,
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1; flip]() mutable {
bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
const int O_per_group = O / groups; const int O_per_group = O / groups;
auto pt_conv_no_checks = [&](const T* in_ptr, auto pt_conv_no_checks =
const T* wt_ptr, [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
T* out_ptr, out_ptr += oh * out_stride_H + ow * out_stride_W;
int oh, int ih_base = oh * wt_strides[0] - padding_lo[0];
int ow) { int iw_base = ow * wt_strides[1] - padding_lo[1];
out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.; float r = 0.;
for (int wh = 0; wh < wH; ++wh) { for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) { for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh; int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww; int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0]; int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1]; int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; const T* wt_ptr_pt =
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { for (int c = g * C_per_group; c < (g + 1) * C_per_group;
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) * ++c) {
static_cast<float>( r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
wt_ptr_pt[(c % C_per_group) * wt_stride_C]); static_cast<float>(
} // c wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // ww } // c
} // wh } // ww
} // wh
out_ptr[0] = static_cast<T>(r); out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O; out_ptr += out_stride_O;
wt_ptr += wt_stride_O; wt_ptr += wt_stride_O;
} // o } // o
} // g } // g
}; };
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_w = flip ? -wt_dilation[1] : wt_dilation[1]; int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0); int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0); int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
int f_wgt_jump_h = int f_wgt_jump_h =
std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0]; std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_w = int f_wgt_jump_w =
std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1]; std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_out_jump_h = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0]; int f_out_jump_h =
int f_out_jump_w = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1]; std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_w =
std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
std::vector<int> base_h(f_out_jump_h); std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w); std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_h; ++i) { for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[0] - padding[0] + init_h; int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
int wh_base = 0; int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[0] != 0) { while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
wh_base++; wh_base++;
ih_loop += jump_h; ih_loop += jump_h;
} }
base_h[i] = wh_base; base_h[i] = wh_base;
} }
for (int j = 0; j < f_out_jump_w; ++j) { for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[1] - padding[1] + init_w; int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
int ww_base = 0; int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[1] != 0) { while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
ww_base++; ww_base++;
iw_loop += jump_w; iw_loop += jump_w;
} }
base_w[j] = ww_base; base_w[j] = ww_base;
} }
auto pt_conv_all_checks = auto pt_conv_all_checks =
[&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
out_ptr += oh * out_stride_H + ow * out_stride_W; out_ptr += oh * out_stride_H + ow * out_stride_W;
int ih_base = oh * wt_strides[0] - padding[0]; int ih_base = oh * wt_strides[0] - padding_lo[0];
int iw_base = ow * wt_strides[1] - padding[1]; int iw_base = ow * wt_strides[1] - padding_lo[1];
int wh_base = base_h[oh % f_out_jump_h]; int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w]; int ww_base = base_w[ow % f_out_jump_w];
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.; float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh; int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww; int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0]; int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1]; int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt = const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W; wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt = const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; iw_dil * in_stride_W;
for (int c = g * C_per_group; c < (g + 1) * C_per_group; for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) { ++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) * r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>( static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]); wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c } // c
} // ih, iw check } // ih, iw check
} // ww } // ww
} // wh } // wh
out_ptr[0] = static_cast<T>(r); out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O; out_ptr += out_stride_O;
wt_ptr += wt_stride_O; wt_ptr += wt_stride_O;
} // o } // o
} // g } // g
}; };
int oH_border_0 = 0; int oH_border_0 = 0;
int oH_border_1 = int oH_border_1 = is_idil_one
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oH; ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
int oH_border_2 = std::max( : oH;
oH_border_1, (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]); int oH_border_2 = std::max(
int oH_border_3 = oH; oH_border_1,
(iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
int oH_border_3 = oH;
int oW_border_0 = 0; int oW_border_0 = 0;
int oW_border_1 = int oW_border_1 = is_idil_one
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oW; ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
int oW_border_2 = std::max( : oW;
oW_border_1, (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]); int oW_border_2 = std::max(
int oW_border_3 = oW; oW_border_1,
(iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
// Case 1: oh might put us out of bounds // Case 1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) { for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) { for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
} // oh } // oh
// Case 2: oh in bounds // Case 2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) { for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case a: ow might put us out of bounds // Case a: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) { for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
// Case b: ow in bounds // Case b: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) { for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
// Case c: ow might put us out of bounds // Case c: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) { for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
} // oh } // oh
// Case 3: oh might put us out of bounds // Case 3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) { for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) { for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
} // ow } // ow
} // oh } // oh
st_in_ptr += in_stride_N; st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N; st_out_ptr += out_stride_N;
} // n } // n
}); });
} }
template <typename T> template <typename T>
@@ -351,7 +359,8 @@ void slow_conv_3D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@@ -400,7 +409,8 @@ void slow_conv_3D(
out_stride_H = out.strides()[2], out_stride_H = out.strides()[2],
out_stride_W = out.strides()[3], out_stride_W = out.strides()[3],
out_stride_O = out.strides()[4], out_stride_O = out.strides()[4],
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@@ -415,9 +425,9 @@ void slow_conv_3D(
int oh, int oh,
int ow) { int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0]; int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding[1]; int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding[2]; int iw_base = ow * wt_strides[2] - padding_lo[2];
for (int o = 0; o < O; ++o) { for (int o = 0; o < O; ++o) {
float r = 0.; float r = 0.;
@@ -478,7 +488,7 @@ void slow_conv_3D(
std::vector<int> base_w(f_out_jump_w); std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_d; ++i) { for (int i = 0; i < f_out_jump_d; ++i) {
int id_loop = i * wt_strides[0] - padding[0] + init_d; int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
int wd_base = 0; int wd_base = 0;
while (wd_base < wD && id_loop % in_dilation[0] != 0) { while (wd_base < wD && id_loop % in_dilation[0] != 0) {
@@ -490,7 +500,7 @@ void slow_conv_3D(
} }
for (int i = 0; i < f_out_jump_h; ++i) { for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[1] - padding[1] + init_h; int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
int wh_base = 0; int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[1] != 0) { while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
@@ -502,7 +512,7 @@ void slow_conv_3D(
} }
for (int j = 0; j < f_out_jump_w; ++j) { for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[2] - padding[2] + init_w; int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
int ww_base = 0; int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[2] != 0) { while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
@@ -521,9 +531,9 @@ void slow_conv_3D(
int ow) { int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W; out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0]; int id_base = od * wt_strides[0] - padding_lo[0];
int ih_base = oh * wt_strides[1] - padding[1]; int ih_base = oh * wt_strides[1] - padding_lo[1];
int iw_base = ow * wt_strides[2] - padding[2]; int iw_base = ow * wt_strides[2] - padding_lo[2];
int wd_base = base_d[od % f_out_jump_d]; int wd_base = base_d[od % f_out_jump_d];
int wh_base = base_h[oh % f_out_jump_h]; int wh_base = base_h[oh % f_out_jump_h];
@@ -573,24 +583,30 @@ void slow_conv_3D(
}; };
int oD_border_0 = 0; int oD_border_0 = 0;
int oD_border_1 = int oD_border_1 = is_idil_one
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD; ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
: oD;
int oD_border_2 = std::max( int oD_border_2 = std::max(
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]); oD_border_1,
(iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
int oD_border_3 = oD; int oD_border_3 = oD;
int oH_border_0 = 0; int oH_border_0 = 0;
int oH_border_1 = int oH_border_1 = is_idil_one
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH; ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
: oH;
int oH_border_2 = std::max( int oH_border_2 = std::max(
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]); oH_border_1,
(iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
int oH_border_3 = oH; int oH_border_3 = oH;
int oW_border_0 = 0; int oW_border_0 = 0;
int oW_border_1 = int oW_border_1 = is_idil_one
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW; ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
: oW;
int oW_border_2 = std::max( int oW_border_2 = std::max(
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]); oW_border_1,
(iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
int oW_border_3 = oW; int oW_border_3 = oW;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
@@ -658,7 +674,8 @@ void dispatch_slow_conv_1D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@@ -669,7 +686,8 @@ void dispatch_slow_conv_1D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@@ -680,7 +698,8 @@ void dispatch_slow_conv_1D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@@ -691,7 +710,8 @@ void dispatch_slow_conv_1D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@@ -707,7 +727,8 @@ void dispatch_slow_conv_2D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@@ -718,7 +739,8 @@ void dispatch_slow_conv_2D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@@ -729,7 +751,8 @@ void dispatch_slow_conv_2D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@@ -740,7 +763,8 @@ void dispatch_slow_conv_2D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@@ -756,7 +780,8 @@ void dispatch_slow_conv_3D(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@@ -767,7 +792,8 @@ void dispatch_slow_conv_3D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@@ -778,7 +804,8 @@ void dispatch_slow_conv_3D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@@ -789,7 +816,8 @@ void dispatch_slow_conv_3D(
in, in,
wt, wt,
out, out,
padding, padding_lo,
padding_hi,
wt_strides, wt_strides,
wt_dilation, wt_dilation,
in_dilation, in_dilation,
@@ -829,7 +857,8 @@ void explicit_gemm_conv_1D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
Stream stream) { Stream stream) {
@@ -848,7 +877,7 @@ void explicit_gemm_conv_1D_cpu(
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
// Pad input // Pad input
Shape padded_shape = {N, iH + 2 * padding[0], C}; Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {}); array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros // Fill with zeros
@@ -857,7 +886,7 @@ void explicit_gemm_conv_1D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream); copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = padding[0] * in_padded.strides()[1]; size_t data_offset = padding_lo[0] * in_padded.strides()[1];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer( in_padded_slice.copy_shared_buffer(
in_padded, in_padded,
@@ -921,7 +950,7 @@ void explicit_gemm_conv_1D_cpu(
if (out.dtype() != float32) { if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out); temps.push_back(gemm_out);
} }
@@ -971,7 +1000,8 @@ void explicit_gemm_conv_2D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
Stream stream) { Stream stream) {
@@ -989,7 +1019,11 @@ void explicit_gemm_conv_2D_cpu(
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
// Pad input // Pad input
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C}; Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {}); array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros // Fill with zeros
@@ -998,8 +1032,8 @@ void explicit_gemm_conv_2D_cpu(
copy(temps.back(), in_padded, CopyType::Scalar, stream); copy(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2]; padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer( in_padded_slice.copy_shared_buffer(
in_padded, in_padded,
@@ -1048,7 +1082,7 @@ void explicit_gemm_conv_2D_cpu(
if (out.dtype() != float32) { if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out); temps.push_back(gemm_out);
} }
@@ -1091,7 +1125,8 @@ void explicit_gemm_conv_ND_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const bool flip, const bool flip,
@@ -1114,7 +1149,7 @@ void explicit_gemm_conv_ND_cpu(
Shape padded_shape(in.shape().size()); Shape padded_shape(in.shape().size());
padded_shape.front() = N; padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) { for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i]; padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
} }
padded_shape.back() = C; padded_shape.back() = C;
array in_padded(padded_shape, conv_dtype, nullptr, {}); array in_padded(padded_shape, conv_dtype, nullptr, {});
@@ -1125,9 +1160,10 @@ void explicit_gemm_conv_ND_cpu(
// Pick input slice from padded // Pick input slice from padded
size_t data_offset = 0; size_t data_offset = 0;
for (size_t i = 0; i < padding.size(); i++) { for (size_t i = 0; i < padding_lo.size(); i++) {
data_offset += padding[i] * in_padded.strides()[i + 1]; data_offset += padding_lo[i] * in_padded.strides()[i + 1];
} }
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer( in_padded_slice.copy_shared_buffer(
in_padded, in_padded,
@@ -1214,7 +1250,7 @@ void explicit_gemm_conv_ND_cpu(
if (out.dtype() != float32) { if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {}); gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out); temps.push_back(gemm_out);
} }
@@ -1261,7 +1297,8 @@ void conv_1D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@@ -1270,22 +1307,40 @@ void conv_1D_cpu(
const int groups = in.shape().back() / wt.shape().back(); const int groups = in.shape().back() / wt.shape().back();
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) { if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
return explicit_gemm_conv_1D_cpu( return explicit_gemm_conv_1D_cpu(
in, wt, out, padding, wt_strides, wt_dilation, stream); in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
} }
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) { if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu( return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
} }
return dispatch_slow_conv_1D( return dispatch_slow_conv_1D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
} }
void conv_2D_cpu( void conv_2D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@@ -1295,18 +1350,35 @@ void conv_2D_cpu(
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 && if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
in_dilation[1] == 1 && groups == 1) { in_dilation[1] == 1 && groups == 1) {
return explicit_gemm_conv_ND_cpu( return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
} }
return dispatch_slow_conv_2D( return dispatch_slow_conv_2D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
} }
void conv_3D_cpu( void conv_3D_cpu(
const array& in, const array& in,
const array& wt, const array& wt,
array out, array out,
const std::vector<int>& padding, const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
@@ -1317,17 +1389,34 @@ void conv_3D_cpu(
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 && in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
groups == 1) { groups == 1) {
return explicit_gemm_conv_ND_cpu( return explicit_gemm_conv_ND_cpu(
in, wt, out, padding, wt_strides, wt_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
flip,
stream);
} }
return dispatch_slow_conv_3D( return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip, stream); in,
wt,
out,
padding_lo,
padding_hi,
wt_strides,
wt_dilation,
in_dilation,
flip,
stream);
} }
} // namespace } // namespace
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) { void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& in = inputs[0]; auto& in = inputs[0];
auto& wt = inputs[1]; auto& wt = inputs[1];
@@ -1338,7 +1427,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
padding_hi_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@@ -1351,7 +1441,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
padding_hi_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@@ -1364,7 +1455,8 @@ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
padding_hi_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,

View File

@@ -30,7 +30,7 @@ void AllReduce::eval_cpu(
if (in.is_donatable()) { if (in.is_donatable()) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
} }
return in; return in;
} else { } else {
@@ -46,8 +46,15 @@ void AllReduce::eval_cpu(
case Sum: case Sum:
distributed::detail::all_sum(group(), in, outputs[0], stream()); distributed::detail::all_sum(group(), in, outputs[0], stream());
break; break;
case Max:
distributed::detail::all_max(group(), in, outputs[0], stream());
break;
case Min:
distributed::detail::all_min(group(), in, outputs[0], stream());
break;
default: default:
throw std::runtime_error("Only all reduce sum is supported for now"); throw std::runtime_error(
"Only all reduce sum, min and max are supported for now");
} }
} }
@@ -58,7 +65,7 @@ void AllGather::eval_cpu(
assert(outputs.size() == 1); assert(outputs.size() == 1);
auto [in, copied] = ensure_row_contiguous(inputs[0], stream()); auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::all_gather(group(), in, outputs[0], stream()); distributed::detail::all_gather(group(), in, outputs[0], stream());
if (copied) { if (copied) {
auto& enc = cpu::get_command_encoder(stream()); auto& enc = cpu::get_command_encoder(stream());
@@ -87,7 +94,7 @@ void Recv::eval_cpu(
assert(inputs.size() == 0); assert(inputs.size() == 0);
assert(outputs.size() == 1); assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_, stream()); distributed::detail::recv(group(), outputs[0], src_, stream());
} }

174
mlx/backend/cpu/eig.cpp Normal file
View File

@@ -0,0 +1,174 @@
// Copyright © 2025 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T>
void eig_impl(
array& a,
array& vectors,
array& values,
bool compute_eigenvectors,
Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>();
auto eig_ptr = values.data<OT>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
OT* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<OT>();
}
encoder.dispatch([a_ptr,
vec_ptr,
eig_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
int info;
{
T work;
int iwork;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
}
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
for (size_t i = 0; i < size / (N * N); ++i) {
geev<T>(
&jobl,
&jobr,
&N,
a_ptr,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
}
if (vec_ptr) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
}
}
}
vec_ptr += N * N;
}
a_ptr += N * N;
eig_ptr += N;
if (info != 0) {
std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
});
encoder.add_temporary(a);
}
} // namespace
void Eig::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), complex64, nullptr, {});
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
copy(
a,
a_copy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
stream());
values.set_data(allocator::malloc(values.nbytes()));
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.set_data(
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
}
switch (a.dtype()) {
case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
}
}
} // namespace mlx::core

View File

@@ -12,6 +12,133 @@ namespace mlx::core {
namespace { namespace {
template <typename T, class Enable = void>
struct EighWork {};
template <typename T>
struct EighWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
char jobz;
char uplo;
int N;
int lwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, T* values) {
syevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<int*>(buffers[1].buffer.raw_ptr()),
&liwork,
&info);
}
};
template <>
struct EighWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
char jobz;
char uplo;
int N;
int lwork;
int lrwork;
int liwork;
int info;
std::vector<array::Data> buffers;
EighWork(char jobz_, char uplo_, int N_)
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
T work;
R rwork;
int iwork;
heevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&rwork,
&lrwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work.real());
lrwork = static_cast<int>(rwork);
liwork = iwork;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
}
void run(T* vectors, R* values) {
heevd<T>(
&jobz,
&uplo,
&N,
vectors,
&N,
values,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&lrwork,
static_cast<int*>(buffers[2].buffer.raw_ptr()),
&liwork,
&info);
if (jobz == 'V') {
// We have pre-transposed the vectors but we also must conjugate them
// when they are complex.
//
// We could vectorize this but it is so fast in comparison to heevd that
// it doesn't really matter.
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
*vectors = std::conj(*vectors);
vectors++;
}
}
}
}
};
template <typename T> template <typename T>
void eigh_impl( void eigh_impl(
array& vectors, array& vectors,
@@ -19,8 +146,10 @@ void eigh_impl(
const std::string& uplo, const std::string& uplo,
bool compute_eigenvectors, bool compute_eigenvectors,
Stream stream) { Stream stream) {
using R = typename EighWork<T>::R;
auto vec_ptr = vectors.data<T>(); auto vec_ptr = vectors.data<T>();
auto eig_ptr = values.data<T>(); auto eig_ptr = values.data<R>();
char jobz = compute_eigenvectors ? 'V' : 'N'; char jobz = compute_eigenvectors ? 'V' : 'N';
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
@@ -33,50 +162,17 @@ void eigh_impl(
N = vectors.shape(-1), N = vectors.shape(-1),
size = vectors.size()]() mutable { size = vectors.size()]() mutable {
// Work query // Work query
int lwork = -1; EighWork<T> work(jobz, uplo, N);
int liwork = -1;
int info;
{
T work;
int iwork;
syevd<T>(
&jobz,
&uplo,
&N,
nullptr,
&N,
nullptr,
&work,
&lwork,
&iwork,
&liwork,
&info);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; // Work loop
auto iwork_buf =
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < size / (N * N); ++i) { for (size_t i = 0; i < size / (N * N); ++i) {
syevd<T>( work.run(vec_ptr, eig_ptr);
&jobz,
&uplo,
&N,
vec_ptr,
&N,
eig_ptr,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
&liwork,
&info);
vec_ptr += N * N; vec_ptr += N * N;
eig_ptr += N; eig_ptr += N;
if (info != 0) { if (work.info != 0) {
std::stringstream msg; std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info; << work.info;
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
} }
@@ -98,7 +194,7 @@ void Eigh::eval_cpu(
? outputs[1] ? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {}); : array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes())); values.set_data(allocator::malloc(values.nbytes()));
copy( copy(
a, a,
@@ -132,6 +228,10 @@ void Eigh::eval_cpu(
eigh_impl<double>( eigh_impl<double>(
vectors, values, uplo_, compute_eigenvectors_, stream()); vectors, values, uplo_, compute_eigenvectors_, stream());
break; break;
case complex64:
eigh_impl<std::complex<float>>(
vectors, values, uplo_, compute_eigenvectors_, stream());
break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"[Eigh::eval_cpu] only supports float32 or float64."); "[Eigh::eval_cpu] only supports float32 or float64.");

View File

@@ -9,6 +9,9 @@
namespace mlx::core::cpu { namespace mlx::core::cpu {
// Number of dispatches per scheduler task
constexpr int DISPATCHES_PER_TASK = 10;
struct CommandEncoder { struct CommandEncoder {
CommandEncoder(Stream stream) : stream_(stream) {} CommandEncoder(Stream stream) : stream_(stream) {}
@@ -39,13 +42,24 @@ struct CommandEncoder {
template <class F, class... Args> template <class F, class... Args>
void dispatch(F&& f, Args&&... args) { void dispatch(F&& f, Args&&... args) {
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...); auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
scheduler::enqueue(stream_, std::move(task)); if (num_ops_ == 0) {
scheduler::notify_new_task(stream_);
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
task();
scheduler::notify_task_completion(s);
};
scheduler::enqueue(stream_, std::move(task_wrap));
} else {
scheduler::enqueue(stream_, std::move(task));
}
} }
private: private:
Stream stream_; Stream stream_;
std::vector<array> temporaries_; std::vector<array> temporaries_;
int num_ops_{0};
}; };
CommandEncoder& get_command_encoder(Stream stream); CommandEncoder& get_command_encoder(Stream stream);

View File

@@ -33,12 +33,8 @@ void eval(array& arr) {
buffers.erase(it); buffers.erase(it);
} }
auto& encoder = cpu::get_command_encoder(s); auto& encoder = cpu::get_command_encoder(s);
scheduler::notify_new_task(s); encoder.dispatch([buffers = std::move(buffers),
encoder.dispatch([s, temps = std::move(encoder.temporaries())]() {});
buffers = std::move(buffers),
temps = std::move(encoder.temporaries())]() {
scheduler::notify_task_completion(s);
});
} }
} // namespace mlx::core::cpu } // namespace mlx::core::cpu

View File

@@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
s *= out.itemsize(); s *= out.itemsize();
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
std::vector<size_t> shape; std::vector<size_t> shape;
if (out.dtype() == float32) { if (out.dtype() == float32) {

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t*,
const bfloat16_t*,
bfloat16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
}
} // namespace mlx::core

View File

@@ -1,27 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cpu/gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t*,
const float16_t*,
float16_t*,
bool,
bool,
size_t,
size_t,
size_t,
float,
float,
size_t,
const Shape&,
const Strides&,
const Shape&,
const Strides&) {
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<bfloat16_t>(
const bfloat16_t* a,
const bfloat16_t* b,
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<bfloat16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,45 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/gemms/simd_gemm.h"
namespace mlx::core {
template <>
void matmul<float16_t>(
const float16_t* a,
const float16_t* b,
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
simd_gemm<float16_t, float>(
a + elem_to_loc(M * K * i, a_shape, a_strides),
b + elem_to_loc(K * N * i, b_shape, b_strides),
out + M * N * i,
a_transposed,
b_transposed,
M,
N,
K,
alpha,
beta);
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,139 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
inline int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
template <int block_size, typename T, typename AccT>
void load_block(
const T* in,
AccT* out,
int M,
int N,
int i,
int j,
bool transpose) {
if (transpose) {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[jj * block_size + ii] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
} else {
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
out[ii * block_size + jj] =
in[(i * block_size + ii) * N + j * block_size + jj];
}
}
}
}
template <typename T, typename AccT>
void simd_gemm(
const T* a,
const T* b,
T* c,
bool a_trans,
bool b_trans,
int M,
int N,
int K,
float alpha,
float beta) {
constexpr int block_size = 16;
constexpr int simd_size = simd::max_size<AccT>;
static_assert(
(block_size % simd_size) == 0,
"Block size must be divisible by SIMD size");
int last_k_block_size = K - block_size * (K / block_size);
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
for (int i = 0; i < ceildiv(M, block_size); i++) {
for (int j = 0; j < ceildiv(N, block_size); j++) {
AccT c_block[block_size * block_size] = {0.0};
AccT a_block[block_size * block_size];
AccT b_block[block_size * block_size];
int k = 0;
for (; k < K / block_size; k++) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
for (int kk = 0; kk < block_size; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
}
}
}
if (last_k_block_size) {
// Load a and b blocks
if (a_trans) {
load_block<block_size>(a, a_block, K, M, k, i, true);
} else {
load_block<block_size>(a, a_block, M, K, i, k, false);
}
if (b_trans) {
load_block<block_size>(b, b_block, N, K, j, k, false);
} else {
load_block<block_size>(b, b_block, K, N, k, j, true);
}
// Multiply and accumulate
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
int kk = 0;
for (; kk < last_k_simd_block; kk += simd_size) {
auto av =
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
auto bv =
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
c_block[ii * block_size + jj] += simd::sum(av * bv);
}
for (; kk < last_k_block_size; ++kk) {
c_block[ii * block_size + jj] +=
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
}
}
}
}
// Store
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
if (beta != 0) {
c[c_idx] = static_cast<T>(
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
} else {
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
}
}
}
}
}
}
} // namespace mlx::core

View File

@@ -197,7 +197,7 @@ void dispatch_gather(
} }
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) { void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0]; auto& src = inputs[0];
std::vector<array> inds; std::vector<array> inds;
@@ -257,15 +257,11 @@ void gather_axis(
const array& ind, const array& ind,
array& out, array& out,
const int axis) { const int axis) {
auto strides = ind.strides(); auto shape = remove_index(ind.shape(), axis);
strides.erase(strides.begin() + axis); ContiguousIterator ind_it(
auto shape = ind.shape(); shape, remove_index(ind.strides(), axis), src.ndim() - 1);
shape.erase(shape.begin() + axis); ContiguousIterator src_it(
ContiguousIterator ind_it(shape, strides, src.ndim() - 1); shape, remove_index(src.strides(), axis), src.ndim() - 1);
strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
auto ind_ptr = ind.data<IdxT>(); auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>(); auto src_ptr = src.data<T>();
@@ -354,7 +350,7 @@ void dispatch_gather_axis(
} }
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) { void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& src = inputs[0]; auto& src = inputs[0];
auto& inds = inputs[1]; auto& inds = inputs[1];
@@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
template <typename T, typename IdxT, typename OpT> template <typename T, typename IdxT, typename OpT>
void scatter_axis(array& out, const array idx, const array& upd, int axis) { void scatter_axis(array& out, const array idx, const array& upd, int axis) {
auto strides = idx.strides(); auto shape = remove_index(idx.shape(), axis);
strides.erase(strides.begin() + axis); ContiguousIterator idx_it(
auto shape = idx.shape(); shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
shape.erase(shape.begin() + axis); ContiguousIterator upd_it(
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1); shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
auto idx_ptr = idx.data<IdxT>(); auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>(); auto upd_ptr = upd.data<T>();

View File

@@ -11,7 +11,7 @@ namespace mlx::core {
template <typename T> template <typename T>
void general_inv(T* inv, int N) { void general_inv(T* inv, int N) {
int info; int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
// Compute LU factorization. // Compute LU factorization.
getrf<T>( getrf<T>(
/* m = */ &N, /* m = */ &N,
@@ -49,7 +49,7 @@ void general_inv(T* inv, int N) {
} }
const int lwork = workspace_size; const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Compute inverse. // Compute inverse.
getri<T>( getri<T>(

View File

@@ -2,14 +2,14 @@
#pragma once #pragma once
// Required for Visual Studio.
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
#ifdef _MSC_VER
#include <complex> #include <complex>
#define LAPACK_COMPLEX_CUSTOM #define LAPACK_COMPLEX_CUSTOM
#define lapack_complex_float std::complex<float> #define lapack_complex_float std::complex<float>
#define lapack_complex_double std::complex<double> #define lapack_complex_double std::complex<double>
#endif #define lapack_complex_float_real(z) ((z).real())
#define lapack_complex_float_imag(z) ((z).imag())
#define lapack_complex_double_real(z) ((z).real())
#define lapack_complex_double_imag(z) ((z).imag())
#ifdef MLX_USE_ACCELERATE #ifdef MLX_USE_ACCELERATE
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
@@ -32,7 +32,7 @@
#endif #endif
#define INSTANTIATE_LAPACK_TYPES(FUNC) \ #define INSTANTIATE_LAPACK_REAL(FUNC) \
template <typename T, typename... Args> \ template <typename T, typename... Args> \
void FUNC(Args... args) { \ void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \ if constexpr (std::is_same_v<T, float>) { \
@@ -42,11 +42,24 @@
} \ } \
} }
INSTANTIATE_LAPACK_TYPES(geqrf) INSTANTIATE_LAPACK_REAL(geqrf)
INSTANTIATE_LAPACK_TYPES(orgqr) INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_TYPES(syevd) INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_TYPES(potrf) INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_TYPES(gesvdx) INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_TYPES(getrf) INSTANTIATE_LAPACK_REAL(gesvdx)
INSTANTIATE_LAPACK_TYPES(getri) INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_TYPES(trtri) INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri)
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_COMPLEX(heevd)

View File

@@ -0,0 +1,140 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <cmath>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
namespace {
using namespace mlx::core::simd;
template <typename T, typename AccT>
void logsumexp(const array& in, array& out, Stream stream) {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_output_array(out);
const T* in_ptr = in.data<T>();
T* out_ptr = out.data<T>();
int M = in.shape().back();
int L = in.data_size() / M;
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
constexpr int N = std::min(max_size<AccT>, max_size<T>);
const T* current_in_ptr;
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
vmaximum = maximum(vals, vmaximum);
current_in_ptr += N;
s -= N;
}
AccT maximum = max(vmaximum);
while (s-- > 0) {
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
current_in_ptr++;
}
// Compute the normalizer and the exponentials
Simd<AccT, N> vnormalizer(0.0);
current_in_ptr = in_ptr;
s = M;
while (s >= N) {
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
vexp = exp(vexp - maximum);
vnormalizer = vnormalizer + vexp;
current_in_ptr += N;
s -= N;
}
AccT normalizer = sum(vnormalizer);
while (s-- > 0) {
AccT _exp = std::exp(*current_in_ptr - maximum);
normalizer += _exp;
current_in_ptr++;
}
// Normalize
*out_ptr = std::isinf(maximum)
? static_cast<T>(maximum)
: static_cast<T>(std::log(normalizer) + maximum);
}
});
}
} // namespace
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
// Make sure that the last dimension is contiguous
auto s = stream();
auto& encoder = cpu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes()));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
auto strides = in.strides();
for (auto& s : strides) {
s /= n;
}
bool col_contig = strides[0] == 1;
for (int i = 1; col_contig && i < strides.size(); ++i) {
col_contig &=
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
}
flags.col_contiguous = col_contig;
out.set_data(
allocator::malloc(in.nbytes() / n),
in.data_size() / n,
std::move(strides),
flags);
}
switch (in.dtype()) {
case float32:
logsumexp<float, float>(in, out, stream());
break;
case float16:
logsumexp<float16_t, float>(in, out, stream());
break;
case bfloat16:
logsumexp<bfloat16_t, float>(in, out, stream());
break;
case float64:
logsumexp<double, double>(in, out, stream());
break;
default:
throw std::runtime_error(
"[logsumexp] only supports floating point types");
break;
}
}
} // namespace mlx::core

View File

@@ -30,8 +30,7 @@ void luf_impl(
auto strides = lu.strides(); auto strides = lu.strides();
strides[ndim - 1] = M; strides[ndim - 1] = M;
strides[ndim - 2] = 1; strides[ndim - 2] = 1;
lu.set_data( lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace( copy_inplace(
a, a,
lu, lu,
@@ -44,8 +43,8 @@ void luf_impl(
stream); stream);
auto a_ptr = lu.data<T>(); auto a_ptr = lu.data<T>();
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes())); pivots.set_data(allocator::malloc(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes())); row_indices.set_data(allocator::malloc(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>(); auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>(); auto row_indices_ptr = row_indices.data<uint32_t>();
size_t num_matrices = a.size() / (M * N); size_t num_matrices = a.size() / (M * N);

View File

@@ -6,6 +6,7 @@
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h"
#include "mlx/backend/cpu/lapack.h" #include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -52,6 +53,58 @@ inline void mask_matrix(
} }
} }
template <typename T>
inline void segmented_mm(
const T* a,
const T* b,
const uint32_t* segments,
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides,
size_t num_segments,
const Shape& segments_shape,
const Strides& segments_strides) {
int ndim = a_shape.size();
Shape a_copy = a_shape;
Shape b_copy = b_shape;
int32_t M = a_copy[ndim - 2];
int32_t N = b_copy[ndim - 1];
for (int i = 0; i < num_segments; i++) {
uint32_t k_start =
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
uint32_t k_end =
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
if (k_end <= k_start) {
std::fill_n(out + i * M * N, M * N, T(0));
continue;
}
a_copy[ndim - 1] = k_end - k_start;
b_copy[ndim - 2] = k_end - k_start;
matmul<T>(
a + k_start * a_strides[ndim - 1],
b + k_start * b_strides[ndim - 2],
out + i * M * N,
a_transposed,
b_transposed,
lda,
ldb,
N,
1.0,
0.0,
1,
a_copy,
a_strides,
b_copy,
b_strides);
}
}
} // namespace } // namespace
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) { void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -59,7 +112,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error( throw std::runtime_error(
"[BlockMaskedMM::eval] Currently only supports float32."); "[BlockMaskedMM::eval] Currently only supports float32.");
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0]; auto& a_pre = inputs[0];
auto& b_pre = inputs[1]; auto& b_pre = inputs[1];
@@ -318,7 +371,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error( throw std::runtime_error(
"[GatherMM::eval] Currently only supports float32."); "[GatherMM::eval] Currently only supports float32.");
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& a_pre = inputs[0]; auto& a_pre = inputs[0];
auto& b_pre = inputs[1]; auto& b_pre = inputs[1];
@@ -437,4 +490,121 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
} }
void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cpu::get_command_encoder(stream());
auto check_transpose = [&s, &encoder](const array& x) {
auto stx = x.strides()[x.ndim() - 2];
auto sty = x.strides()[x.ndim() - 1];
if (stx == x.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, x);
} else if (stx == 1 && sty == x.shape(-2)) {
return std::make_tuple(true, sty, x);
} else {
array xc(x.shape(), x.dtype(), nullptr, {});
copy(x, xc, CopyType::General, s);
encoder.add_temporary(xc);
int64_t stx = x.shape(-1);
return std::make_tuple(false, stx, xc);
}
};
auto [a_transposed, lda, a] = check_transpose(inputs[0]);
auto [b_transposed, ldb, b] = check_transpose(inputs[1]);
auto& segments = inputs[2];
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(segments);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
segments = array::unsafe_weak_copy(segments),
out_ptr = out.data<void>(),
a_transposed = a_transposed,
b_transposed = b_transposed,
lda = lda,
ldb = ldb]() {
switch (a.dtype()) {
case float64:
segmented_mm<double>(
a.data<double>(),
b.data<double>(),
segments.data<uint32_t>(),
static_cast<double*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case float32:
segmented_mm<float>(
a.data<float>(),
b.data<float>(),
segments.data<uint32_t>(),
static_cast<float*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case float16:
segmented_mm<float16_t>(
a.data<float16_t>(),
b.data<float16_t>(),
segments.data<uint32_t>(),
static_cast<float16_t*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
case bfloat16:
segmented_mm<bfloat16_t>(
a.data<bfloat16_t>(),
b.data<bfloat16_t>(),
segments.data<uint32_t>(),
static_cast<bfloat16_t*>(out_ptr),
a_transposed,
b_transposed,
lda,
ldb,
a.shape(),
a.strides(),
b.shape(),
b.strides(),
segments.size() / 2,
segments.shape(),
segments.strides());
break;
default:
throw std::invalid_argument(
"Segmented mm supports only real float types.");
}
});
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -115,7 +115,7 @@ void matmul_general(
} }
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) { void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
if (inputs[0].shape(-1) == 0) { if (inputs[0].shape(-1) == 0) {
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out); encoder.set_output_array(out);
@@ -132,6 +132,10 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error( throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32."); "[AddMM::eval_cpu] Currently only supports float32.");
} }
if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes()));
return;
}
// Fill output with C // Fill output with C
auto& c = inputs[2]; auto& c = inputs[2];
@@ -139,7 +143,9 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
? CopyType::Scalar ? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy(c, out, ctype, stream()); copy(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
} }

View File

@@ -21,7 +21,7 @@ namespace mlx::core {
void reshape(const array& in, array& out) { void reshape(const array& in, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) { if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
copy_inplace(in, out, CopyType::General, out.primitive().stream()); copy_inplace(in, out, CopyType::General, out.primitive().stream());
} else { } else {
shared_buffer_reshape(in, out_strides, out); shared_buffer_reshape(in, out_strides, out);
@@ -39,7 +39,7 @@ static std::pair<array, bool> compute_dynamic_offset(
if (donate) { if (donate) {
offset.copy_shared_buffer(indices); offset.copy_shared_buffer(indices);
} else { } else {
offset.set_data(allocator::malloc_or_wait(offset.itemsize())); offset.set_data(allocator::malloc(offset.itemsize()));
} }
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
@@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) { void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0); assert(inputs.size() == 0);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
switch (out.dtype()) { switch (out.dtype()) {
case bool_: case bool_:
throw std::runtime_error("Bool type unsupported for arange."); throw std::runtime_error("Bool type unsupported for arange.");
@@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto strides = out.strides(); auto strides = out.strides();
auto flags = out.flags(); auto flags = out.flags();
@@ -205,8 +205,10 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) { void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.flags().row_contiguous || constexpr size_t extra_bytes = 16384;
(allow_col_major_ && in.flags().col_contiguous)) { if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
copy(in, out, CopyType::General, stream()); copy(in, out, CopyType::General, stream());
@@ -276,7 +278,7 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
size_t elems_per_key = out.size() / num_keys; size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key; size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>(); auto kptr = inputs[0].data<uint32_t>();
auto cptr = out.data<char>(); auto cptr = out.data<char>();
@@ -335,7 +337,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto [in_offset, donated] = auto [in_offset, donated] =
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream()); compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
copy_inplace( copy_inplace(
@@ -450,7 +452,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
} else { } else {
auto tmp = array( auto tmp = array(
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {}); in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); tmp.set_data(allocator::malloc(tmp.nbytes()));
if (in.dtype() == bool_) { if (in.dtype() == bool_) {
auto in_tmp = array(in.shape(), uint8, nullptr, {}); auto in_tmp = array(in.shape(), uint8, nullptr, {});
in_tmp.copy_shared_buffer(in); in_tmp.copy_shared_buffer(in);

View File

@@ -25,12 +25,11 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
auto strides = in.strides(); auto strides = in.strides();
strides[in.ndim() - 2] = 1; strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M; strides[in.ndim() - 1] = M;
in.set_data( in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
copy_inplace(a, in, CopyType::GeneralGeneral, stream); copy_inplace(a, in, CopyType::GeneralGeneral, stream);
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
q.set_data(allocator::malloc_or_wait(q.nbytes())); q.set_data(allocator::malloc(q.nbytes()));
r.set_data(allocator::malloc_or_wait(r.nbytes())); r.set_data(allocator::malloc(r.nbytes()));
auto in_ptr = in.data<T>(); auto in_ptr = in.data<T>();
auto r_ptr = r.data<T>(); auto r_ptr = r.data<T>();
@@ -41,8 +40,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
encoder.set_output_array(r); encoder.set_output_array(r);
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() { encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
int num_reflectors = std::min(M, N); int num_reflectors = std::min(M, N);
auto tau = auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
T optimal_work; T optimal_work;
int lwork = -1; int lwork = -1;
@@ -53,7 +51,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
// Update workspace size // Update workspace size
lwork = optimal_work; lwork = optimal_work;
auto work = allocator::malloc_or_wait(sizeof(T) * lwork); auto work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices // Loop over matrices
for (int i = 0; i < num_matrices; ++i) { for (int i = 0; i < num_matrices; ++i) {
@@ -96,7 +94,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
&lwork, &lwork,
&info); &info);
lwork = optimal_work; lwork = optimal_work;
work = allocator::malloc_or_wait(sizeof(T) * lwork); work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices // Loop over matrices
for (int i = 0; i < num_matrices; ++i) { for (int i = 0; i < num_matrices; ++i) {

View File

@@ -13,9 +13,18 @@ namespace mlx::core {
namespace { namespace {
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
}
inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
auto power_of_2_bits = (bits & (bits - 1)) == 0;
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
}
template <typename T, int bits> template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) { void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6); static_assert(bits == 3 || bits == 5 || bits == 6);
if (bits == 3) { if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7); w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3); w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
@@ -25,6 +34,16 @@ void extract_bits(const uint8_t* w_in, T* w_out) {
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1)); w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2); w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5); w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 5) {
w_out[0] = static_cast<T>(w_in[0] & 0x1f);
w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));
w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);
w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));
w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));
w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);
w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));
w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);
} else if (bits == 6) { } else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f); w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] = w_out[1] =
@@ -46,8 +65,8 @@ void _qmm(
int N, int N,
int K) { int K) {
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int pack_factor = get_pack_factor(bits, 8);
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; constexpr int bytes_per_pack = get_bytes_per_pack(bits);
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
@@ -65,7 +84,7 @@ void _qmm(
T scale = *scales_local++; T scale = *scales_local++;
T bias = *biases_local++; T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) { for (int ng = 0; ng < packs_in_group; ng++) {
if (bits == 3 || bits == 6) { if constexpr (bits == 3 || bits == 5 || bits == 6) {
T wl[pack_factor]; T wl[pack_factor];
extract_bits<T, bits>(w_local, wl); extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
@@ -104,8 +123,9 @@ void _qmm_t(
int N, int N,
int K) { int K) {
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; constexpr int pack_factor = get_pack_factor(bits, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
@@ -121,7 +141,7 @@ void _qmm_t(
T bias = *biases_local++; T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) { for (int kw = 0; kw < packs_in_group; kw++) {
if (bits == 3 || bits == 6) { if constexpr (bits == 3 || bits == 5 || bits == 6) {
T wl[pack_factor]; T wl[pack_factor];
extract_bits<T, bits>(w_local, wl); extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
@@ -304,6 +324,10 @@ void _qmm_dispatch_typed(
_qmm_dispatch_group<T, 4>( _qmm_dispatch_group<T, 4>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w); result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break; break;
case 5:
_qmm_dispatch_group<T, 5>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 6: case 6:
_qmm_dispatch_group<T, 6>( _qmm_dispatch_group<T, 6>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w); result, x, w, scales, biases, M, N, K, group_size, transposed_w);
@@ -515,7 +539,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous(scales_pre); auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre); auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
@@ -565,7 +589,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto scales = ensure_row_contiguous_last_dims(scales_pre); auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre); auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
@@ -613,9 +637,8 @@ void quantize(
float eps = 1e-7; float eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits); bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; int el_per_int = get_pack_factor(bits, 32);
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32 int bytes_per_pack = get_bytes_per_pack(bits);
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int; int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_size / group_size; size_t n_groups = w_size / group_size;
@@ -640,15 +663,21 @@ void quantize(
} }
size_t out_idx = i * int_per_group; size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) { for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0; uint64_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) { for (int k = 0; k < el_per_int; ++k) {
float w_el = w[w_idx + j * el_per_int + k]; float w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - bias) / scale); w_el = std::rint((w_el - bias) / scale);
w_el = std::min(std::max(w_el, 0.0f), n_bins); w_el = std::min(std::max(w_el, 0.0f), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits); out_el |= static_cast<uint64_t>(w_el) << (k * bits);
} }
if (power_of_2_bits) { if (power_of_2_bits) {
out[out_idx + j] = out_el; out[out_idx + j] = out_el;
} else if (bits == 5) {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;
out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;
} else { } else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff; out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8; out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
@@ -691,12 +720,12 @@ void fast::AffineQuantize::eval_cpu(
auto [w, copied] = ensure_row_contiguous(inputs[0]); auto [w, copied] = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0]; auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& scales = outputs[1]; auto& scales = outputs[1];
auto& biases = outputs[2]; auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes())); scales.set_data(allocator::malloc(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes())); biases.set_data(allocator::malloc(biases.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
if (copied) { if (copied) {
encoder.add_temporary(w); encoder.add_temporary(w);

View File

@@ -325,7 +325,15 @@ struct MaxReduce {
}; };
template <int N, typename T> template <int N, typename T>
T operator()(simd::Simd<T, N> x) { std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
return simd::max(x);
};
template <int N, typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
if (simd::any(x != x)) {
return static_cast<T>(NAN);
}
return simd::max(x); return simd::max(x);
}; };
}; };
@@ -342,7 +350,15 @@ struct MinReduce {
}; };
template <int N, typename T> template <int N, typename T>
T operator()(simd::Simd<T, N> x) { std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
return simd::min(x);
};
template <int N, typename T>
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
if (simd::any(x != x)) {
return static_cast<T>(NAN);
}
return simd::min(x); return simd::min(x);
}; };
}; };
@@ -433,7 +449,7 @@ void reduce_dispatch_min_max(
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) { void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
@@ -527,10 +543,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_); reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break; break;
case int8: case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_); reduce_dispatch_min_max<int8_t>(in, out, reduce_type_, axes_);
break; break;
case int16: case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_); reduce_dispatch_min_max<int16_t>(in, out, reduce_type_, axes_);
break; break;
case int32: case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_); reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);

View File

@@ -3,6 +3,7 @@
#include <cassert> #include <cassert>
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
@@ -226,6 +227,16 @@ void scan_dispatch(
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init); scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break; break;
} }
case Scan::LogAddExp: {
auto op = [](U a, T b) {
return detail::LogAddExp{}(a, static_cast<U>(b));
};
auto init = (issubdtype(in.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
break;
}
} }
} }
@@ -244,7 +255,7 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
in = arr_copy; in = arr_copy;
encoder.add_temporary(arr_copy); encoder.add_temporary(arr_copy);
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
@@ -319,7 +330,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
reduce_type_, in, out, axis_, reverse_, inclusive_); reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
case complex64: case complex64:
throw std::runtime_error("Scan ops do not support complex types yet"); scan_dispatch<complex64_t, complex64_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break; break;
} }
}); });

View File

@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
#endif #endif
template <> template <>
static constexpr int max_size<float16_t> = N; inline constexpr int max_size<float16_t> = N;
#define SIMD_FP16_DEFAULT_UNARY(op) \ #define SIMD_FP16_DEFAULT_UNARY(op) \
template <> \ template <> \

View File

@@ -83,25 +83,25 @@ struct Simd {
// Values chosen based on benchmarks on M3 Max // Values chosen based on benchmarks on M3 Max
// TODO: consider choosing these more optimally // TODO: consider choosing these more optimally
template <> template <>
static constexpr int max_size<int8_t> = 16; inline constexpr int max_size<int8_t> = 16;
template <> template <>
static constexpr int max_size<int16_t> = 16; inline constexpr int max_size<int16_t> = 16;
template <> template <>
static constexpr int max_size<int> = 8; inline constexpr int max_size<int> = 8;
template <> template <>
static constexpr int max_size<int64_t> = 4; inline constexpr int max_size<int64_t> = 4;
template <> template <>
static constexpr int max_size<uint8_t> = 16; inline constexpr int max_size<uint8_t> = 16;
template <> template <>
static constexpr int max_size<uint16_t> = 16; inline constexpr int max_size<uint16_t> = 16;
template <> template <>
static constexpr int max_size<uint32_t> = 8; inline constexpr int max_size<uint32_t> = 8;
template <> template <>
static constexpr int max_size<uint64_t> = 4; inline constexpr int max_size<uint64_t> = 4;
template <> template <>
static constexpr int max_size<float> = 8; inline constexpr int max_size<float> = 8;
template <> template <>
static constexpr int max_size<double> = 4; inline constexpr int max_size<double> = 4;
#define SIMD_DEFAULT_UNARY(name, op) \ #define SIMD_DEFAULT_UNARY(name, op) \
template <typename T, int N> \ template <typename T, int N> \

View File

@@ -87,14 +87,45 @@ DEFAULT_UNARY(cosh, std::cosh)
DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(expm1, std::expm1)
DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(floor, std::floor)
DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log, std::log)
DEFAULT_UNARY(log2, std::log2)
DEFAULT_UNARY(log10, std::log10) DEFAULT_UNARY(log10, std::log10)
DEFAULT_UNARY(log1p, std::log1p)
DEFAULT_UNARY(sinh, std::sinh) DEFAULT_UNARY(sinh, std::sinh)
DEFAULT_UNARY(sqrt, std::sqrt) DEFAULT_UNARY(sqrt, std::sqrt)
DEFAULT_UNARY(tan, std::tan) DEFAULT_UNARY(tan, std::tan)
DEFAULT_UNARY(tanh, std::tanh) DEFAULT_UNARY(tanh, std::tanh)
template <typename T>
Simd<T, 1> log1p(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto x = in.value.real();
auto y = in.value.imag();
auto zabs = std::abs(in.value);
auto theta = std::atan2(y, x + 1);
if (zabs < 0.5) {
auto r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return Simd<T, 1>{T{x, theta}};
}
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
} else {
auto z0 = std::hypot(x + 1, y);
return Simd<T, 1>{T{std::log(z0), theta}};
}
} else {
return Simd<T, 1>{std::log1p(in.value)};
}
}
template <typename T>
Simd<T, 1> log2(Simd<T, 1> in) {
if constexpr (is_complex<T>) {
auto out = std::log(in.value);
auto scale = decltype(out.real())(M_LN2);
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
} else {
return Simd<T, 1>{std::log2(in.value)};
}
}
template <typename T> template <typename T>
Simd<T, 1> operator~(Simd<T, 1> in) { Simd<T, 1> operator~(Simd<T, 1> in) {
return ~in.value; return ~in.value;

View File

@@ -119,17 +119,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
// Make sure that the last dimension is contiguous // Make sure that the last dimension is contiguous
auto set_output = [s = stream(), &out](const array& x) { auto set_output = [s = stream(), &out](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1; if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
out.set_data( out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()), allocator::malloc(x.data_size() * x.itemsize()),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
@@ -146,18 +141,6 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
auto in = set_output(inputs[0]); auto in = set_output(inputs[0]);
switch (in.dtype()) { switch (in.dtype()) {
case bool_:
case uint8:
case uint16:
case uint32:
case uint64:
case int8:
case int16:
case int32:
case int64:
throw std::runtime_error(
"Softmax is defined only for floating point types");
break;
case float32: case float32:
softmax<float, float>(in, out, stream()); softmax<float, float>(in, out, stream());
break; break;
@@ -178,9 +161,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
case float64: case float64:
softmax<double, double>(in, out, stream()); softmax<double, double>(in, out, stream());
break; break;
case complex64: default:
throw std::invalid_argument( throw std::runtime_error(
"[Softmax] Not yet implemented for complex64"); "[softmax] Only defined for floating point types.");
break; break;
} }
} }

View File

@@ -288,7 +288,7 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
// Allocate output // Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in); encoder.set_input_array(in);
@@ -379,7 +379,7 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
// Allocate output // Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in); encoder.set_input_array(in);

View File

@@ -50,9 +50,9 @@ void svd_impl(
array& s = outputs[1]; array& s = outputs[1];
array& vt = outputs[2]; array& vt = outputs[2];
u.set_data(allocator::malloc_or_wait(u.nbytes())); u.set_data(allocator::malloc(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes())); s.set_data(allocator::malloc(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes())); vt.set_data(allocator::malloc(vt.nbytes()));
encoder.set_output_array(u); encoder.set_output_array(u);
encoder.set_output_array(s); encoder.set_output_array(s);
@@ -64,7 +64,7 @@ void svd_impl(
} else { } else {
array& s = outputs[0]; array& s = outputs[0];
s.set_data(allocator::malloc_or_wait(s.nbytes())); s.set_data(allocator::malloc(s.nbytes()));
encoder.set_output_array(s); encoder.set_output_array(s);
@@ -91,7 +91,7 @@ void svd_impl(
// Will contain the indices of eigenvectors that failed to converge (not // Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack). // used here but required by lapack).
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)}; auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
static const int lwork_query = -1; static const int lwork_query = -1;
@@ -132,7 +132,7 @@ void svd_impl(
} }
const int lwork = workspace_dimension; const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Loop over matrices. // Loop over matrices.
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {

View File

@@ -1,5 +1,8 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
// Required for using M_LN2 in MSVC.
#define _USE_MATH_DEFINES
#include <cassert> #include <cassert>
#include "mlx/backend/cpu/unary.h" #include "mlx/backend/cpu/unary.h"

View File

@@ -2,32 +2,13 @@
#pragma once #pragma once
#include "mlx/allocator.h" #include "mlx/backend/common/unary.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
void set_unary_output_data(const array& in, array& out) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
}
template <typename T, typename U = T, typename Op> template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, size_t shape, size_t stride) { void unary_op(const T* a, U* out, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) { for (size_t i = 0; i < shape; i += 1) {

View File

@@ -86,13 +86,14 @@ struct Sign {
template <int N, typename T> template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) { Simd<T, N> operator()(Simd<T, N> x) {
auto z = Simd<T, N>{0}; auto z = Simd<T, N>{0};
auto o = Simd<T, N>{1};
auto m = Simd<T, N>{-1};
if constexpr (std::is_unsigned_v<T>) { if constexpr (std::is_unsigned_v<T>) {
return x != z; return simd::select(x == z, z, o);
} else if constexpr (std::is_same_v<T, complex64_t>) { } else if constexpr (std::is_same_v<T, complex64_t>) {
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x))); return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
} else { } else {
return simd::select( return simd::select(x < z, m, simd::select(x > z, o, z));
x < z, Simd<T, N>{-1}, simd::select(x > z, Simd<T, N>{1}, z));
} }
} }
SINGLE() SINGLE()

View File

@@ -0,0 +1,127 @@
# Filename rules in cuda backend:
#
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
# * Device-only code should be put in device/ subdir.
# * Files in device/ subdir should not include files outside.
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
# Embed kernel sources in binary for JIT compilation.
file(
GLOB MLX_JIT_SOURCES
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
add_custom_command(
OUTPUT gen/cuda_jit_sources.h
COMMAND
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)
add_dependencies(mlx cuda_jit_sources)
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
# Enable defining device lambda functions.
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Enable calling host constexpr functions from device. This is needed because
# the constexpr version of isnan is host only.
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
# Explicitly pass this flag to suppress the warning, it is safe to set it to
# true but the warning wouldn't be suppressed.
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
target_compile_options(
mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>")
endif()
# Suppress warning when building for compute capability 7 used by V100.
target_compile_options(
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
set(MLX_CUDA_ARCHITECTURES
"70;80"
CACHE STRING "CUDA architectures")
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")
# Use fixed version of CCCL.
FetchContent_Declare(
cccl
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl)
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX.
FetchContent_Declare(
nvtx3
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
GIT_TAG v3.1.1
GIT_SHALLOW TRUE
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(nvtx3)
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
# Make cuda runtime APIs available in non-cuda files.
find_package(CUDAToolkit REQUIRED)
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
# Use cublasLt.
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
# Use NVRTC and driver APIs.
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)

View File

@@ -0,0 +1,215 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
#include <cuda_runtime.h>
#include <fmt/format.h>
#include <unistd.h>
#include <cassert>
namespace mlx::core {
namespace cu {
constexpr int page_size = 16384;
CudaAllocator::CudaAllocator()
: buffer_cache_(
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8;
max_pool_size_ = memory_limit_;
}
Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache.
size_t mem_required = get_active_memory() + get_cache_memory() + size;
if (mem_required >= memory_limit_) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
}
lock.unlock();
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
lock.lock();
}
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{buf};
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return;
}
std::unique_lock lock(mutex_);
active_memory_ -= buf->size;
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
lock.unlock();
cuda_free(buf->data);
delete buf;
}
}
size_t CudaAllocator::size(Buffer buffer) const {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return 0;
}
return buf->size;
}
void CudaAllocator::register_this_thread() {
std::lock_guard lock(worker_mutex_);
allowed_threads_.insert(std::this_thread::get_id());
}
void CudaAllocator::cuda_free(void* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf);
}
size_t CudaAllocator::get_active_memory() const {
return active_memory_;
}
size_t CudaAllocator::get_peak_memory() const {
return peak_memory_;
}
void CudaAllocator::reset_peak_memory() {
std::lock_guard lock(mutex_);
peak_memory_ = 0;
}
size_t CudaAllocator::get_memory_limit() {
return memory_limit_;
}
size_t CudaAllocator::set_memory_limit(size_t limit) {
std::lock_guard lock(mutex_);
std::swap(limit, memory_limit_);
return limit;
}
size_t CudaAllocator::get_cache_memory() const {
return buffer_cache_.cache_size();
}
size_t CudaAllocator::set_cache_limit(size_t limit) {
std::lock_guard lk(mutex_);
std::swap(limit, max_pool_size_);
return limit;
}
void CudaAllocator::clear_cache() {
std::lock_guard lk(mutex_);
buffer_cache_.clear();
}
CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static CudaAllocator* allocator_ = new CudaAllocator;
return *allocator_;
}
} // namespace cu
namespace allocator {
Allocator& allocator() {
return cu::allocator();
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<cu::CudaBuffer*>(ptr_)->data;
}
} // namespace allocator
size_t get_active_memory() {
return cu::allocator().get_active_memory();
}
size_t get_peak_memory() {
return cu::allocator().get_peak_memory();
}
void reset_peak_memory() {
return cu::allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t limit) {
return cu::allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return cu::allocator().get_memory_limit();
}
size_t get_cache_memory() {
return cu::allocator().get_cache_memory();
}
size_t set_cache_limit(size_t limit) {
return cu::allocator().set_cache_limit(limit);
}
void clear_cache() {
cu::allocator().clear_cache();
}
// Not supported in CUDA.
size_t set_wired_limit(size_t) {
return 0;
}
} // namespace mlx::core

View File

@@ -0,0 +1,67 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include <mutex>
#include <set>
#include <thread>
#include <utility>
namespace mlx::core::cu {
class Worker;
using allocator::Buffer;
// Stores cuda-managed unified memory.
struct CudaBuffer {
void* data;
size_t size;
};
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
size_t get_memory_limit();
size_t set_memory_limit(size_t limit);
size_t get_cache_memory() const;
size_t set_cache_limit(size_t limit);
void clear_cache();
private:
CudaAllocator();
friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_;
size_t memory_limit_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
};
CudaAllocator& allocator();
} // namespace mlx::core::cu

View File

@@ -0,0 +1,182 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
#include <cassert>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename T>
struct IndexValPair {
uint32_t index;
T val;
};
template <typename T>
struct ArgMin {
constexpr __device__ T init() {
return Limits<T>::max();
}
__device__ IndexValPair<T> operator()(
const IndexValPair<T>& best,
const IndexValPair<T>& current) {
if (best.val > current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
best.index = offset + i;
}
}
return best;
}
};
template <typename T>
struct ArgMax {
constexpr __device__ T init() {
return Limits<T>::min();
}
__device__ IndexValPair<T> operator()(
const IndexValPair<T>& best,
const IndexValPair<T>& current) {
if (best.val < current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
}
}
template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
best.index = offset + i;
}
}
return best;
}
};
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
__global__ void arg_reduce_general(
const T* in,
uint32_t* out,
size_t size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides in_strides,
const __grid_constant__ Strides out_strides,
int32_t ndim,
int64_t axis_stride,
int32_t axis_size) {
auto block = cg::this_thread_block();
int64_t index = cg::this_grid().block_rank();
if (index >= size) {
return;
}
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
Op op;
T init = op.init();
IndexValPair<T> best{0, init};
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked(
tid, strided_iterator(in + in_idx, axis_stride), vals, axis_size, init);
best = op.reduce_many(best, vals, tid * N_READS);
}
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
__shared__ typename BlockReduceT::TempStorage temp;
best = BlockReduceT(temp).Reduce(best, op);
if (block.thread_rank() == 0) {
out[out_idx] = best.index;
}
}
} // namespace cu
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);
Strides in_strides = remove_index(in.strides(), axis_);
Strides out_strides = out.ndim() == in.ndim()
? remove_index(out.strides(), axis_)
: out.strides();
int64_t axis_stride = in.strides()[axis_];
int32_t axis_size = in.shape()[axis_];
int32_t ndim = shape.size();
// ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr uint32_t N_READS = 4;
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
auto kernel =
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) {
kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
}
encoder.add_kernel_node(
kernel,
num_blocks,
block_dim(),
in.data<T>(),
out.data<uint32_t>(),
out.size(),
const_param(shape),
const_param(in_strides),
const_param(out_strides),
ndim,
axis_stride,
axis_size);
});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,150 @@
# Based on: https://github.com/sivachandran/cmake-bin2h
#
# Copyright 2020 Sivachandran Paramasivam
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
include(CMakeParseArguments)
# Function to wrap a given string into multiple lines at the given column
# position.
#
# Parameters:
#
# * VARIABLE - The name of the CMake variable holding the string.
# * AT_COLUMN - The column position at which string will be wrapped.
function(WRAP_STRING)
set(oneValueArgs VARIABLE AT_COLUMN)
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)
math(EXPR offset "0")
while(stringLength GREATER 0)
if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})
math(EXPR length "${WRAP_STRING_AT_COLUMN}")
else()
math(EXPR length "${stringLength}")
endif()
string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)
set(lines "${lines}\n ${line}")
math(EXPR stringLength "${stringLength} - ${length}")
math(EXPR offset "${offset} + ${length}")
endwhile()
set(${WRAP_STRING_VARIABLE}
"${lines}"
PARENT_SCOPE)
endfunction()
# Function to embed contents of a file as byte array in C/C++ header file(.h).
# The header file will contain a byte array and integer variable holding the
# size of the array.
#
# Parameters:
#
# * SOURCE_FILES - The paths of source files whose contents will be embedded in
# the header file.
# * VARIABLE_NAME - The name of the variable for the byte array. The string
# "_SIZE" will be append to this name and will be used a variable name for
# size variable.
# * HEADER_FILE - The path of header file.
# * APPEND - If specified appends to the header file instead of overwriting it
# * HEADER_NAMESPACE - The namespace, where the array should be located in.
# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte
# array.
#
# Usage:
#
# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG")
function(BIN2H)
set(options APPEND NULL_TERMINATE)
set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)
set(multiValueArgs SOURCE_FILES)
cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
set(arrayDefinition "")
foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)
# get filename without extension
get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)
# convert the filename to a valid C identifier
string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME)
# reads source file contents as hex string
file(READ ${SOURCE_FILE} hexString HEX)
# append null
if(BIN2H_NULL_TERMINATE)
string(APPEND hexString "00")
endif()
# wraps the hex string into multiple lines
wrap_string(VARIABLE hexString AT_COLUMN 24)
# strip the © in source code
string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString})
string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues
${arrayValues})
# make a full variable name for the array
set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}")
# declares byte array and the length variables
string(APPEND arrayDefinition
"constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n")
endforeach()
# add namespace wrapper if defined
if(DEFINED BIN2H_HEADER_NAMESPACE)
set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {")
set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}")
set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n")
endif()
set(arrayIncludes "#pragma once")
string(PREPEND declarations "${arrayIncludes}\n\n")
if(BIN2H_APPEND)
file(APPEND ${BIN2H_HEADER_FILE} "${declarations}")
else()
file(WRITE ${BIN2H_HEADER_FILE} "${declarations}")
endif()
endfunction()
# ----------------------------- CLI args -----------------------------
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
foreach(source ${MLX_JIT_SOURCES_LIST})
list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}")
endforeach()
bin2h(
SOURCE_FILES
${MLX_JIT_SOURCES_ABS}
NULL_TERMINATE
VARIABLE_NAME
"jit_source"
HEADER_NAMESPACE
"mlx::core"
HEADER_FILE
"${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h")

362
mlx/backend/cuda/binary.cu Normal file
View File

@@ -0,0 +1,362 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (int i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[0], b[0]);
}
} else {
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[0], b[i]);
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a[0], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[0]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b[0]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
out[i] = Op{}(a[i], b[i]);
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
out_vec.val[i] = Op{}(a_vec.val[i], b_vec.val[i]);
}
store_vector<N_READS>(out, index, out_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_g_nd(
const In* a,
const In* b,
Out* out,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_g(
const In* a,
const In* b,
Out* out,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
out[index] = Op{}(a[a_idx], b[b_idx]);
}
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_op() {
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
return std::is_same_v<In, Out>;
}
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
return std::is_same_v<Out, bool>;
}
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
}
if (std::is_same_v<Op, NaNEqual>) {
return std::is_same_v<Out, bool> && is_inexact_v<In>;
}
if (std::is_same_v<Op, LogAddExp>) {
return std::is_same_v<In, Out> && is_inexact_v<In>;
}
if (std::is_same_v<Op, ArcTan2>) {
return std::is_same_v<In, Out> && is_floating_v<In>;
}
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
std::is_same_v<Op, BitwiseXor>) {
return std::is_same_v<In, Out> && std::is_integral_v<In>;
}
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
!std::is_same_v<In, bool>;
}
return false;
}
} // namespace cu
template <typename Op>
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
if (out.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out.data_size() > INT32_MAX,
[&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
Shape shape;
std::vector<Strides> strides;
std::tie(shape, strides) = collapse_contiguous_dims(a, b, out);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::
binary_g_nd<Op, InType, OutType, IdxT, dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_vs<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out.data_size(),
out.shape(),
out.strides(),
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out.dtype())));
}
});
});
}
template <typename Op>
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
nvtx3::scoped_range r(#func "::eval_gpu"); \
auto& s = out.primitive().stream(); \
binary_op_gpu<cu::func>(inputs, out, get_primitive_string(this), s); \
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU(Remainder)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Equal::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
if (equal_nan_) {
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
} else {
binary_op_gpu<cu::Equal>(inputs, out, op, s);
}
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
auto& s = out.primitive().stream();
auto op = get_primitive_string(this);
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu<cu::BitwiseAnd>(inputs, out, op, s);
break;
case BitwiseBinary::Or:
binary_op_gpu<cu::BitwiseOr>(inputs, out, op, s);
break;
case BitwiseBinary::Xor:
binary_op_gpu<cu::BitwiseXor>(inputs, out, op, s);
break;
case BitwiseBinary::LeftShift:
binary_op_gpu<cu::LeftShift>(inputs, out, op, s);
break;
case BitwiseBinary::RightShift:
binary_op_gpu<cu::RightShift>(inputs, out, op, s);
break;
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,335 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
namespace cg = cooperative_groups;
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[0], b[0]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b[0]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[0], b[i]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a[0], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[i], b[0]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b[0]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
__global__ void
binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
IdxT index = cg::this_grid().thread_rank();
if ((index + 1) * N_READS > size) {
for (IdxT i = index * N_READS; i < size; ++i) {
auto out = Op{}(a[i], b[i]);
out_a[i] = out[0];
out_b[i] = out[1];
}
} else {
auto a_vec = load_vector<N_READS>(a, index);
auto b_vec = load_vector<N_READS>(b, index);
AlignedVector<Out, N_READS> out_a_vec;
AlignedVector<Out, N_READS> out_b_vec;
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
auto out = Op{}(a_vec.val[i], b_vec.val[i]);
out_a_vec.val[i] = out[0];
out_b_vec.val[i] = out[1];
}
store_vector<N_READS>(out_a, index, out_a_vec);
store_vector<N_READS>(out_b, index, out_b_vec);
}
}
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
__global__ void binary_two_g_nd(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
index, shape.data(), a_strides.data(), b_strides.data());
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out, typename IdxT>
__global__ void binary_two_g(
const In* a,
const In* b,
Out* out_a,
Out* out_b,
IdxT size,
const __grid_constant__ Shape shape,
const __grid_constant__ Strides a_strides,
const __grid_constant__ Strides b_strides,
int ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index < size) {
auto [a_idx, b_idx] = elem_to_loc_4d(
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
auto out = Op{}(a[a_idx], b[b_idx]);
out_a[index] = out[0];
out_b[index] = out[1];
}
}
template <typename Op, typename In, typename Out>
constexpr bool supports_binary_two_op() {
if (std::is_same_v<Op, DivMod>) {
return std::is_same_v<In, Out> &&
(std::is_integral_v<Out> || is_floating_v<Out>);
}
return false;
}
} // namespace cu
template <typename Op>
void binary_two_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
assert(inputs.size() > 1);
const auto& a = inputs[0];
const auto& b = inputs[1];
auto& out_a = outputs[0];
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
if (out_a.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_two_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) {
dispatch_bool(
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
out_a.data_size() > INT32_MAX,
[&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
Shape shape;
std::vector<Strides> strides;
std::tie(shape, strides) =
collapse_contiguous_dims(a, b, out_a);
auto& a_strides = strides[0];
auto& b_strides = strides[1];
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = cu::binary_two_g_nd<
Op,
InType,
OutType,
IdxT,
dims_constant()>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_two_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else {
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
// TODO: Choose optimized value based on type size.
constexpr int N_READS = 4;
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorScalar) {
kernel = cu::binary_two_vs<Op, InType, OutType, IdxT, N_READS>;
} else if (bopt == BinaryOpType::VectorVector) {
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel,
out_a.data_size(),
out_a.shape(),
out_a.strides(),
large(),
N_READS);
encoder.add_kernel_node(
kernel,
num_blocks,
block_dims,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.data_size());
});
}
} else {
throw std::runtime_error(fmt::format(
"Can not do binary op {} on inputs of {} with result of {}.",
op,
dtype_to_string(a.dtype()),
dtype_to_string(out_a.dtype())));
}
});
});
}
template <typename Op>
void binary_two_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
std::string_view op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_two_op_gpu_inplace<Op>(inputs, outputs, op, s);
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("DivMod::eval_gpu");
auto& s = outputs[0].primitive().stream();
binary_two_op_gpu<cu::DivMod>(inputs, outputs, get_primitive_string(this), s);
}
} // namespace mlx::core

View File

@@ -0,0 +1,233 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace cu {
struct FusedKernelBuilder {
std::string os;
const std::string& kernel_name;
const std::vector<array>& inputs;
const std::vector<array>& outputs;
const std::vector<array>& tape;
const std::function<bool(size_t)>& is_constant;
void build(const char* name, bool contiguous) {
NodeNamer namer;
// Function parameters.
std::vector<std::string> params;
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant(i)) {
continue;
}
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
params.push_back(
fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname));
if (!is_scalar(x) && !contiguous) {
params.push_back(fmt::format(
"const __grid_constant__ cuda::std::array<int64_t, NDIM> {}_strides",
xname));
}
}
for (const auto& x : outputs) {
params.push_back(fmt::format(
"{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x)));
}
if (!contiguous) {
params.push_back(
"const __grid_constant__ cuda::std::array<int32_t, NDIM> shape");
}
params.push_back("IdxT size");
// Build function signature.
if (contiguous) {
os += "template <typename IdxT = uint32_t>\n";
} else {
os += "template <int NDIM, typename IdxT = uint32_t>\n";
}
os += fmt::format("__global__ void {}(\n", kernel_name + name);
for (size_t i = 0; i < params.size(); ++i) {
os += " ";
os += params[i];
if (i != params.size() - 1) {
os += ",\n";
}
}
os += ") {\n";
// Index.
os +=
" IdxT index = cg::this_grid().thread_rank();\n"
" if (index >= size) {\n"
" return;\n"
" }\n";
// Read inputs.
for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
std::string value;
if (is_constant(i)) {
std::ostringstream ss;
print_constant(ss, x);
value = fmt::format("static_cast<{}>({})", type, ss.str());
} else if (is_scalar(x)) {
value = fmt::format("{}[0]", xname);
} else if (contiguous) {
value = fmt::format("{}[index]", xname);
} else {
std::string index = fmt::format(
"elem_to_loc_nd<NDIM>(index, shape.data(), {}_strides.data())",
xname);
value = fmt::format("{}[{}]", xname, index);
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write tape.
for (const auto& x : tape) {
const std::string& xname = namer.get_name(x);
std::string type = dtype_to_cuda_type(x.dtype());
std::string value;
if (is_static_cast(x.primitive())) {
value = fmt::format(
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
} else {
std::ostringstream ss;
x.primitive().print(ss);
value = ss.str();
value += "{}(";
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
}
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
}
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
}
// Write output.
for (const auto& x : outputs) {
os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
}
os += "}\n";
}
};
} // namespace cu
constexpr const char* g_jit_includes = R"(
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/device/ternary_ops.cuh"
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cooperative_groups.h>
#define inf cuda::std::numeric_limits<float>::infinity()
)";
void Compiled::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("Compiled::eval_gpu");
auto& s = stream();
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
// Build source code.
cu::FusedKernelBuilder builder{
g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_};
builder.os +=
"namespace mlx::core::cu {\n\n"
"namespace cg = cooperative_groups;\n\n";
builder.build("_contiguous", true);
builder.os += "\n";
builder.build("_strided", false);
builder.os += "\n} // namespace mlx::core::cu\n";
// Build kernel names.
std::vector<std::string> kernel_names = {
fmt::format("mlx::core::cu::{}_contiguous<uint32_t>", lib_name()),
fmt::format("mlx::core::cu::{}_contiguous<int64_t>", lib_name()),
};
for (int i = 1; i <= MAX_NDIM; ++i) {
kernel_names.push_back(fmt::format(
"mlx::core::cu::{}_strided<{}, uint32_t>", lib_name(), i));
kernel_names.push_back(
fmt::format("mlx::core::cu::{}_strided<{}, int64_t>", lib_name(), i));
}
return std::make_pair(std::move(builder.os), std::move(kernel_names));
});
// Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
auto [contiguous, shape, strides_vec] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Whether to use large index.
bool large = compiled_use_large_index(inputs, outputs, contiguous);
cu::KernelArgs args;
// Put inputs.
int strides_index = 1;
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant_(i)) {
continue;
}
const auto& x = inputs[i];
args.append(x);
if (!contiguous && !is_scalar(x)) {
args.append_ptr(strides_vec[strides_index++].data());
}
}
// Put outputs.
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
for (auto& x : outputs) {
args.append(x);
}
// Put shape and size.
if (!contiguous) {
args.append_ptr(shape.data());
}
if (large) {
args.append<int64_t>(outputs[0].data_size());
} else {
args.append<uint32_t>(outputs[0].data_size());
}
// Launch kernel.
const char* index_type = large ? "int64_t" : "uint32_t";
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
if (contiguous) {
kernel_name += fmt::format("_contiguous<{}>", index_type);
} else {
kernel_name += fmt::format("_strided<{}, {}>", shape.size(), index_type);
}
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}
for (const auto& out : outputs) {
encoder.set_output_array(out);
}
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(kernel, outputs[0], large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args());
}
} // namespace mlx::core

Some files were not shown because too many files have changed in this diff Show More