Compare commits

..

127 Commits

Author SHA1 Message Date
Angelos Katharopoulos
d2bc340df4 Update the mlx.launch and mlx.distributed_config docs 2025-12-12 17:03:38 -08:00
Angelos Katharopoulos
fabc947df4 Possible doc fix 2025-12-12 15:34:59 -08:00
Angelos Katharopoulos
5523087cfb Finish the distributed docs 2025-12-12 14:16:16 -08:00
Angelos Katharopoulos
2f939acefa Progress with the docs 2025-12-12 04:36:46 -08:00
Angelos Katharopoulos
3b416a2e36 Comments 2025-12-12 01:21:43 -08:00
Angelos Katharopoulos
753c6a4d0f Disable echo for interactive distributed 2025-12-11 17:47:42 -08:00
Angelos Katharopoulos
d3a754c8aa Fix config when more cables are connected 2025-12-10 02:14:29 -08:00
Angelos Katharopoulos
595a4ad206 Improve interactivity of mlx.launch 2025-12-10 00:54:27 -08:00
Angelos Katharopoulos
a4dc1fac6c Set the shell to bash explicitly 2025-12-09 15:17:09 -08:00
Angelos Katharopoulos
ebda161a86 Remove old joined script 2025-12-09 13:39:57 -08:00
Angelos Katharopoulos
fa31a4b295 Add more checks and improve errors 2025-12-09 13:36:17 -08:00
Angelos Katharopoulos
9d707ba3b5 Remove python from the launch script 2025-12-09 13:04:37 -08:00
Angelos Katharopoulos
405d30b6e5 Refactor distributed config 2025-12-09 05:58:44 -08:00
Angelos Katharopoulos
cd4b12ce1b Refactoring launcher 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
425043ccca Change the name to a fun pun 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
95d92af8a0 Add headers for gcc 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
bfdddd644b Expose per-backend availability in C++ and python 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
1216afdc91 Add a no_ibv 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
04e94d78bb Add empty sum_scatter 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
60d4e8b2a8 Add send/recv 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
c5745fddd2 Make sure that there is space for work completions 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
e937a8033f Add working reduce and semi-working all gather 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
4dfe02d7c6 Fix ring 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
5c2cff9329 Fix side channel initialization for more than 2 peers 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
325dab9559 All gather 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
67e454ab0a Initial working all reduce 2025-12-08 15:50:05 -08:00
Awni Hannun
27232db1ba [CUDA] Enable more graphs to be updatable (#2883)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-08 06:18:01 -08:00
Awni Hannun
a4b3bc969b Try not to fail when there should be memory available (#2869)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-07 06:11:00 -08:00
Awni Hannun
667c0f3bb9 [Metal] No copy array init (#2875)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-05 13:36:45 -08:00
Cheng
6245824d42 Make allocator::malloc throw on allocation failure (#2874)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-05 17:44:38 +09:00
Awni Hannun
39289ef025 [CUDA] Release build for cuda 13 (#2872) 2025-12-04 21:42:26 -08:00
Awni Hannun
aefc9bd3f6 [CUDA] Faster general copy (#2873) 2025-12-04 21:42:15 -08:00
Angelos Katharopoulos
997cfc7699 Add a 2-pass col reduce for CUDA (#2863)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-04 15:53:59 -08:00
Awni Hannun
1fa8dc5797 Do a PyPi release for cuda on arm (#2866) 2025-12-04 15:28:29 -08:00
Awni Hannun
a6d6717181 fix compile copying (#2871) 2025-12-04 12:32:56 -08:00
Awni Hannun
941cfe23d7 Layer norm throws on dimension mismatch (#2870) 2025-12-04 11:21:05 -08:00
romanoneg
9abb0b8123 Added support for pytree types that inherit from tuple and typing.namedtuple (#2845) 2025-12-04 11:06:45 -08:00
Tian En "TianHeng
50d3914c67 Update gumbel function signature parameters (#2868)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-03 15:37:35 -08:00
Awni Hannun
cacbdbf995 Fix init from double (#2861)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-03 06:08:11 -08:00
Awni Hannun
193cdcd81a Fix graph updating (#2857)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-02 17:12:24 -08:00
Awni Hannun
d8ceae7b77 Reduce JVP (#2854) 2025-12-02 16:17:47 -08:00
Awni Hannun
eff0e31f00 Fix export scatters (#2852)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-02 11:24:40 -08:00
Awni Hannun
6c5785bc2f use thread local cpature mode (#2850)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-01 19:02:47 -08:00
CCYeh
8879ee00eb Support more Numpy interfaces for masked_scatter (#2832) 2025-12-01 17:51:02 -08:00
Cheng
6e762fe2e2 [CUDA] Migrate conv code to new cuDNN APIs (#2847)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-02 07:55:43 +09:00
Cheng
2b95d0c270 [CUDA] Use cuDNN attention when T_q != T_kv (#2843)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-11-27 09:58:43 +09:00
Chaoran Yu
b054838780 Added clarification to apply_fn parameter of apply_to_modules (#2831)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-26 15:40:56 -08:00
Awni Hannun
dd79d3c465 [CUDA] Faster rms norm for small dimension (#2838) 2025-11-26 15:10:41 -08:00
Cheng
704fd1ae28 [CUDA] Support array mask in SDPA (#2822)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-26 11:08:58 +09:00
Cheng
c9f4dc851f Merge build-cuda and build-linux actions (#2783)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-11-25 20:06:42 +09:00
Cheng
f8bd675655 [CUDA] Output of SDPA should have same layout with inputs (#2826)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-25 15:22:58 +09:00
Cheng
23a9168d34 [CUDA] Add debug env to save cuda graphs to dot files (#2825) 2025-11-25 15:22:36 +09:00
Awni Hannun
bca205e287 [CUDA] Exit on crash and more helpful errors (#2830) 2025-11-24 19:46:03 -08:00
CCYeh
1d4eacb737 Fix mx.core.linspace type annotation (#2820)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-11-24 14:15:08 -08:00
dependabot[bot]
8abd37ad05 Bump actions/checkout from 5 to 6 (#2828)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-24 06:04:46 -08:00
Andrey Portnoy
3e05cea9f8 Force cudaGraphExec reinstantiation when clusters are used (#2813)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-22 12:43:49 -08:00
CCYeh
5b0f047226 Fix mx.core.load type annotation (#2819) 2025-11-22 11:09:44 -08:00
Harsh Sutaria
618c87af8c Add float64 Eig and complex64 SVD/Eig support (Fixes #2708) (#2737)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-22 06:51:36 -08:00
Cheng
d5f61a93fa Fix typo: refs/head/main => refs/heads/main (#2818)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-22 09:43:35 +09:00
Awni Hannun
4a09264236 Tolerance for some ops tests on cuda (#2815) 2025-11-21 16:06:16 -08:00
Awni Hannun
0dbc7e5bee Centralize NAX condition (#2811)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-21 13:28:15 -08:00
Awni Hannun
0d68efd461 patch bump for future version (#2804)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-20 09:26:20 -08:00
Awni Hannun
f9e1a14135 [CUDA] Partly fix random for large sizes (#2798) 2025-11-20 07:27:50 -08:00
Awni Hannun
d8e9ded928 Fix cuda allocator copy condition (#2800) 2025-11-20 07:06:55 -08:00
Awni Hannun
60939d010c Fix macos release target and linux arm release (#2802)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-19 21:37:50 -08:00
Awni Hannun
fdcd2923fd patch + fix docs build (#2799) 2025-11-19 16:16:26 -08:00
Jagrit Digani
54f1cc6e3e Add Neural Accelerator Support (#2772) 2025-11-19 15:06:00 -08:00
CCYeh
b3825ac149 Add Masked Scatter (#2663)
Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-11-19 14:53:32 -08:00
Awni Hannun
7f4b7e553c version (#2797) 2025-11-19 14:11:16 -08:00
Awni Hannun
ad16f41a7f Fix version tag (#2790)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-19 08:55:57 -08:00
Awni Hannun
f46877bc08 more accurate rope fallback (#2792) 2025-11-19 06:07:21 -08:00
Cheng
6f35017d1b [CUDA] cuDNN backward attention (#2762)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-19 08:13:50 +09:00
Awni Hannun
b167f0df1c build docs on linux (#2787)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-18 08:01:03 -08:00
Cheng
a9f0d6b160 Avoid duplicate CI runs when starting a PR from upstream branch (#2788)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-18 15:16:25 +09:00
Cheng
940f4c7818 Fix building with CUDA < 12.8 (#2782)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-18 12:55:19 +09:00
Cheng
35f81728f1 Remove unneeded tests in nightly build (#2786) 2025-11-18 08:09:58 +09:00
Cheng
4442ed86c1 Fix nightly build (#2785) 2025-11-18 08:07:51 +09:00
Cheng
698559c231 Test every commit in main branch (#2781) 2025-11-18 08:07:22 +09:00
Cheng
ecc4879b07 Do not run CPU tests in CUDA builds (#2784) 2025-11-18 07:27:09 +09:00
Cheng
32b18d8b66 Use std::optional for mask_arr arg (#2763)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.8) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (cuda-12.8) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (cuda-12.9) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-17 10:43:33 +09:00
Cheng
472c43a0c8 Build and test with multiple CUDA versions (#2780) 2025-11-17 09:19:02 +09:00
Cheng
b7214ff01e Remove pip cache in GitHub Actions (#2776)
* Correctly set pip cache key

* [Debug] Try disabling pip cache
2025-11-17 08:19:59 +09:00
Cheng
76414c8971 Run CI for pushes (#2777) 2025-11-17 07:19:01 +09:00
Awni Hannun
49e4566df3 fix release 2 (#2767)
* fix release 2

* login

* fix
2025-11-16 11:39:53 -08:00
Awni Hannun
aad49f932f [CUDA] Tune ops per buffer based on device (#2761)
* tune ops per buffer based on device

* tune memory limit as well

* add tuning for spark
2025-11-16 06:29:49 -08:00
Cheng
86765cce34 Use ccache in GitHub Actions (#2773)
* Remove unnecessary steps

* Use ccache

* Log when using ccache

* Set max-size to 1GB

* Pass --no-build-isolation

* Remove more unused things
2025-11-16 07:58:14 +09:00
Cheng
1bedcbd556 Fix warnings with cmake 4.1 (#2774) 2025-11-16 07:12:47 +09:00
Cheng
9ac7dbe877 Fix MPI distributed tests with CUDA backend (#2775) 2025-11-16 07:12:18 +09:00
Awni Hannun
1bf605d56d use arch specific targets when possible (#2771) 2025-11-14 20:04:18 -08:00
Cheng
3c622ddd1d Separate test-linux from build-linux/cuda in GitHub Actions (#2765)
* Separate test-linux from build-linux/cuda in GitHub Actions

* Prefer unittest when possible

Co-authored-by: Mike Drob <mdrob@apache.org>

---------

Co-authored-by: Mike Drob <mdrob@apache.org>
2025-11-15 11:14:09 +09:00
Awni Hannun
27ff069175 Fix exporting with constants (#2769) 2025-11-14 12:52:08 -08:00
Cheng
3b2ffcefc3 [CUDA] cuDNN forward attention (#2743)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Separate sdpa kernels in another file

* Initial support for cuDNN SDPA

* Diable a few corner cases

* Remove scaled_dot_product_attention.h

* Use cuDNN attention for prefilling

* cuDNN SDPA requires Ampere and later

* Address reviews

* Do contiguous copy of inputs
2025-11-14 09:23:56 +09:00
Awni Hannun
b65f882df3 fix release (#2759) 2025-11-13 15:34:01 -08:00
Cheng
b704e9e77a [CUDA] Check CUDA error in synchronize (#2757) 2025-11-14 07:10:23 +09:00
Awni Hannun
66519fb348 fix slice (#2758) 2025-11-13 11:30:02 -08:00
Awni Hannun
8973550ff3 export custom kernel (#2756) 2025-11-13 11:29:50 -08:00
Mike Drob
3f866be665 minor debugging for publishing (#2739)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* minor debugging for publishing

* fix logic
2025-11-12 06:33:39 -08:00
Awni Hannun
23f81ed1c1 Linux on arm (#2751)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* try linux on arm

* ssh

* fix
2025-11-11 11:41:14 -08:00
wrmsr
3fe2250c00 Fix irregular_strides benchmark shape type (#2754) 2025-11-11 11:40:22 -08:00
Awni Hannun
047114b988 remove circle (#2753) 2025-11-11 11:39:47 -08:00
wrmsr
9320eb89a8 Fix dequantize python sig (dtype default) (#2752) 2025-11-11 09:55:24 -08:00
Awni Hannun
75819d70ea patch bump (#2750) 2025-11-11 08:49:14 -08:00
Awni Hannun
60d80a3728 fix release builds (#2746) 2025-11-11 07:44:30 -08:00
Pedro Cuenca
eba6a9d163 Compatibility with pip-installed openmpi (#2741)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-07 16:58:31 -08:00
CCYeh
be9e2aebd6 Shapeless support for zeros/ones_like (#2726)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* shapeless support for zeros/ones_like

* Improvements

* fix access after moved
2025-11-06 19:12:20 -08:00
Awni Hannun
df58b4133a [CUDA] Reduce use of managed memory (#2725)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Use async cuda malloc managed with cuda 13

* add pool threshold

* refactor for regular cuda malloc

* load eval gpu for cuda

* remove use of cuda pool, use cuda free async

* fix

* fix

* fix

* fix

* fix + comment
2025-11-05 16:05:23 -08:00
Anastasiia Filippova
27778156dc Nccl reduce scatter, all gather (#2727)
* Added reduce scatter and all gather for nccl

* fix unused import, delete unused file

* small fix

* deleted useless condition

* fixed comments

* fix bug in eval_gpu, renamed to sum_scatter, fix docs

* final fix docs

* remove and

* Update mlx/distributed/mpi/mpi.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix broken set input output

* fixes set output

* typo

* fix typo

* no cpu, no gpu for reduce scatter

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-11-05 08:21:11 -08:00
Mike Drob
761f901a41 fix property name (#2736) 2025-11-05 06:31:56 -06:00
Angelos Katharopoulos
6ece97f69b Make cpu binary_op easily accessible (#2733) 2025-11-05 01:08:41 -08:00
Awni Hannun
d3bc6a9bff don't test when doing release (#2734)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-04 15:54:23 -08:00
Awni Hannun
26ceb507eb only build for macos 14 and up (#2731)
* only build for macos 14 and up

* bump metal cpp
2025-11-04 09:44:15 -08:00
Mike Drob
910b3e3299 skip self-hosted runners on forks (#2730)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-03 16:22:13 -06:00
Harsh Sutaria
50fa315d18 Fix addmm with empty matrices and beta != 1.0 (#2715) 2025-11-03 14:16:15 -08:00
AN Long
1ff2b713b6 Check isnan in maximum / minimum with CPU backend (#2652)
* Check isnan in maximum / minimum with CPU backend

* Add tests

* fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-03 08:51:14 -08:00
Mike Drob
50514a6146 Set up publishing to PyPI and Test-PyPI (#2721) 2025-11-03 07:20:11 -08:00
Awni Hannun
93d76b0f30 Fix compile multi capture (#2678)
* fix compile when compiling multiple lambdas with the same capture

* add test
2025-11-03 06:33:43 -08:00
David Koski
78678de0cd add null check -- the bundleIdentifier is optional (#2709)
* add null check -- the bundleIdentifier is optional

* use variable
2025-11-03 06:33:21 -08:00
Melissa Kilby
ed9c6b1117 update: add linux fedora container CI - CPP build test only (#2722)
* update: add linux_fedora_build_cpp CI - CPP build test only - x86-64

Signed-off-by: Melissa Kilby <mkilby@apple.com>

* update: add linux_fedora_build_cpp_aarch64 CI - CPP build test only - arm64

Co-authored-by: Mike Drob <mdrob@apple.com>
Signed-off-by: Melissa Kilby <mkilby@apple.com>

* update: convert linux_fedora_build_cpp to matrix.arch loop

Co-authored-by: Mike Drob <mdrob@apple.com>
Signed-off-by: Melissa Kilby <mkilby@apple.com>

---------

Signed-off-by: Melissa Kilby <mkilby@apple.com>
Co-authored-by: Mike Drob <mdrob@apple.com>
2025-11-03 06:33:00 -08:00
Awni Hannun
39b04ce638 use faster dequant for fp4 qmv (#2720)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-10-31 11:49:59 -07:00
Mike Drob
d9e6349657 fix docs path (#2719) 2025-10-30 19:12:49 -05:00
Angelos Katharopoulos
b901a9f311 Fix the order of hosts in the ring (#2718)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-10-30 15:02:39 -07:00
Awni Hannun
68c5fa1c95 fix memory count bug (#2717) 2025-10-30 14:27:15 -07:00
Christopher Webb
793a31eeb6 Fix missing domain_uuid_key in thunderbolt ring setup (#2682) 2025-10-30 13:17:20 -07:00
Mike Drob
74c1ed25bb Migrate CircleCI to GitHub Actions (#2716)
Co-authored-by: Joseph Heck <j_heck@apple.com>
2025-10-30 12:26:55 -05:00
Awni Hannun
ec72b44417 Add quantize/dequantize for mxfp8 and nvfp4 (#2688)
* Add quantize/dequantize slow path for mxfp8 and nvfp4

* fast cuda kernel for mx/nv quantization

* fallback for cuda < 12.8 (#2697)

* format (#2700)

* fix (#2701)

* metal kernels

* docs

* fix jit

* add default bits and group sizes

* improve quant docs

* fix output type of mxfp4 matmuls
2025-10-28 16:23:12 -07:00
Melissa Kilby
460691a0e8 fix: linux-{fedora}x86_64-build (#2707)
Signed-off-by: Melissa Kilby <mkilby@apple.com>
2025-10-27 16:36:08 -07:00
Awni Hannun
969924cc69 Fp8 conversion (#2686)
* add fp8 e4m3 converters

* add cuda

* default saturate to min/max

* fix for older OS

* fix no gpu/cpu

* fix saturate

* fix compile
2025-10-27 16:35:50 -07:00
316 changed files with 20024 additions and 6015 deletions

View File

@@ -1,579 +0,0 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "26.0.0"
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.10
brew install doxygen
python3.10 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
name: Run style checks
command: |
pip install pre-commit
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
- run:
name: Install dependencies
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e ".[dev]" -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "26.0.0"
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
- run:
name: Install Python package
command: |
uv venv --python 3.10
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e . -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build example extension
command: |
source .venv/bin/activate
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
- run:
name: Build small binary
command: |
source .venv/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine:
image: "linux-cuda-12:<< parameters.image_date >>"
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Set CCache size
command: ccache --max-size 1G
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
name: Run Python tests
command: |
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release:
parameters:
python_version:
type: string
default: "3.10"
xcode_version:
type: string
default: "26.0.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m4pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
- checkout
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
mkdir -p ~/miniconda3
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install twine
pip install build
- run:
name: Install Python package
command: |
conda activate env
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
pip install . -v
- run:
name: Generate package stubs
command: |
conda activate env
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
conda activate env
twine upload dist/*
- store_artifacts:
path: dist/
build_linux_release:
parameters:
python_version:
type: string
default: "3.10"
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
name: Build wheel
command: |
PYTHON=python<< parameters.python_version >>
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
TZ=Etc/UTC sudo apt-get -y install tzdata
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.build_env >> pip install ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: xlarge
steps:
- checkout
- run:
name: Build wheel
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
pip install auditwheel
pip install patchelf
pip install build
pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w
bash python/scripts/repair_cuda.sh
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"]
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
- build_cuda_release
build_dev_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]

View File

@@ -0,0 +1,24 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}

38
.github/actions/build-docs/action.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: 'Build Documentation'
description: 'Build documentation'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-linux
- name: Install dependencies
shell: bash
run: |
sudo apt-get install -y doxygen
source .venv/bin/activate
pip install -r docs/requirements.txt
pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: bash
run: tar -cf artifact.tar -C docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error

View File

@@ -0,0 +1,40 @@
name: 'Build Linux wheel'
description: 'Build Linux wheel'
inputs:
build-backend:
description: 'Build the backend mlx-cpu package'
type: boolean
required: false
default: false
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
steps:
- name: Generate package stubs
shell: bash
run: |
pip install -e ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
- name: Build Python package
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}

41
.github/actions/build-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,41 @@
name: 'Build and Test on Linux'
inputs:
toolkit:
description: 'The toolkit to build with'
required: false
default: 'cpu'
runs:
using: "composite"
steps:
- name: Install Python package
id: python_build
shell: sh
env:
DEBUG: 1
CMAKE_ARGS: >-
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
run: |
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
# There is no GPU in arm64 runner, use a common arch.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
# Can not build tests when the built executables can not run.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
fi
pip install --no-build-isolation -e ".[dev]" -v
# Pass the CMAKE_ARGS to following steps.
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Build CPP only
shell: bash
run: |
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
cmake --build build -j $(nproc)

View File

@@ -0,0 +1,34 @@
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
build-backend:
description: 'Build the backend mlx-metal package'
type: boolean
required: false
default: false
runs:
using: "composite"
steps:
- name: Build Python package
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
pip install build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w

88
.github/actions/build-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,88 @@
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
runs:
using: "composite"
steps:
- name: Install dependencies
env:
DEBUG: 1
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs
shell: bash -l {0}
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Install tests dependencies
shell: bash -l {0}
run: |
pip install numpy torch tensorflow unittest-xml-reporting
- name: Run Python tests
shell: bash -l {0}
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
shell: bash -l {0}
run: |
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext --inplace
python test.py
- name: Build CPP only
shell: bash -l {0}
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
shell: bash -l {0}
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
shell: bash -l {0}
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
shell: bash -l {0}
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
pip install -e . -v
python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit

87
.github/actions/setup-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,87 @@
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
toolkit:
description: 'Which toolkit to install'
required: false
default: 'cpu'
python-version:
description: 'Version of python to set up'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
max-size: 1GB
- name: Install common dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
- name: Setup Python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
pip install setuptools cmake nanobind==2.4.0
echo PATH=$PATH >> $GITHUB_ENV
# Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
- name: Install MPI
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Install CUDA toolkit
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
env:
# Note: the CI machine does not meet CUDA 13's driver requirement.
# Compatibility matrix:
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
PACKAGES: |
{
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
}
run: |
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
# Jetson specific. SBSA means Arm Server Base System Architecture.
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y \
libnccl2 libnccl-dev \
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
- name: CUDA packages and driver report
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms
echo "NVIDIA Driver Packages Available:"
sudo ubuntu-drivers list --gpgpu
echo "NVIDIA Driver Version:"
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
echo "Installed NVIDIA and CUDA packages:"
dpkg -l | egrep "cuda|nvidia" -i
echo "DKMS Status:"
dkms status || echo "dkms not found"
echo "NVIDIA-SMI Status:"
nvidia-smi || echo "nvidia-smi not found"

24
.github/actions/setup-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- uses: conda-incubator/setup-miniconda@v3
with:
miniconda-version: "latest"
python-version: ${{ inputs.python-version }}

69
.github/actions/test-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,69 @@
name: 'Run Linux tests'
inputs:
has-gpu:
description: 'Run GPU tests'
required: false
default: false
runs:
using: "composite"
steps:
- name: Run MPI tests
shell: bash
run: |
echo "::group::MPI tests"
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
echo "::endgroup::"
- name: Run distributed tests
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
run: |
echo "::group::Distributed tests"
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
echo "::endgroup::"
- name: Run Python tests - CPU
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::Python tests - CPU"
python -m unittest discover python/tests -v
echo "::endgroup::"
- name: Run Python tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::Python tests - GPU"
python -m tests discover python/tests -v
echo "::endgroup::"
- name: Run CPP tests - CPU
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::CPP tests - CPU"
./build/tests/tests
echo "::endgroup::"
- name: Run CPP tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::CPP tests - GPU"
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
echo "::endgroup::"

6
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,6 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

View File

@@ -0,0 +1,27 @@
#!/bin/bash
set -ex
# [Setup] Install dependencies inside the container.
dnf update -y
dnf install -y \
blas-devel \
lapack-devel \
openblas-devel \
make \
cmake \
clang \
git
dnf clean all
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
export DEBUG=1
export CMAKE_C_COMPILER=/usr/bin/clang
export CMAKE_CXX_COMPILER=/usr/bin/clang++
mkdir -p build
pushd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
./tests/tests
popd

108
.github/workflows/build_and_test.yml vendored Normal file
View File

@@ -0,0 +1,108 @@
name: Build and Test
on:
pull_request:
push:
branches:
- main
# For testing CI without starting a pull request:
- test/*
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
check_lint:
name: Check Lint
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
name: Linux (cpu, ${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
cuda_build_and_test:
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
if: github.repository == 'ml-explore/mlx'
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
if: matrix.arch == 'x86_64'
with:
has-gpu: true
mac_build_and_test:
name: macOS (${{ matrix.macos-target }})
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
macos-target: ["14.0", "15.0"]
runs-on: [self-hosted, macos]
env:
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
build_documentation:
name: Build Documentation
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora (${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

28
.github/workflows/documentation.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

96
.github/workflows/nightly.yml vendored Normal file
View File

@@ -0,0 +1,96 @@
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: "x86_64"
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12", "3.13", "3.14"]
runner:
- ubuntu-22.04
- ubuntu-22.04-arm
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.13"]
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
toolkit: 'cuda-12.9'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7

View File

@@ -1,20 +0,0 @@
on:
pull_request:
branches:
- main
jobs:
check_lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files

244
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,244 @@
name: PyPI Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
inputs:
dev_release:
description: "Do a dev release or regular release"
required: true
default: "false"
permissions:
contents: read
jobs:
setup:
runs-on: ubuntu-latest
steps:
- name: Set publishing variables
run: echo "Publishing setup complete"
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy_documentation:
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: ${{ matrix.arch }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cpu-${{ matrix.arch }}
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: bash -l {0}
run: |
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs
shell: bash -l {0}
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-metal
path: dist/mlx_metal-*.whl
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiple: true
path: dist
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_cuda_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
pattern: mlx-cpu-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_mac_release]
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://upload.pypi.org/legacy/

View File

@@ -1,4 +1,10 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
hooks:

View File

@@ -20,13 +20,9 @@ project(
LANGUAGES C CXX
VERSION ${MLX_PROJECT_VERSION})
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
add_compile_options(-Wall -Wextra)
endif()
# ----------------------------- Setup -----------------------------
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER)
@@ -78,6 +74,7 @@ endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
@@ -92,6 +89,11 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_CUDA)
enable_language(CUDA)
endif()
@@ -126,9 +128,12 @@ if(MLX_BUILD_METAL)
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS >= 14.0")
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process(
@@ -137,7 +142,6 @@ if(MLX_BUILD_METAL)
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>

View File

@@ -75,7 +75,7 @@ void time_irregular_binary_ops_3D() {
void time_irregular_binary_ops_4D() {
auto device = mx::default_device();
std::vector<int> shape = {8, 8, 512, 512};
mx::Shape shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape);
@@ -115,7 +115,7 @@ void time_irregular_binary_ops_4D() {
void time_irregular_reshape() {
auto device = mx::default_device();
std::vector<int> shape;
mx::Shape shape;
auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device);
};
@@ -170,7 +170,7 @@ void time_irregular_astype_1D() {
void time_irregular_astype_2D() {
auto device = mx::default_device();
int size = 2048;
std::vector<int> shape = {size, size};
mx::Shape shape = {size, size};
auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device);

View File

@@ -1,6 +1,5 @@
# Copyright © 2023 Apple Inc.
import argparse
import os
import subprocess
import time

View File

@@ -0,0 +1,212 @@
import math
import os
import subprocess
import time
from copy import copy
from functools import partial
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
from matplotlib.ticker import FuncFormatter
RESULTS_DIR = "./results"
if not os.path.isdir(RESULTS_DIR):
os.mkdir(RESULTS_DIR)
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
TORCH_DEVICE = torch.device(
"mps"
if torch.backends.mps.is_available()
else ("cuda" if torch.cuda.is_available() else "cpu")
)
N_WARMUP = 5
N_ITER_BENCH = 50
N_ITER_FUNC = 20
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
D_TYPES = ("float32", "float16")
def _power_of_two_formatter(value, _position):
if value <= 0:
return ""
exponent = int(round(math.log2(value)))
if abs(value - (1 << exponent)) / value > 1e-6:
return f"{value:g}"
return f"$2^{{{exponent}}}$"
def torch_sync():
if TORCH_DEVICE.type == "cuda":
torch.cuda.synchronize()
elif TORCH_DEVICE.type == "mps":
torch.mps.synchronize()
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
outs = []
for _ in range(N_ITER_FUNC):
out = copy(self_arr)
out[mask_arr] = src_arr
outs.append(out)
mx.eval(outs)
return outs
@torch.no_grad()
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
outs = []
for _ in range(N_ITER_FUNC):
out = self_tensor.clone()
out.masked_scatter_(mask_tensor, src_tensor)
outs.append(out)
torch_sync()
return outs
def measure(fn):
for _ in range(N_WARMUP):
fn()
start = time.perf_counter_ns()
for _ in range(N_ITER_BENCH):
fn()
end = time.perf_counter_ns()
return (end - start) * 1e-9
def bytes_touched(length, true_count, item_size):
mask_bytes = length
self_bytes = length * item_size * 2 # read + write
src_bytes = true_count * item_size
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
def build_case(length, density, np_dtype, torch_dtype):
true_count = max(1, int(round(length * density)))
rng = np.random.default_rng()
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
mask_np = np.zeros(length, dtype=bool)
mask_np[:true_count] = True
rng.shuffle(mask_np)
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
self_mlx = mx.array(self_np)
mask_mlx = mx.array(mask_np)
src_mlx = mx.array(src_np)
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
# Correctness check once per configuration
mx_out = mx.array(self_np)
mx_out[mask_mlx] = src_mlx
mx.eval(mx_out)
torch_out = self_torch.clone()
torch_out.masked_scatter_(mask_torch, src_torch)
atol = 5e-3 if np_dtype == np.float16 else 1e-5
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
raise AssertionError("masked_scatter results diverged between MLX and Torch")
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
def bench_case(length, density, dtype):
np_dtype = getattr(np, dtype)
torch_dtype = getattr(torch, dtype)
(
self_mlx,
mask_mlx,
src_mlx,
self_torch,
mask_torch,
src_torch,
true_count,
) = build_case(length, density, np_dtype, torch_dtype)
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
time_torch = measure(
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
)
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
bytes_per_gb = float(1024**3)
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
return time_mlx, time_torch, mlx_gbps, torch_gbps
def plot_density(ax_perf, ax_speedup, density, dtype):
mlx_gbps = []
torch_gbps = []
mlx_times = []
torch_times = []
for length in VECTOR_LENGTHS:
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
mlx_gbps.append(gbps_mlx)
torch_gbps.append(gbps_torch)
mlx_times.append(t_mlx)
torch_times.append(t_torch)
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
ax_perf.set_xscale("log", base=2)
ax_perf.set_xticks(VECTOR_LENGTHS)
formatter = FuncFormatter(_power_of_two_formatter)
ax_perf.xaxis.set_major_formatter(formatter)
ax_perf.set_title(f"density={density:.2f}")
ax_perf.set_ylabel("GB/s")
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
ax_perf.legend()
speedup = np.array(torch_times) / np.array(mlx_times)
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
ax_speedup.set_xscale("log", base=2)
ax_speedup.set_xticks(VECTOR_LENGTHS)
ax_speedup.xaxis.set_major_formatter(formatter)
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
def main():
for dtype in D_TYPES:
fig, axs = plt.subplots(
len(MASK_DENSITIES),
2,
figsize=(10, 12),
layout="constrained",
sharex=True,
)
for i, density in enumerate(MASK_DENSITIES):
plot_density(axs[i][0], axs[i][1], density, dtype)
axs[i][0].set_xlabel("vector length")
axs[i][1].set_xlabel("vector length")
fig.suptitle(
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
)
output_path = os.path.join(
RESULTS_DIR,
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
)
fig.savefig(output_path)
plt.close(fig)
if __name__ == "__main__":
main()

3
cmake/Findnvpl.cmake Normal file
View File

@@ -0,0 +1,3 @@
# This file does nothing but to suppress the cmake warning: "By not providing
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

View File

@@ -17,11 +17,10 @@ To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.10
- macOS >= 13.5
- macOS >= 14.0
.. note::
MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
MLX is only available on devices running macOS >= 14.0 and higher.
CUDA
^^^^
@@ -30,17 +29,20 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell
pip install mlx[cuda]
pip install mlx[cuda12]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia architecture >= SM 7.5
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux)
^^^^^^^^^^^^^^^^

View File

@@ -7,21 +7,29 @@ Distributed Communication
MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support two different communication backends:
moment we support several different communication backends introduced below.
.. list-table::
:widths: 20 80
:header-rows: 1
* - Backend
- Description
* - :ref:`MPI <mpi_section>`
- A full featured and mature distributed communications library.
* - :ref:`RING <ring_section>`
- Ring all reduce and all gather over TCP sockets. Always available and
usually faster than MPI.
* - :ref:`JACCL <ring_section>`
- Low latency communication with RDMA over thunderbolt. Necessary for
things like tensor parallelism.
* - :ref:`NCCL <nccl_section>`
- The backend of choice for CUDA environments.
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets and should be
faster for thunderbolt connections.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
.. note::
Some operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
@@ -84,9 +92,8 @@ Selecting Backend
^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
both fail then a singleton group is created.
one of ``{'any', 'ring', 'jaccl', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
available backends. If they all fail then a singleton group is created.
.. note::
After a distributed backend is successfully initialized :func:`init` will
@@ -110,6 +117,8 @@ The following examples aim to clarify the backend initialization logic in MLX:
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
.. _training_example:
Training Example
----------------
@@ -192,80 +201,7 @@ almost identical to the example above:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. Launching distributed MLX programs that use MPI can be done with
``mpirun`` as expected. However, in the following examples we will be using
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
.. code:: shell
$ mlx.launch --backend mpi -n 2 test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
print 4 etc.
Installing MPI
^^^^^^^^^^^^^^
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install conda-forge::openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
$ # or simply
$ mlx.launch -n 2 test.py
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
Tuning MPI All Reduce
^^^^^^^^^^^^^^^^^^^^^
.. note::
For faster all reduce consider using the ring backend either with Thunderbolt
connections or over Ethernet.
Configure MPI to use N tcp connections between each host to improve bandwidth
by passing ``--mca btl_tcp_links N``.
Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
.. _ring_section:
Getting Started with Ring
-------------------------
@@ -316,22 +252,13 @@ utility as follows:
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --backend ring
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
To validate your connection without configuring anything
``mlx.distributed_config`` can also plot the ring using DOT format.
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
dot -Tpng ring.dot >ring.png
open ring.png
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
@@ -342,3 +269,382 @@ If you want to go through the process manually, the steps are as follows:
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.
.. _jaccl_section:
Getting Started with RDMA over Thunderbolt
------------------------------------------
Starting from version 26.2 RDMA over thunderbolt is available in MacOS and
enables low-latency communication between Macs with thunderbolt 5. MLX provides
the JACCL backend that uses this functionality to achieve communication latency
an order of magnitude lower than the ring backend.
.. note::
The name JACCL (pronounced Jackal) stands for *Jack and Angelos' Collective
Communication Library* and it is an obvious pun to Nvidia's NCCL but also
tribute to *Jack Beasley* who led the development of RDMA over Thunderbolt
at Apple.
Enabling RDMA
^^^^^^^^^^^^^
Until the feature matures, enabling RDMA over thunderbolt is slightly more
involved and **cannot** be done remotely even with sudo. In fact, it has to be
done in macOS recovery:
1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_.
2. Open the Terminal by going to Utilities -> Terminal.
3. Run ``rdma_ctl enable``.
4. Reboot.
To verify that you have successfully enabled Thunderbolt RDMA you can run
``ibv_devices`` which should produce something like the following for an M3 Ultra.
.. code-block:: bash
~ % ibv_devices
device node GUID
------ ----------------
rdma_en2 8096a9d9edbaac05
rdma_en3 8196a9d9edbaac05
rdma_en5 8396a9d9edbaac05
rdma_en4 8296a9d9edbaac05
rdma_en6 8496a9d9edbaac05
rdma_en7 8596a9d9edbaac05
Defining a Mesh
^^^^^^^^^^^^^^^
The JACCL backend supports only fully connected topologies. Namely, there needs
to be a thunderbolt cable connecting all pairs of Macs directly. For example, in
the following topology visualizations, the left one is valid because there is a
connection from any node to any other node, while for the one on the right M3
Ultra 1 is not connected to M3 Ultra 2.
.. raw:: html
<div style="display: flex; text-align: center; align-items: end; font-size: 80%;">
<div>
<img src="/_static/distributed/m3-ultra-mesh.png" alt="M3 Ultra thunderbolt mesh" style="width: 55%">
<p>Fully connected mesh of four M3 Ultra.</p>
</div>
<div>
<img src="/_static/distributed/m3-ultra-mesh-broken.png" alt="M3 Ultra broken thunderbolt mesh" style="width: 55%">
<p>Not a valid mesh (M3 Ultra 1 is not connected to M3 Ultra 2).</p>
</div>
</div>
Similar to the ring backend, the easiest way to use JACCL with MLX is to write
a JSON hostfile that will be used by ``mlx.launch``. The hostfile needs to contain
- Hostnames to use for launching scripts via ssh
- An IP for rank 0 that is reachable by all nodes
- A list of rdma devices that connect each node to each other node
The following JSON defines the valid 4-node mesh from the image above.
.. code-block:: json
[
{
"ssh": "m3-ultra-1",
"ips": ["123.123.123.1"],
"rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"]
},
{
"ssh": "m3-ultra-2",
"ips": [],
"rdma": ["rdma_en5", null, "rdma_en3", "rdma_en4"]
},
{
"ssh": "m3-ultra-3",
"ips": [],
"rdma": ["rdma_en4", "rdma_en3", null, "rdma_en5"]
},
{
"ssh": "m3-ultra-4",
"ips": [],
"rdma": ["rdma_en3", "rdma_en4", "rdma_en5", null]
}
]
Even though TCP/IP is not used when communicating with Thunderbolt RDMA,
disabling the thunderbolt bridge is still required as well as setting up
isolated local networks for each thunderbolt connection.
All of the above can be done instead via ``mlx.distributed_config``. This helper
script will
- ssh into each node
- extract the thunderbolt connectivity
- check for a valid mesh
- provide the commands to configure each node (or run them if sudo is available)
- generate the hostfile to be used with ``mlx.launch``
Putting it All Together
^^^^^^^^^^^^^^^^^^^^^^^^
For example launching a distributed MLX script that uses JACCL is fairly simple
if the nodes are reachable via ssh and have password-less sudo.
First, connect all the thunderbolt cables. Then we can verify the connections
by using the ``mlx.distributed_config`` script to visualize them.
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --dot | dot -Tpng | open -f -a Preview
After making sure that everything looks right we can auto-configure the nodes
and save the hostfile to ``m3-ultra-jaccl.json`` by running:
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --backend jaccl \
--auto-setup --output m3-ultra-jaccl.json
And now we are ready to run a distributed MLX script such as distributed inference
of a gigantic model using MLX-LM.
.. code-block::
mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \
--env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important
/path/to/remote/python -m mlx_lm chat --model mlx-community/DeepSeek-V3.2-8bit --shard
.. note::
Defining the environment variable ``MLX_METAL_FAST_SYNCH=1`` enables a
different, faster way of synchronizing between the GPU and the CPU. It is
not specific to the JACCL backend and can be used in all cases where the CPU
and GPU need to collaborate for some computation and is pretty critical for
low-latency communication since the communication is done by the CPU.
.. _nccl_section:
Getting Started with NCCL
-------------------------
MLX on CUDA environments ships with the ability to talk to `NCCL
<https://developer.nvidia.com/nccl>`_ which is a high-performance collective
communication library that supports both multi-gpu and multi-node setups.
For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all
it takes to run a distributed job is
.. code-block::
mlx.launch -n 8 test.py
# perfect for interactive scripts
mlx.launch -n 8 python -m mlx_lm chat --model my-model --shard
You can also use ``mlx.launch`` to ssh to a remote node and launch a script
with the same ease
.. code-block::
mlx.launch --hosts my-cuda-node -n 8 test.py
In many cases you may not want to use ``mlx.launch`` with the NCCL backend
because the cluster scheduler will be the one launching the processes. You can
:ref:`see which environment variables need to be defined <no_mlx_launch>` in
order for the MLX NCCL backend to be initialized correctly.
.. _mpi_section:
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to `MPI
<https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ if it is installed
on the machine. Launching distributed MLX programs that use MPI can be done
with ``mpirun`` as expected. However, in the following examples we will be
using ``mlx.launch --backend mpi`` which takes care of some nuisances such as
setting absolute paths for the ``mpirun`` executable and the ``libmpi.dyld``
shared library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
.. code:: shell
$ mlx.launch --backend mpi -n 2 test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
print 4 etc.
Installing MPI
^^^^^^^^^^^^^^
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install conda-forge::openmpi
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``. Some environments use a non-standard
library filename that can be specified using the ``MPI_LIBNAME`` environment
variable. This is automatically taken care of by ``mlx.launch`` as well.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
$ # or simply
$ mlx.launch -n 2 test.py
Setting up Remote Hosts
^^^^^^^^^^^^^^^^^^^^^^^
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
Tuning MPI All Reduce
^^^^^^^^^^^^^^^^^^^^^
.. note::
For faster all reduce consider using the ring backend either with Thunderbolt
connections or over Ethernet.
Configure MPI to use N tcp connections between each host to improve bandwidth
by passing ``--mca btl_tcp_links N``.
Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
.. _no_mlx_launch:
Distributed Without ``mlx.launch``
----------------------------------
None of the implementations of the distributed backends require launching with
``mlx.launch``. The script simply connects to each host. Starts a process per
rank and sets up the necessary environment variables before delegating to your
MLX script. See the :doc:`dedicated documentation page <launching_distributed>`
for more details.
For many use-cases this will be the easiest way to perform distributed
computations in MLX. However, there may be reasons that you cannot or should
not use ``mlx.launch``. A common such case is the use of a scheduler that
starts all the processes for you on machines undetermined at the time of
scheduling the job.
Below we list the environment variables required to use each backend.
Ring
^^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_HOSTFILE** should contain the path to a json file that contains IPs and
ports for each rank to listen to, something like the following:
.. code-block:: json
[
["123.123.1.1:5000", "123.123.1.2:5000"],
["123.123.2.1:5000", "123.123.2.2:5000"],
["123.123.3.1:5000", "123.123.3.2:5000"],
["123.123.4.1:5000", "123.123.4.2:5000"]
]
**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging
from the distributed backend.
JACCL
^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen
to all the other ranks connect to in order to establish the RDMA connections.
**MLX_IBV_DEVICES** should contain the path to a json file that contains the
ibverbs device names that connect each node to each other node, something like
the following:
.. code-block:: json
[
[null, "rdma_en5", "rdma_en4", "rdma_en3"],
["rdma_en5", null, "rdma_en3", "rdma_en4"],
["rdma_en4", "rdma_en3", null, "rdma_en5"],
["rdma_en3", "rdma_en4", "rdma_en5", null]
]
NCCL
^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_WORLD_SIZE** should contain the total number of processes that will be
launched.
**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all
hosts can connect to to establish the NCCL communication.
**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that
corresponds to this process.
Of course any `other environment variable
<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`_ that is
used by NCCL can be set.
.. _tips_and_tricks:
Tips and Tricks
----------------
This is a small collection of tips to help you utilize better the distributed
communication capabilities of MLX.
- *Test locally first.*
You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small
scale test on a single node first.
- *Batch your communication.*
As described in the :ref:`training example <training_example>`, performing a
lot of small communication can hurt performance. Copy the approach of
:func:`mlx.nn.average_gradients` to gather many small communications in a
single large one.
- *Visualize the connectivity.*
Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to
visualize the connnections and make sure that the cables are connected
correctly. See the :ref:`JACCL section <jaccl_section>` for examples.
- *Use the debugger.*
``mlx.launch`` is meant for interactive use. It broadcasts stdin to all
processes and gathers stdout from all processes. This makes using ``pdb`` a
breeze.

View File

@@ -70,7 +70,8 @@ Differences from NumPy
* Indexing does not perform bounds checking. Indexing out of bounds is
undefined behavior.
* Boolean mask based indexing is not yet supported.
* Boolean mask based indexing is supported for assignment only (see
:ref:`boolean-mask-assignment`).
The reason for the lack of bounds checking is that exceptions cannot propagate
from the GPU. Performing bounds checking for array indices before launching the
@@ -143,3 +144,51 @@ expected. For example:
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
and ones elsewhere.
.. _boolean-mask-assignment:
Boolean Mask Assignment
-----------------------
MLX supports boolean indices using NumPy syntax. A mask must already be
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
Other index types are routed through the standard scatter code.
.. code-block:: shell
>>> a = mx.array([1.0, 2.0, 3.0])
>>> mask = mx.array([True, False, True])
>>> updates = mx.array([5.0, 6.0])
>>> a[mask] = updates
>>> a
array([5.0, 2.0, 6.0], dtype=float32)
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
assignments, ``updates`` must provide at least as many elements as there are
``True`` entries in ``mask``.
.. code-block:: shell
>>> a = mx.zeros((2, 3))
>>> mask = mx.array([[True, False, True],
[False, False, True]])
>>> a[mask] = 1.0
>>> a
array([[1.0, 0.0, 1.0],
[0.0, 0.0, 1.0]], dtype=float32)
Boolean masks follow NumPy semantics:
- The mask shape must match the shape of the axes it indexes exactly. The only
exception is a scalar boolean mask, which broadcasts to the full array.
- Any axes not covered by the mask are taken in full.
.. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
axes and therefore raise errors.

View File

@@ -7,13 +7,106 @@ Launching Distributed Programs
.. currentmodule:: mlx.core.distributed
Installing the MLX python package provides a helper script ``mlx.launch`` that
can be used to run python scripts distributed on several nodes. It allows
launching using either the MPI backend or the ring backend. See the
:doc:`distributed docs <distributed>` for the different backends.
Installing the MLX python package provides two utilities to help you configure
your Macs for distributed computation and also launch distributed programs on
multiple nodes or with many processes in a single node. These utilities are aptly named
Usage
-----
- ``mlx.launch``
- ``mlx.distributed_config``
See the :doc:`distributed docs <distributed>` for an introduction and
getting-started guides to the various backends.
``mlx.distributed_config``
---------------------------
Unless you are launching distributed jobs locally for development or multi-gpu
CUDA environments, then you have several Macs that you need to configure for
distributed communication with MLX.
``mlx.distributed_config`` aims to automate the process of configuring the
network interfaces (especially for communication over thunderbolt) and also
creating the hostfile to be used with ``mlx.launch``.
We will analyse 3 cases of using ``mlx.distributed_config``
1. RDMA over thunderbolt using JACCL
2. TCP/IP over thunderbolt using the ring backend
3. TCP/IP over ethernet using the ring backend
JACCL
^^^^^^^
After following :ref:`the steps to enable RDMA <jaccl_section>` you can run the
following command to configure the nodes and create the hostfile.
.. code-block::
mlx.distributed_config --verbose --backend jaccl \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 --over thunderbolt \
--auto-setup --output m3-ultra-jaccl.json
Let's walk through the steps that the script takes to configure the nodes.
1. Ssh to all nodes to verify that they are reachable
2. Extract the thunderbolt connectivity. Namely run commands on each node to
calculate which node is connected to which other node.
3. Verify that we have a valid fully connected mesh
4. Check that RDMA is enabled
5. Extract the ethernet IP from interface en0
6. Disable the thunderbolt bridge and set up peer to peer networks for each
thunderbolt cable
7. Write the hostfile
Knowing the above steps allows you to manually configure the nodes but also
debug any configuration issue. For instance changing the Ethernet IP to a
different interface directly in the config is possible (as long as it is
reachable from all nodes).
The ``--auto-setup`` argument requires password-less sudo on each node. If it
isn't available then the configuration script will print commands to be run on
each node.
Ring over thunderbolt
^^^^^^^^^^^^^^^^^^^^^
Setting up a ring backend over thunderbolt only requires changing the
``--backend`` from ``jaccl`` to ``ring``.
The steps are very similar with the main difference being that instead of
verifying that the nodes are fully connected, the script attempts to identify a
ring topology (or multiple rings).
Ring over Ethernet
^^^^^^^^^^^^^^^^^^
Configuring the ring backend over ethernet doesn't require setting up network
interface and as such it simply extracts the ``en0`` IP from each node and
writes the hostfile.
Debugging cable connections
^^^^^^^^^^^^^^^^^^^^^^^^^^^
``mlx.distributed_config`` can help you debug the connectivity of your nodes
over thunderbolt by exporting a graph of the connections.
Running
.. code-block::
mlx.distributed_config --verbose \
--hosts host1,host2,host3,host4 \
--over thunderbolt --dot
will export a `GraphViz <https://graphviz.org>`_ representation of the
connections between the nodes which makes it very easy to figure out which
cable is not connected correctly.
See :ref:`the JACCL section <jaccl_section>` for an example.
``mlx.launch``
--------------
The minimal usage example of ``mlx.launch`` is simply
@@ -33,6 +126,10 @@ the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
It also takes care of forwarding the output of each remote process to stdout
and stderr respectively.
Importantly, it also broadcasts stdin to each process which enables interactive
programs to work in distributed mode as well as debugging using the interactive
debugger.
Providing Hosts
^^^^^^^^^^^^^^^^
@@ -63,10 +160,62 @@ host and on the same path. A good checklist to debug errors is the following:
``mlx.launch --print-python`` to see what that path is.
* the script you want to run is available on all hosts at the same path
If you are launching from a node with a completely different setup than the
nodes that the program will run on, you can specify ``--no-verify-script`` so
that ``mlx.launch`` does not attempt to verify that the executable and script
exist locally before launching the distributed job.
.. _ring_specifics:
Ring Specifics
^^^^^^^^^^^^^^
The :ref:`ring <ring_section>` backend, which is also the default
backend, can be explicitly selected with the argument ``--backend ring``. The
ring backend has some specific requirements and arguments that are different to
other backends:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.
.. _jaccl_specifics:
JACCL Specifics
^^^^^^^^^^^^^^^^
The :ref:`JACCL <jaccl_section>` backend can be selected with the argument
``--backend jaccl``. A hostfile is necessary to launch with this backend
because it needs to contain the RDMA devices connecting each node to each other
node.
NCCL Specifics
^^^^^^^^^^^^^^
The :ref:`NCCL <nccl_section>` backend is the default backend for CUDA
environments. When launching from a Mac to a Linux machine with CUDA then the
backend should be selected using ``--backend nccl``.
The ``--repeat-hosts, -n`` argument should be used to launch multi-node and
multi-gpu jobs. For instance
.. code-block::
mlx.launch --backend nccl --hosts linux-1,linux-2 -n 8 --no-verify-script -- ./my-job.sh
will attempt to launch 16 processes, 8 on each node that will all run
``my-job.sh``.
.. _mpi_specifics:
MPI Specifics
-------------
^^^^^^^^^^^^^
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
@@ -83,23 +232,3 @@ to choose a specific interface for the byte-transfer-layer of MPI we can call
.. code:: shell
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
.. _ring_specifics:
Ring Specifics
--------------
The ring backend, which is also the default backend, can be explicitly selected
with the argument ``--backend ring``. The ring backend has some specific
requirements and arguments that are different to MPI:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.

View File

@@ -14,17 +14,14 @@ void array_basics() {
// Get the value out of it:
auto s = x.item<float>();
assert(s == 1.0);
(void)s;
// Scalars have a size of 1:
int64_t size = x.size();
size_t size = x.size();
assert(size == 1);
(void)size;
// Scalars have 0 dimensions:
int ndim = x.ndim();
assert(ndim == 0);
(void)ndim;
// The shape should be an empty vector:
auto shape = x.shape();
@@ -33,7 +30,6 @@ void array_basics() {
// The datatype should be float32:
auto dtype = x.dtype();
assert(dtype == mx::float32);
(void)dtype;
// Specify the dtype when constructing the array:
x = mx::array(1, mx::int32);

View File

@@ -1,7 +1,6 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp

View File

@@ -1,24 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <sstream>
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
void free(Buffer buffer) {
allocator().free(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr) {};
explicit Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer
void* raw_ptr();
@@ -28,16 +28,16 @@ class Buffer {
};
};
Buffer malloc(size_t size);
void free(Buffer buffer);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
virtual Buffer make_buffer(void* ptr, size_t size) {
return Buffer{nullptr};
};
virtual void release(Buffer buffer) {}
Allocator() = default;
Allocator(const Allocator& other) = delete;
@@ -49,4 +49,25 @@ class Allocator {
Allocator& allocator();
inline Buffer malloc(size_t size) {
return allocator().malloc(size);
}
inline void free(Buffer buffer) {
allocator().free(buffer);
}
// Make a Buffer from a raw pointer of the given size without a copy. If a
// no-copy conversion is not possible then the returned buffer.ptr() will be
// nullptr. Any buffer created with this function must be released with
// release(buffer)
inline Buffer make_buffer(void* ptr, size_t size) {
return allocator().make_buffer(ptr, size);
};
// Release a buffer from the allocator made with make_buffer
inline void release(Buffer buffer) {
allocator().release(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -44,11 +44,11 @@ std::vector<array> array::make_arrays(
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (int i = 0; i < std::ssize(shapes); ++i) {
for (size_t i = 0; i < shapes.size(); ++i) {
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
}
// For each node in |outputs|, its siblings are the other nodes.
for (int i = 0; i < std::ssize(outputs); ++i) {
for (size_t i = 0; i < outputs.size(); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);
@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
other.strides(),
other.flags(),
[](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
cpy.array_desc_->offset = other.array_desc_->offset;
return cpy;
}
@@ -82,6 +82,28 @@ array::array(std::initializer_list<int> data, Dtype dtype)
init(data.begin());
}
array::array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
auto buffer = allocator::make_buffer(data, nbytes());
if (buffer.ptr() == nullptr) {
set_data(allocator::malloc(nbytes()));
auto ptr = static_cast<char*>(data);
std::copy(ptr, ptr + nbytes(), this->data<char>());
deleter(data);
} else {
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
auto ptr = buffer.ptr();
allocator::release(buffer);
return deleter(ptr);
};
set_data(buffer, std::move(wrapped_deleter));
}
}
/* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
@@ -141,13 +163,12 @@ bool array::is_tracer() const {
void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->offset = 0;
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
auto max_dim =
static_cast<int64_t>(*std::max_element(shape().begin(), shape().end()));
array_desc_->flags.col_contiguous = size() <= 1 || size() == max_dim;
auto max_dim = std::max_element(shape().begin(), shape().end());
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
}
void array::set_data(
@@ -157,7 +178,7 @@ void array::set_data(
Flags flags,
Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->offset = 0;
array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides);
array_desc_->flags = flags;
@@ -168,14 +189,13 @@ void array::copy_shared_buffer(
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
int64_t offset /* = 0 */) {
array_desc_->data = other.array_desc_->data;
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
array_desc_->offset =
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
}
void array::copy_shared_buffer(const array& other) {
@@ -193,7 +213,7 @@ array::~array() {
}
// Break circular reference for non-detached arrays with siblings
if (auto n = std::ssize(siblings()); n > 0) {
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
@@ -275,7 +295,7 @@ array::ArrayDesc::~ArrayDesc() {
ad.inputs.clear();
for (auto& [_, a] : input_map) {
bool is_deletable =
(a.array_desc_.use_count() <= std::ssize(a.siblings()) + 1);
(a.array_desc_.use_count() <= a.siblings().size() + 1);
// An array with siblings is deletable only if all of its siblings
// are deletable
for (auto& s : a.siblings()) {
@@ -284,7 +304,7 @@ array::ArrayDesc::~ArrayDesc() {
}
int is_input = (input_map.find(s.id()) != input_map.end());
is_deletable &=
s.array_desc_.use_count() <= std::ssize(a.siblings()) + is_input;
s.array_desc_.use_count() <= a.siblings().size() + is_input;
}
if (is_deletable) {
for_deletion.push_back(std::move(a.array_desc_));

View File

@@ -57,6 +57,16 @@ class array {
Shape shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a raw pointer. The constructor will attempt to use the
* input data without a copy. The deleter will be called when the array no
* longer needs the underlying memory - after the array is destroyed in the
* no-copy case and after the copy otherwise. */
explicit array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter);
/* Build an array from a buffer */
explicit array(
allocator::Buffer data,
@@ -81,22 +91,22 @@ class array {
}
/** The size of the array's datatype in bytes. */
int itemsize() const {
size_t itemsize() const {
return size_of(dtype());
}
/** The number of elements in the array. */
int64_t size() const {
size_t size() const {
return array_desc_->size;
}
/** The number of bytes in the array. */
int64_t nbytes() const {
size_t nbytes() const {
return size() * itemsize();
}
/** The number of dimensions of the array. */
int ndim() const {
size_t ndim() const {
return array_desc_->shape.size();
}
@@ -294,6 +304,11 @@ class array {
return array_desc_->siblings;
}
/** The array's position in the sibling list. */
int sibling_position() const {
return array_desc_->position;
}
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;
@@ -329,7 +344,7 @@ class array {
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
**/
int64_t data_size() const {
size_t data_size() const {
return array_desc_->data_size;
}
@@ -340,7 +355,7 @@ class array {
return array_desc_->data->buffer;
}
int64_t buffer_size() const {
size_t buffer_size() const {
return allocator::allocator().size(buffer());
}
@@ -349,15 +364,23 @@ class array {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
// Return a raw pointer to the arrays data. This function may do a copy if
// the underlying buffer is not accessible on the CPU. When accessing the
// data for GPU kernels, be sure to use the correct method / function for the
// given backend to access the GPU pointer.
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
return reinterpret_cast<T*>(
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
}
template <typename T>
const T* data() const {
return static_cast<T*>(array_desc_->data_ptr);
return const_cast<array&>(*this).data<T>();
}
int64_t offset() const {
return array_desc_->offset;
}
enum Status {
@@ -426,7 +449,7 @@ class array {
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
int64_t offset = 0);
void copy_shared_buffer(const array& other);
@@ -461,8 +484,8 @@ class array {
// can share the underlying data buffer.
std::shared_ptr<Data> data;
// Properly offset data pointer
void* data_ptr{nullptr};
// Offset from beginning of data pointer
int64_t offset{0};
// The size in elements of the data buffer the array accesses
size_t data_size;
@@ -530,7 +553,7 @@ array::array(
Shape shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (std::ssize(data) != size()) {
if (data.size() != size()) {
throw std::invalid_argument(
"Data size and provided shape mismatch in array construction.");
}

View File

@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
const array& a,
const array& b,
array& out,
BinaryOpType bopt) {
BinaryOpType bopt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(b.data_size() * out.itemsize()),
mallocfn(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
mallocfn(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b);
} else {
out.set_data(
allocator::malloc(a.data_size() * out.itemsize()),
mallocfn(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
break;
}

View File

@@ -6,7 +6,7 @@ namespace mlx::core {
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
Strides strides(out.ndim(), 0);

View File

@@ -21,8 +21,8 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
int64_t r = 1, c = 1;
for (int i = std::ssize(strides_) - 1, j = 0; i >= 0; i--, j++) {
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
@@ -60,8 +60,7 @@ void CustomTransforms::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = std::ssize(inputs) - std::ssize(outputs);
i < std::ssize(outputs);
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
@@ -71,7 +70,7 @@ void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < std::ssize(outputs); i++) {
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
@@ -207,11 +206,11 @@ void Split::eval(
auto compute_new_flags = [](const auto& shape,
const auto& strides,
int64_t in_data_size,
size_t in_data_size,
auto flags) {
int64_t data_size = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
@@ -241,7 +240,7 @@ void Split::eval(
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < std::ssize(indices); i++) {
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
@@ -255,7 +254,7 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0];
Strides strides;
for (int i = 0, j = 0; i < in.ndim(); ++i) {
if (j < std::ssize(axes_) && i == axes_[j]) {
if (j < axes_.size() && i == axes_[j]) {
j++;
} else {
strides.push_back(in.strides(i));
@@ -273,7 +272,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
Strides out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < std::ssize(axes_); ++ax) {
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}

View File

@@ -114,13 +114,15 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous) {
bool contiguous,
const std::function<allocator::Buffer(size_t)>&
mallocfn /* = allocator::malloc */) {
if (contiguous) {
int o = 0;
Strides strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Correct size
@@ -138,16 +140,16 @@ void compiled_allocate_outputs(
data_size = in.data_size();
}
}
for (; o < std::ssize(outputs); ++o) {
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc(data_size * outputs[o].itemsize()),
mallocfn(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < std::ssize(inputs) && o < std::ssize(outputs); ++i) {
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
@@ -162,8 +164,8 @@ void compiled_allocate_outputs(
o++;
}
}
for (; o < std::ssize(outputs); ++o) {
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
for (; o < outputs.size(); ++o) {
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
}
}
}
@@ -193,7 +195,7 @@ std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
// Broadcast the inputs to the output shape.
Strides xstrides;
int j = 0;
size_t j = 0;
for (; j < shape.size() - x.ndim(); ++j) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
@@ -201,7 +203,7 @@ std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); ++i, ++j) {
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
@@ -224,13 +226,13 @@ bool compiled_use_large_index(
const std::vector<array>& outputs,
bool contiguous) {
if (contiguous) {
int64_t max_size = 0;
size_t max_size = 0;
for (const auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
return max_size > UINT32_MAX;
} else {
int64_t max_size = 0;
size_t max_size = 0;
for (const auto& o : outputs) {
max_size = std::max(max_size, o.size());
}

View File

@@ -58,7 +58,9 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant,
bool contiguous);
bool contiguous,
const std::function<allocator::Buffer(size_t)>& mallocfn =
allocator::malloc);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(

View File

@@ -22,7 +22,11 @@ enum class CopyType {
GeneralGeneral
};
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
inline bool set_copy_output_data(
const array& in,
array& out,
CopyType ctype,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
@@ -31,14 +35,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
return true;
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
return false;
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
return false;
}
}

View File

@@ -27,7 +27,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
namespace mlx::core {
void Load::eval_cpu(const std::vector<array>& /* inputs */, array& out) {
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc(out.nbytes()));
auto read_task = [out_ptr = out.data<char>(),
size = out.size(),

View File

@@ -28,7 +28,7 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && std::ssize(axes) == x.ndim() &&
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
return ContiguousAllReduce;
}
@@ -38,7 +38,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Merge consecutive axes
Shape shape = {x.shape(axes[0])};
Strides strides = {x.strides()[axes[0]]};
for (int i = 1; i < std::ssize(axes); i++) {
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];

View File

@@ -14,10 +14,6 @@ std::tuple<int64_t, Strides> prepare_slice(
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
}
// Normalize the offset
if (data_offset < 0) {
data_offset += in.data_size();
}
return std::make_tuple(data_offset, inp_strides);
}
@@ -25,7 +21,7 @@ void shared_buffer_slice(
const array& in,
const Strides& out_strides,
int64_t data_offset,
int64_t data_size,
size_t data_size,
array& out) {
// Compute row/col contiguity
auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =
@@ -45,23 +41,30 @@ void slice(
const Shape& start_indices,
const Shape& strides) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
// Calculate out strides, initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
int64_t data_end = 1;
for (int i = 0; i < start_indices.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
// Get the location of the end based on the inp strides and out.shape()
int64_t low_idx = 0;
int64_t high_idx = 0;
for (int i = 0; i < inp_strides.size(); ++i) {
auto delta = inp_strides[i] * (out.shape()[i] - 1);
if (inp_strides[i] > 0) {
high_idx += delta;
} else {
low_idx += delta;
}
}
if (data_end < 0) {
data_end += in.data_size();
int64_t data_size = (high_idx - low_idx) + 1;
if (data_size < 0) {
std::ostringstream msg;
msg << "[slice] Computed invalid data size: " << data_size << ".";
throw std::runtime_error(msg.str());
}
int64_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}

View File

@@ -46,7 +46,8 @@ inline void set_ternary_op_output_data(
const array& b,
const array& c,
array& out,
TernaryOpType topt) {
TernaryOpType topt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
auto maybe_donate = [&out](const array& x) {
if (is_donatable(x, out)) {
out.copy_shared_buffer(x);
@@ -57,13 +58,12 @@ inline void set_ternary_op_output_data(
switch (topt) {
case TernaryOpType::ScalarScalarScalar:
out.set_data(
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
break;
case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data(
allocator::malloc(out.itemsize() * b.data_size()),
mallocfn(out.itemsize() * b.data_size()),
b.data_size(),
b.strides(),
b.flags());
@@ -76,7 +76,7 @@ inline void set_ternary_op_output_data(
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
break;
}

View File

@@ -7,19 +7,22 @@
namespace mlx::core {
inline void set_unary_output_data(const array& in, array& out) {
inline void set_unary_output_data(
const array& in,
array& out,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (in.flags().contiguous) {
if (is_donatable(in, out)) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc(in.data_size() * out.itemsize()),
mallocfn(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(mallocfn(out.nbytes()));
}
}

View File

@@ -28,7 +28,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
if (shape[0] != 1) {
to_collapse.push_back(0);
}
int64_t size = shape[0];
size_t size = shape[0];
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
size *= shape[i];
@@ -64,7 +64,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
current_shape *= shape[to_collapse[k]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < std::ssize(strides); j++) {
for (int j = 0; j < strides.size(); j++) {
const auto& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]);
}

View File

@@ -162,7 +162,7 @@ struct ContiguousIterator {
};
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
int64_t no_broadcast_data_size = 1;
size_t no_broadcast_data_size = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
bool is_row_contiguous = true;
@@ -183,7 +183,7 @@ inline auto check_contiguity(const Shape& shape, const Strides& strides) {
}
inline bool is_donatable(const array& in, const array& out) {
constexpr int64_t donation_extra = 16384;
constexpr size_t donation_extra = 16384;
return in.is_donatable() && in.itemsize() == out.itemsize() &&
in.buffer_size() <= out.nbytes() + donation_extra;

View File

@@ -10,7 +10,7 @@ namespace mlx::core {
namespace {
template <typename T>
void arange(T start, T next, array& out, int64_t size, Stream stream) {
void arange(T start, T next, array& out, size_t size, Stream stream) {
auto ptr = out.data<T>();
auto step_size = next - start;
auto& encoder = cpu::get_command_encoder(stream);

View File

@@ -19,12 +19,12 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();
for (int64_t i = 0; i < out.size(); ++i) {
for (uint32_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto local_in_ptr = in_ptr + loc;
uint32_t ind_v = 0;
InT v = (*local_in_ptr);
for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*local_in_ptr), &ind_v, &v);
}
out_ptr[i] = ind_v;

View File

@@ -14,238 +14,11 @@
namespace mlx::core {
namespace {
template <typename Op>
void binary(
const array& a,
const array& b,
array& out,
Op /* op */,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op /* op */,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op /* op */,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op /* op */,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Add(), stream());
binary_op_cpu(a, b, out, detail::Add(), stream());
}
void DivMod::eval_cpu(
@@ -329,14 +102,14 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Divide(), stream());
binary_op_cpu(a, b, out, detail::Divide(), stream());
}
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Remainder(), stream());
binary_op_cpu(a, b, out, detail::Remainder(), stream());
}
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -377,89 +150,90 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
}
});
} else {
comparison_op(a, b, out, detail::Equal(), stream());
comparison_op_cpu(a, b, out, detail::Equal(), stream());
}
}
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream());
}
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
comparison_op_cpu(
inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
}
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream());
}
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream());
}
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary_float(a, b, out, detail::LogAddExp(), stream());
binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream());
}
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd(), stream());
binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream());
}
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr(), stream());
binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream());
}
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Maximum(), stream());
binary_op_cpu(a, b, out, detail::Maximum(), stream());
}
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Minimum(), stream());
binary_op_cpu(a, b, out, detail::Minimum(), stream());
}
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Multiply(), stream());
binary_op_cpu(a, b, out, detail::Multiply(), stream());
}
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream());
}
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Power(), stream());
binary_op_cpu(a, b, out, detail::Power(), stream());
}
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
binary(a, b, out, detail::Subtract(), stream());
binary_op_cpu(a, b, out, detail::Subtract(), stream());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -468,19 +242,19 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
switch (op_) {
case BitwiseBinary::And:
binary_int(a, b, out, detail::BitwiseAnd(), stream());
binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream());
break;
case BitwiseBinary::Or:
binary_int(a, b, out, detail::BitwiseOr(), stream());
binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream());
break;
case BitwiseBinary::Xor:
binary_int(a, b, out, detail::BitwiseXor(), stream());
binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream());
break;
case BitwiseBinary::LeftShift:
binary_int(a, b, out, detail::LeftShift(), stream());
binary_int_op_cpu(a, b, out, detail::LeftShift(), stream());
break;
case BitwiseBinary::RightShift:
binary_int(a, b, out, detail::RightShift(), stream());
binary_int_op_cpu(a, b, out, detail::RightShift(), stream());
break;
}
}
@@ -489,7 +263,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
binary_float(a, b, out, detail::ArcTan2(), stream());
binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream());
}
} // namespace mlx::core

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core {
@@ -290,4 +291,227 @@ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T, Op>(a, b, out, bopt);
}
template <typename Op>
void binary_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace mlx::core

View File

@@ -99,7 +99,7 @@ void binary_op_dispatch_dims(
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (int64_t elem = 0; elem < std::ssize(a); elem += stride) {
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
@@ -137,21 +137,21 @@ void binary_op(
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
} else if (bopt == BinaryOpType::ScalarVector) {
for (int64_t i = 0; i < b.data_size(); ++i) {
for (size_t i = 0; i < b.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
b_ptr++;
}
} else if (bopt == BinaryOpType::VectorScalar) {
for (int64_t i = 0; i < a.data_size(); ++i) {
for (size_t i = 0; i < a.data_size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;
a_ptr++;
}
} else { // VectorVector
for (int64_t i = 0; i < a.size(); ++i) {
for (size_t i = 0; i < a.size(); ++i) {
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
out_a_ptr++;
out_b_ptr++;

View File

@@ -33,8 +33,8 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
N = a.shape(-1),
size = a.size()]() mutable {
char uplo = (upper) ? 'L' : 'U';
int64_t num_matrices = size / (N * N);
for (int64_t i = 0; i < num_matrices; i++) {
size_t num_matrices = size / (N * N);
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
potrf<T>(

View File

@@ -49,7 +49,7 @@ static CompilerCache& cache() {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& /* device */) {
bool compile_available_for_device(const Device& device) {
return true;
}
@@ -168,7 +168,7 @@ inline void build_kernel(
// Add the input arguments
int cnt = 0;
int strides_index = 1;
for (int i = 0; i < std::ssize(inputs); ++i) {
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants from the input list
if (is_constant(i)) {
continue;
@@ -238,7 +238,7 @@ inline void build_kernel(
} else {
os << x.primitive().name();
os << "()(";
for (int i = 0; i < std::ssize(x.inputs()) - 1; i++) {
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;

View File

@@ -860,7 +860,7 @@ void explicit_gemm_conv_1D_cpu(
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& /* wt_dilation */,
const std::vector<int>& wt_dilation,
Stream stream) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
@@ -1003,7 +1003,7 @@ void explicit_gemm_conv_ND_cpu(
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& /* wt_dilation */,
const std::vector<int>& wt_dilation,
const bool flip,
Stream stream) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
@@ -1023,7 +1023,7 @@ void explicit_gemm_conv_ND_cpu(
// Pad input
Shape padded_shape(in.shape().size());
padded_shape.front() = N;
for (int i = 0; i < iDim.size(); i++) {
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
}
padded_shape.back() = C;
@@ -1054,20 +1054,20 @@ void explicit_gemm_conv_ND_cpu(
// Make strided view
Shape strided_shape(oDim.size() + wDim.size() + 2);
strided_shape.front() = N;
for (int i = 0; i < oDim.size(); i++) {
for (size_t i = 0; i < oDim.size(); i++) {
strided_shape[i + 1] = oDim[i];
}
for (int i = 0; i < wDim.size(); i++) {
for (size_t i = 0; i < wDim.size(); i++) {
strided_shape[i + 1 + oDim.size()] = wDim[i];
}
strided_shape.back() = C;
Strides strided_strides(in.shape().size() * 2 - 2);
strided_strides[0] = in_padded.strides()[0];
for (int i = 0; i < std::ssize(wt_strides); i++) {
for (size_t i = 0; i < wt_strides.size(); i++) {
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
}
for (int i = 1; i < std::ssize(in_padded.strides()); i++) {
for (size_t i = 1; i < in_padded.strides().size(); i++) {
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
}

View File

@@ -90,10 +90,14 @@ void Recv::eval_cpu(
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
(void)inputs;
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_, stream());
}
void ReduceScatter::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
}
} // namespace mlx::core::distributed

View File

@@ -12,6 +12,167 @@ namespace mlx::core {
namespace {
template <typename T>
complex64_t to_complex(T r, T i) {
return {static_cast<float>(r), static_cast<float>(i)};
}
template <typename T, class Enable = void>
struct EigWork {};
template <typename T>
struct EigWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using O = complex64_t;
char jobl;
char jobr;
int N;
int lwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
T work;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
if (compute_eigenvectors) {
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
}
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, O* values, O* vectors) {
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
T* vec_tmp = nullptr;
if (vectors) {
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
}
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
eig_tmp,
eig_tmp + N,
vectors ? vec_tmp : nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
work,
&lwork,
&info);
for (int i = 0; i < N; ++i) {
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
}
if (vectors) {
for (int i = 0; i < N; ++i) {
if (values[i].imag() != 0) {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] =
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
vectors[(i + 1) * N + j] =
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
}
}
}
}
}
};
template <>
struct EigWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
using O = T;
char jobl;
char jobr;
int N;
int lwork;
int lrwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
T work;
R rwork;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&rwork,
&info);
lwork = static_cast<int>(work.real());
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
}
void run(T* a, T* values, T* vectors) {
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
values,
vectors,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&info);
}
};
template <typename T>
void eig_impl(
array& a,
@@ -19,101 +180,39 @@ void eig_impl(
array& values,
bool compute_eigenvectors,
Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>();
auto eig_ptr = values.data<OT>();
auto val_ptr = values.data<complex64_t>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
OT* vec_ptr = nullptr;
complex64_t* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<OT>();
vec_ptr = vectors.data<complex64_t>();
}
encoder.dispatch([a_ptr,
val_ptr,
vec_ptr,
eig_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
int info;
{
T work;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
}
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
for (int64_t i = 0; i < size / (N * N); ++i) {
geev<T>(
&jobl,
&jobr,
&N,
a_ptr,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
}
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
for (size_t i = 0; i < size / (N * N); ++i) {
work.run(a_ptr, val_ptr, vec_ptr);
a_ptr += N * N;
val_ptr += N;
if (vec_ptr) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
}
}
}
vec_ptr += N * N;
}
a_ptr += N * N;
eig_ptr += N;
if (info != 0) {
if (work.info != 0) {
std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
<< work.info;
throw std::runtime_error(msg.str());
}
}
@@ -165,8 +264,17 @@ void Eig::eval_cpu(
case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case float64:
eig_impl<double>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case complex64:
eig_impl<std::complex<float>>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
throw std::runtime_error(
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
}
}

View File

@@ -165,7 +165,7 @@ void eigh_impl(
EighWork<T> work(jobz, uplo, N);
// Work loop
for (int64_t i = 0; i < size / (N * N); ++i) {
for (size_t i = 0; i < size / (N * N); ++i) {
work.run(vec_ptr, eig_ptr);
vec_ptr += N * N;
eig_ptr += N;

View File

@@ -20,8 +20,8 @@ struct CommandEncoder {
CommandEncoder(CommandEncoder&&) = delete;
CommandEncoder& operator=(CommandEncoder&&) = delete;
void set_input_array(const array& /* a */) {}
void set_output_array(array& /* a */) {}
void set_input_array(const array& a) {}
void set_output_array(array& a) {}
// Hold onto a temporary until any already scheduled tasks which use it as
// an input are complete.

View File

@@ -12,12 +12,12 @@ void matmul(
T* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,

View File

@@ -34,7 +34,7 @@ void matmul_bnns(
bool b_transposed,
size_t lda,
size_t ldb,
size_t /* ldc */,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
@@ -52,7 +52,7 @@ void matmul_bnns(
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (size_t i = 0; i < batch_size * M * N; ++i) {
for (auto i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
@@ -127,7 +127,7 @@ void matmul_bnns(
auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (size_t i = 0; i < batch_size; ++i) {
for (int i = 0; i < batch_size; ++i) {
BNNSFilterApplyTwoInput(
bnns_filter,
reinterpret_cast<const uint8_t*>(
@@ -148,12 +148,12 @@ void matmul<float16_t>(
float16_t* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
@@ -183,12 +183,12 @@ void matmul<bfloat16_t>(
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,

View File

@@ -13,20 +13,20 @@ void matmul<float>(
float* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
int64_t M = a_shape[ndim - 2];
int64_t N = b_shape[ndim - 1];
int64_t K = a_shape[ndim - 1];
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(
@@ -54,20 +54,20 @@ void matmul<double>(
double* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
int64_t M = a_shape[ndim - 2];
int64_t N = b_shape[ndim - 1];
int64_t K = a_shape[ndim - 1];
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(
@@ -95,20 +95,20 @@ void matmul<complex64_t>(
complex64_t* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
int64_t ldc,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
int64_t batch_size,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
int64_t M = a_shape[ndim - 2];
int64_t N = b_shape[ndim - 1];
int64_t K = a_shape[ndim - 1];
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);

View File

@@ -11,9 +11,9 @@ namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(T* out, int n, int /* m */, float scale, int64_t size) {
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
for (int b = 0; b < size / n; b++) {
int64_t loc = b * n;
size_t loc = b * n;
T* data_ptr = out + loc;
int h = 1;
int n_over_2 = n / 2;
@@ -37,7 +37,7 @@ void hadamard_n(T* out, int n, int /* m */, float scale, int64_t size) {
// m component
template <typename T>
void hadamard_m(T* out, int n, int m, float scale, int64_t size) {
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
@@ -45,7 +45,7 @@ void hadamard_m(T* out, int n, int m, float scale, int64_t size) {
std::vector<bool> hmat_vec;
while (end != std::string_view::npos) {
auto row = matrix.substr(start, end - start);
for (int i = 0; i < std::ssize(row); i++) {
for (int i = 0; i < row.length(); i++) {
hmat_vec.push_back(row[i] == '+');
}
start = end + 1;
@@ -53,7 +53,7 @@ void hadamard_m(T* out, int n, int m, float scale, int64_t size) {
}
for (int b = 0; b < size / m / n; b++) {
int64_t loc = b * n * m;
size_t loc = b * n * m;
T* data_ptr = out + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);

View File

@@ -78,7 +78,7 @@ void gather(
can_copy = true;
// Ignore leading 1s
int64_t i = 0;
int i = 0;
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
;
@@ -91,7 +91,7 @@ void gather(
can_copy = true;
// Ignore trailing 1s
int64_t i = slice_sizes.size() - 1;
int i = slice_sizes.size() - 1;
for (; i >= 0 && slice_sizes[i] == 1; --i)
;
@@ -101,11 +101,11 @@ void gather(
can_copy = (src.shape(i) == slice_sizes[i]);
}
}
int64_t slice_size = 1;
size_t slice_size = 1;
for (auto s : slice_sizes) {
slice_size *= s;
}
int64_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
@@ -115,10 +115,10 @@ void gather(
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
}
int64_t out_idx = 0;
for (int64_t idx = 0; idx < ind_size; idx++) {
int64_t src_idx = 0;
for (int ii = 0; ii < std::ssize(inds); ++ii) {
size_t out_idx = 0;
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii];
auto idx_loc = its[ii].loc;
its[ii].step();
@@ -134,7 +134,7 @@ void gather(
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
out_idx += slice_size;
} else {
for (int64_t jj = 0; jj < slice_size; jj++) {
for (int jj = 0; jj < slice_size; jj++) {
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
src_it.step();
}
@@ -403,11 +403,11 @@ void scatter(
const std::vector<int>& axes) {
int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim();
int64_t n_updates = nind ? inds[0].size() : 1;
size_t n_updates = nind ? inds[0].size() : 1;
Shape update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end());
int64_t update_size = 1;
size_t update_size = 1;
for (auto us : update_shape) {
update_size *= us;
}
@@ -418,9 +418,9 @@ void scatter(
auto out_ptr = out.data<InT>();
auto upd_ptr = updates.data<InT>();
for (int64_t i = 0; i < n_updates; ++i) {
int64_t out_offset = 0;
for (int j = 0; j < std::ssize(inds); ++j) {
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < inds.size(); ++j) {
auto ax = axes[j];
auto idx_loc = its[j].loc;
its[j].step();
@@ -429,7 +429,7 @@ void scatter(
out_offset += (idx_val * out.strides()[ax]);
}
update_it.seek(i * update_size);
for (int64_t j = 0; j < update_size; ++j) {
for (int j = 0; j < update_size; ++j) {
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
update_it.step();
out_it.step();
@@ -747,4 +747,108 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
});
}
template <typename T>
void masked_scatter_impl(const array& mask, const array& src, array& out) {
ContiguousIterator mask_it(mask);
ContiguousIterator src_it(src);
ContiguousIterator out_it(out);
const bool* mask_ptr = mask.data<bool>();
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
const size_t batch_count = mask.shape(0);
const size_t mask_batch_size = mask.size() / batch_count;
const size_t src_batch_size = src.size() / batch_count;
for (uint b = 0; b < batch_count; ++b) {
size_t src_consumed = 0;
src_it.seek(b * src_batch_size);
for (size_t i = 0; i < mask_batch_size; ++i) {
if (mask_ptr[mask_it.loc]) {
if (src_consumed >= src_batch_size) {
throw std::runtime_error(
"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.");
}
dst_ptr[out_it.loc] = src_ptr[src_it.loc];
src_it.step();
++src_consumed;
}
mask_it.step();
out_it.step();
}
}
}
void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
auto& dst = inputs[0];
auto& mask = inputs[1];
auto& src = inputs[2];
// Copy src into out (copy allocates memory for out)
auto ctype =
dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(dst, out, ctype, stream());
if (mask.size() == 0) {
return;
}
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(mask);
encoder.set_input_array(src);
encoder.set_output_array(out);
encoder.dispatch([mask = array::unsafe_weak_copy(mask),
src = array::unsafe_weak_copy(src),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
masked_scatter_impl<bool>(mask, src, out);
break;
case uint8:
masked_scatter_impl<uint8_t>(mask, src, out);
break;
case uint16:
masked_scatter_impl<uint16_t>(mask, src, out);
break;
case uint32:
masked_scatter_impl<uint32_t>(mask, src, out);
break;
case uint64:
masked_scatter_impl<uint64_t>(mask, src, out);
break;
case int8:
masked_scatter_impl<int8_t>(mask, src, out);
break;
case int16:
masked_scatter_impl<int16_t>(mask, src, out);
break;
case int32:
masked_scatter_impl<int32_t>(mask, src, out);
break;
case int64:
masked_scatter_impl<int64_t>(mask, src, out);
break;
case float16:
masked_scatter_impl<float16_t>(mask, src, out);
break;
case float32:
masked_scatter_impl<float>(mask, src, out);
break;
case float64:
masked_scatter_impl<double>(mask, src, out);
break;
case bfloat16:
masked_scatter_impl<bfloat16_t>(mask, src, out);
break;
case complex64:
masked_scatter_impl<complex64_t>(mask, src, out);
break;
}
});
}
} // namespace mlx::core

View File

@@ -122,7 +122,7 @@ void inverse_impl(
stream);
const int N = a.shape(-1);
const int64_t num_matrices = a.size() / (N * N);
const size_t num_matrices = a.size() / (N * N);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(inv);
@@ -130,13 +130,13 @@ void inverse_impl(
auto inv_ptr = inv.data<T>();
if (tri) {
encoder.dispatch([inv_ptr, N, num_matrices, upper]() {
for (int64_t i = 0; i < num_matrices; i++) {
for (int i = 0; i < num_matrices; i++) {
tri_inv<T>(inv_ptr + N * N * i, N, upper);
}
});
} else {
encoder.dispatch([inv_ptr, N, num_matrices]() {
for (int64_t i = 0; i < num_matrices; i++) {
for (int i = 0; i < num_matrices; i++) {
general_inv<T>(inv_ptr + N * N * i, N);
}
});

View File

@@ -45,9 +45,7 @@
INSTANTIATE_LAPACK_REAL(geqrf)
INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesdd)
INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri)
@@ -63,3 +61,20 @@ INSTANTIATE_LAPACK_REAL(trtri)
}
INSTANTIATE_LAPACK_COMPLEX(heevd)
#define INSTANTIATE_LAPACK_ALL(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, double>) { \
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_ALL(geev)
INSTANTIATE_LAPACK_ALL(gesdd)

View File

@@ -25,7 +25,7 @@ inline void mask_matrix(
const int64_t Y_data_str,
const int64_t X_mask_str,
const int64_t Y_mask_str,
const int64_t mask_offset) {
const size_t mask_offset) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
@@ -61,13 +61,13 @@ inline void segmented_mm(
T* out,
bool a_transposed,
bool b_transposed,
int64_t lda,
int64_t ldb,
size_t lda,
size_t ldb,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides,
int64_t num_segments,
size_t num_segments,
const Shape& segments_shape,
const Strides& segments_strides) {
int ndim = a_shape.size();
@@ -149,9 +149,9 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto [b_transposed, ldb, b, b_copied] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
int64_t M = a.shape(-2);
int64_t N = b.shape(-1);
int64_t K = a.shape(-1);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
@@ -172,8 +172,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
int batch_idx,
int X,
int Y,
int64_t X_data_str,
int64_t Y_data_str,
size_t X_data_str,
size_t Y_data_str,
const Shape& mask_shape,
const Strides& mask_strides,
bool is_bool) {
@@ -253,7 +253,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto a_ptr = a.data<float>();
auto b_ptr = b.data<float>();
auto out_ptr = out.data<float>();
int64_t num_matrices = out.size() / (M * int64_t(N));
size_t num_matrices = out.size() / (M * size_t(N));
auto ldc = out.shape(-1);
encoder.dispatch([a_ptr,
@@ -394,9 +394,9 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
int64_t M = a.shape(-2);
int64_t N = b.shape(-1);
int64_t K = a.shape(-1);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
@@ -413,7 +413,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
// Get batch dims
auto batch_size_out = out.size() / (M * N);
int64_t matrix_stride_out = M * N;
size_t matrix_stride_out = M * N;
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};

View File

@@ -2,6 +2,8 @@
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h"
@@ -135,15 +137,29 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
return;
}
// Handle empty matrix case (K=0)
if (inputs[0].shape(-1) == 0) {
auto& c = inputs[2];
if (beta_ == 1.0f) {
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream());
} else {
array beta_scalar = array(beta_, c.dtype());
auto& encoder = cpu::get_command_encoder(stream());
binary_float_op_cpu(c, beta_scalar, out, detail::Multiply(), stream());
encoder.add_temporary(std::move(beta_scalar));
}
return;
}
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
}

View File

@@ -48,7 +48,7 @@ static std::pair<array, bool> compute_dynamic_offset(
auto compute_offset =
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
int64_t offset_ = 0;
for (int i = 0; i < std::ssize(axes); ++i) {
for (int i = 0; i < axes.size(); ++i) {
offset_ += indices[i] * strides[axes[i]];
}
offset[0] = offset_;
@@ -124,7 +124,6 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 0);
(void)inputs;
out.set_data(allocator::malloc(out.nbytes()));
switch (out.dtype()) {
case bool_:
@@ -194,9 +193,9 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
for (int i = 0; i < std::ssize(inputs); i++) {
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
int64_t data_offset = strides[axis_] * sizes[i];
size_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
@@ -206,7 +205,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr int64_t extra_bytes = 16384;
constexpr size_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
@@ -255,8 +254,8 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
copy_cpu(val, out, CopyType::Scalar, stream());
// Find offset for start of input values
int64_t data_offset = 0;
for (int i = 0; i < std::ssize(axes_); i++) {
size_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i];
}
@@ -275,10 +274,10 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
int64_t num_keys = keys.size() / 2;
size_t num_keys = keys.size() / 2;
int64_t elems_per_key = out.size() / num_keys;
int64_t bytes_per_key = out.itemsize() * elems_per_key;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
@@ -292,8 +291,8 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
num_keys,
kshape = keys.shape(),
kstrides = keys.strides()]() mutable {
int64_t out_skip = (bytes_per_key + 4 - 1) / 4;
uintptr_t half_size = out_skip / 2;
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);
@@ -334,7 +333,7 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
auto& in = inputs[0];
@@ -362,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs,
array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}
@@ -397,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu(
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
out.set_data(allocator::malloc(0));
return;
}

View File

@@ -13,7 +13,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = M;
int64_t num_matrices = a.size() / (M * N);
size_t num_matrices = a.size() / (M * N);
// Copy A to inplace input and make it col-contiguous
array in(a.shape(), a.dtype(), nullptr, {});
@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
auto work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices
for (int64_t i = 0; i < num_matrices; ++i) {
for (int i = 0; i < num_matrices; ++i) {
// Solve
geqrf<T>(
&M,
@@ -68,7 +68,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
}
allocator::free(work);
for (int64_t i = 0; i < num_matrices; ++i) {
for (int i = 0; i < num_matrices; ++i) {
/// num_reflectors x N
for (int j = 0; j < num_reflectors; ++j) {
for (int k = 0; k < j; ++k) {
@@ -97,7 +97,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices
for (int64_t i = 0; i < num_matrices; ++i) {
for (int i = 0; i < num_matrices; ++i) {
// Compute Q
orgqr<T>(
&M,
@@ -111,7 +111,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
&info);
}
for (int64_t i = 0; i < num_matrices; ++i) {
for (int i = 0; i < num_matrices; ++i) {
// M x num_reflectors
for (int j = 0; j < M; ++j) {
for (int k = 0; k < num_reflectors; ++k) {

View File

@@ -1,8 +1,11 @@
// Copyright © 2023 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/cpu/unary.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -1102,4 +1105,44 @@ void fast::Quantize::eval_cpu(
});
}
void fast::ConvertFP8::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& in = inputs[0];
auto& out = outputs[0];
set_unary_output_data(in, out);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
to_fp8 = to_fp8_]() mutable {
if (to_fp8) {
switch (in.dtype()) {
case float16:
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
break;
case bfloat16:
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
break;
default:
unary_op<float, uint8_t>(in, out, detail::ToFP8());
break;
}
} else {
switch (out.dtype()) {
case float16:
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
break;
case bfloat16:
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
break;
default:
unary_op<uint8_t, float>(in, out, detail::FromFP8());
break;
}
}
});
}
} // namespace mlx::core

View File

@@ -1,5 +1,6 @@
#pragma once
#include <arm_neon.h>
#include <simd/math.h>
#include <simd/vector.h>
@@ -200,6 +201,15 @@ SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> clz(Simd<T, N> x) {
auto a = *(uint32x4_t*)(&x);
auto b = *((uint32x4_t*)(&x) + 1);
a = vclzq_u32(a);
b = vclzq_u32(b);
return asd::make_uint8(a, b);
}
template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value);
@@ -207,14 +217,20 @@ Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
template <typename T, int N>
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::max(a.value, b.value);
auto out = Simd<T, N>(asd::max(a.value, b.value));
if constexpr (!std::is_integral_v<T>) {
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
}
template <typename T, int N>
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan
return asd::min(a.value, b.value);
auto out = Simd<T, N>(asd::min(a.value, b.value));
if constexpr (!std::is_integral_v<T>) {
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
}
template <typename T, int N>
@@ -253,12 +269,12 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
} else {
Simd<T, N> res = 1;
// Raising an integer to a negative power is undefined
if (any(exp < static_cast<T>(0))) {
if (any(exp < 0)) {
return 0;
}
while (any(exp > static_cast<T>(0))) {
while (any(exp > 0)) {
res = select((exp & 1) != 0, res * base, res);
base = select(exp > static_cast<T>(0), base * base, base);
base = select(exp > 0, base * base, base);
exp = exp >> 1;
}
return res;

View File

@@ -171,6 +171,11 @@ DEFAULT_BINARY(&)
DEFAULT_BINARY(&&)
DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> clz(Simd<T, 1> x_) {
return __builtin_clz(x_.value);
}
template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value;

View File

@@ -79,8 +79,7 @@ Simd<T, N> sincos(Simd<T, N> in) {
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
auto poly_mask =
(emm2 & static_cast<uint32_t>(2)) != static_cast<uint32_t>(0);
auto poly_mask = (emm2 & 2) != 0;
// The magic pass: "Extended precision modular arithmetic"
// x = ((x - y * DP1) - y * DP2) - y * DP3
@@ -88,8 +87,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != static_cast<uint32_t>(0));
auto sign_mask_cos = ((emm2 - 2) & 4) != static_cast<uint32_t>(0);
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
// and the second polynom (Pi/4 <= x <= 0) in y2

View File

@@ -120,8 +120,8 @@ template <typename T>
void sort(array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + out.ndim() : axis;
int64_t in_size = out.size();
int64_t n_rows = in_size / out.shape(axis);
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -136,7 +136,7 @@ void sort(array& out, int axis) {
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int64_t i = 0; i < n_rows; i++) {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0);
@@ -151,7 +151,7 @@ template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
int64_t n_rows = in.size() / in.shape(axis);
size_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
@@ -176,7 +176,7 @@ void argsort(const array& in, array& out, int axis) {
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int64_t i = 0; i < n_rows; i++) {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
@@ -214,8 +214,8 @@ template <typename T>
void partition(array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + out.ndim() : axis;
int64_t in_size = out.size();
int64_t n_rows = in_size / out.shape(axis);
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -232,7 +232,7 @@ void partition(array& out, int axis, int kth) {
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int64_t i = 0; i < n_rows; i++) {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
@@ -248,7 +248,7 @@ template <typename T, typename IdxT = uint32_t>
void argpartition(const array& in, array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
int64_t n_rows = in.size() / in.shape(axis);
size_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
@@ -277,7 +277,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int64_t i = 0; i < n_rows; i++) {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();

View File

@@ -8,6 +8,183 @@
namespace mlx::core {
template <typename T, class Enable = void>
struct SVDWork {};
template <typename T>
struct SVDWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <>
struct SVDWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
const int lrwork =
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
int lwork_query = -1;
int work_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension.real();
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <typename T>
void svd_impl(
const array& a,
@@ -27,7 +204,9 @@ void svd_impl(
const int N = a.shape(-1);
const int K = std::min(M, N);
int64_t num_matrices = a.size() / (M * N);
using R = typename SVDWork<T>::R;
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), a.dtype(), nullptr, {});
@@ -42,7 +221,7 @@ void svd_impl(
encoder.set_input_array(a);
auto in_ptr = in.data<T>();
T* u_ptr;
T* s_ptr;
R* s_ptr;
T* vt_ptr;
if (compute_uv) {
@@ -58,7 +237,7 @@ void svd_impl(
encoder.set_output_array(s);
encoder.set_output_array(vt);
s_ptr = s.data<T>();
s_ptr = s.data<R>();
u_ptr = u.data<T>();
vt_ptr = vt.data<T>();
} else {
@@ -68,96 +247,26 @@ void svd_impl(
encoder.set_output_array(s);
s_ptr = s.data<T>();
s_ptr = s.data<R>();
u_ptr = nullptr;
vt_ptr = nullptr;
}
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
auto jobz = (u_ptr) ? "A" : "N";
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
static const int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto jobz = (u_ptr) ? 'A' : 'N';
SVDWork<T> svd_work(N, M, K, jobz);
// Loop over matrices.
for (int64_t i = 0; i < num_matrices; i++) {
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
for (int i = 0; i < num_matrices; i++) {
svd_work.run(
in_ptr + M * N * i,
s_ptr + K * i,
vt_ptr ? vt_ptr + N * N * i : nullptr,
u_ptr ? u_ptr + M * M * i : nullptr);
}
});
encoder.add_temporary(in);
}
template <typename T>
void compute_svd(
const array& /* a */,
bool /* compute_uv */,
std::vector<array>& /* outputs */,
Stream /* stream */) {}
void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -168,9 +277,12 @@ void SVD::eval_cpu(
case float64:
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
break;
case complex64:
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
break;
default:
throw std::runtime_error(
"[SVD::eval_cpu] only supports float32 or float64.");
"[SVD::eval_cpu] only supports float32, float64, or complex64.");
}
}

View File

@@ -136,7 +136,7 @@ void ternary_op(
if (topt == TernaryOpType::ScalarScalarScalar) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
} else if (topt == TernaryOpType::VectorVectorVector) {
for (int64_t i = 0; i < out.size(); ++i) {
for (size_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;

View File

@@ -10,8 +10,8 @@
namespace mlx::core {
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, int64_t shape, int64_t stride) {
for (int64_t i = 0; i < shape; i += 1) {
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = Op{}(*a);
a += stride;
}
@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
auto ndim = a.ndim();
if (a.flags().contiguous) {
auto size = a.data_size();
constexpr int N = simd::max_size<T>;
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
while (size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src)));
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
size -= N;
src += N;
dst += N;
@@ -38,14 +38,14 @@ void unary_op(const array& a, array& out, Op) {
src++;
}
} else {
int64_t shape = ndim > 0 ? a.shape().back() : 1;
int64_t stride = ndim > 0 ? a.strides().back() : 1;
size_t shape = ndim > 0 ? a.shape().back() : 1;
size_t stride = ndim > 0 ? a.strides().back() : 1;
if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride);
return;
}
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
for (int64_t elem = 0; elem < a.size(); elem += shape) {
for (size_t elem = 0; elem < a.size(); elem += shape) {
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step();
}

View File

@@ -108,4 +108,73 @@ struct Square {
SINGLE()
};
template <int N>
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
return *(Simd<float, N>*)(&x);
}
template <int N>
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
return *(Simd<uint32_t, N>*)(&x);
}
struct ToFP8 {
template <typename T, int N>
Simd<uint8_t, N> operator()(Simd<T, N> f) {
uint32_t fp8_max = 543 << 21;
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
Simd<uint32_t, N> f_bits;
Simd<float, N> f32 = f;
f_bits = fp32_to_bits(f32);
Simd<uint8_t, N> result = 0u;
auto sign = f_bits & 0x80000000;
f_bits = f_bits ^ sign;
auto f_bits_low =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
result = select(f_bits < (121 << 23), result_low, result_high);
auto result_sat = Simd<uint8_t, N>(0x7E);
result = select(f_bits >= fp8_max, result_sat, result);
return result | Simd<uint8_t, N>(sign >> 24);
}
template <typename T>
uint8_t operator()(T x) {
return (*this)(Simd<T, 1>(x)).value;
}
};
struct FromFP8 {
template <int N>
Simd<float, N> operator()(Simd<uint8_t, N> x) {
auto w = Simd<uint32_t, N>(x) << 24;
auto sign = w & 0x80000000;
auto nonsign = w & 0x7FFFFFFF;
auto renorm_shift = clz(nonsign);
renorm_shift = simd::select(
renorm_shift > Simd<uint32_t, N>{4},
renorm_shift - Simd<uint32_t, N>{4},
Simd<uint32_t, N>{0});
Simd<int32_t, N> inf_nan_mask =
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
auto result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
float operator()(uint8_t x) {
return (*this)(Simd<uint8_t, 1>(x)).value;
}
};
} // namespace mlx::core::detail

View File

@@ -32,6 +32,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
@@ -43,6 +44,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
@@ -51,12 +53,19 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
# fp4 is not available on < 12.8
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
@@ -114,10 +123,21 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
# Use native CUDA arch by default.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES "native")
execute_process(
COMMAND __nvcc_device_query
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
OUTPUT_STRIP_TRAILING_WHITESPACE)
set(UPGRADABLE_ARCHITECTURES "90;100;121")
if(MLX_CUDA_ARCHITECTURES STREQUAL "")
message(
FATAL_ERROR
"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
# Use arch-specific compute capability whenever possible.
set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
endif()
endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
@@ -129,6 +149,7 @@ FetchContent_Declare(
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl)
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
set_target_properties(mlx PROPERTIES CCCL_DIR "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX.
FetchContent_Declare(
@@ -154,7 +175,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.14.0
GIT_TAG v1.16.0
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
@@ -170,11 +191,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
# Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/utils.h"
@@ -19,6 +20,19 @@ constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
#if CUDART_VERSION >= 13000
inline cudaMemLocation cuda_mem_loc(int i) {
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
return loc;
}
#else
inline int cuda_mem_loc(int i) {
return i;
}
#endif // CUDART_VERSION >= 13000
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
@@ -34,13 +48,7 @@ SmallSizePool::SmallSizePool() {
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
auto loc = cuda_mem_loc(i);
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
@@ -67,6 +75,7 @@ CudaBuffer* SmallSizePool::malloc() {
next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
b->buf.device = -1;
return &b->buf;
}
@@ -88,16 +97,47 @@ CudaAllocator::CudaAllocator()
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
// TODO: Set memory limit for multi-device.
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
size_t free;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
memory_limit_ = total_memory_ * 0.95;
free_limit_ = total_memory_ - memory_limit_;
max_pool_size_ = memory_limit_;
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
int curr;
CHECK_CUDA_ERROR(cudaGetDevice(&curr));
for (int i = 0; i < device_count; ++i) {
CHECK_CUDA_ERROR(cudaSetDevice(i));
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
cudaMemPool_t mem_pool;
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
mem_pools_.push_back(mem_pool);
}
CHECK_CUDA_ERROR(cudaSetDevice(curr));
}
Buffer CudaAllocator::malloc(size_t size) {
void copy_to_managed(CudaBuffer& buf) {
// TODO maybe make this async on a i/o stream to avoid synchronizing the
// device on malloc/and free
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, buf.size));
buf.device = -1;
CHECK_CUDA_ERROR(cudaMemcpy(new_data, buf.data, buf.size, cudaMemcpyDefault));
CHECK_CUDA_ERROR(cudaFree(buf.data));
buf.data = new_data;
}
Buffer
CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
if (size == 0) {
return Buffer{new CudaBuffer{nullptr, 0, -1}};
}
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
size = 8;
@@ -107,6 +147,10 @@ Buffer CudaAllocator::malloc(size_t size) {
size = page_size * ((size + page_size - 1) / page_size);
}
if (size <= small_block_size || stream == nullptr) {
device = -1;
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure try to reclaim memory from the cache.
@@ -122,30 +166,63 @@ Buffer CudaAllocator::malloc(size_t size) {
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
void* data = nullptr;
if (device == -1) {
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
} else {
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
}
if (!data) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
buf = new CudaBuffer{data, size, device};
}
lock.lock();
// If any cuda memory pool has too much reserved memory, clear some
// memory from the cache. This prevents graph / kernel execution failing
// from OOM
if (get_cache_memory() > 0) {
for (auto p : mem_pools_) {
size_t used = 0;
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
p, cudaMemPoolAttrReservedMemCurrent, &used));
if (used > (total_memory_ - free_limit_)) {
buffer_cache_.release_cached_buffers(free_limit_);
break;
}
}
}
}
active_memory_ += size;
active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
// Copy to managed here if the buffer is not on the right device
if (buf->device >= 0 && buf->device != device) {
copy_to_managed(*buf);
}
return Buffer{buf};
}
Buffer CudaAllocator::malloc(size_t size) {
return malloc_async(size, -1, nullptr);
}
void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) {
return;
}
if (buf->size == 0) {
delete buf;
return;
}
std::unique_lock lock(mutex_);
active_memory_ -= buf->size;
@@ -169,7 +246,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf);
} else {
cudaFree(buf->data);
if (buf->device >= 0) {
CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
} else {
CHECK_CUDA_ERROR(cudaFree(buf->data));
}
delete buf;
}
}
@@ -220,6 +301,17 @@ CudaAllocator& allocator() {
return *allocator_;
}
Buffer malloc_async(size_t size, CommandEncoder& encoder) {
auto buffer = allocator().malloc_async(
size, encoder.device().cuda_device(), encoder.stream());
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace cu
namespace allocator {
@@ -232,7 +324,11 @@ void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<cu::CudaBuffer*>(ptr_)->data;
auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
if (cbuf.device != -1) {
copy_to_managed(cbuf);
}
return cbuf.data;
}
} // namespace allocator

View File

@@ -4,19 +4,24 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/cuda/cuda_utils.h"
#include <cuda_runtime.h>
#include <mutex>
#include <set>
#include <utility>
namespace mlx::core::cu {
class CommandEncoder;
using allocator::Buffer;
// Stores cuda-managed unified memory.
struct CudaBuffer {
void* data;
size_t size;
int device; // -1 for managed
};
class SmallSizePool {
@@ -45,6 +50,7 @@ class SmallSizePool {
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, int device, cudaStream_t stream);
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
@@ -65,13 +71,19 @@ class CudaAllocator : public allocator::Allocator {
std::mutex mutex_;
size_t memory_limit_;
size_t free_limit_;
size_t total_memory_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_;
std::vector<cudaMemPool_t> mem_pools_;
SmallSizePool scalar_pool_;
};
CudaAllocator& allocator();
Buffer malloc_async(size_t size, CommandEncoder& encoder);
} // namespace mlx::core::cu

View File

@@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
@@ -58,7 +57,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dims,
0,
out.data<OutType>(),
gpu_ptr<OutType>(out),
out.data_size(),
static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));

View File

@@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder));
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);
@@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int32_t ndim = shape.size();
// ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
@@ -172,8 +173,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks,
block_dim(),
0,
in.data<T>(),
out.data<uint32_t>(),
gpu_ptr<T>(in),
gpu_ptr<uint32_t>(out),
out.size(),
const_param(shape),
const_param(in_strides),

View File

@@ -292,9 +292,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -310,9 +310,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
rest,
const_param(shape),
const_param(a_strides),
@@ -339,9 +339,9 @@ void binary_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out),
out.data_size());
});
}
@@ -365,7 +365,10 @@ void binary_op_gpu(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(
a, b, out, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -245,14 +245,16 @@ void binary_two_op_gpu_inplace(
auto& out_a = outputs[0];
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt);
set_binary_op_output_data(a, b, out_b, bopt);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(
a, b, out_a, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
set_binary_op_output_data(
a, b, out_b, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
if (out_a.size() == 0) {
return;
}
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out_a);
@@ -313,10 +315,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
rest,
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
@@ -332,10 +334,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y},
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
rest,
const_param(shape),
const_param(a_strides),
@@ -366,10 +368,10 @@ void binary_two_op_gpu_inplace(
num_blocks,
block_dims,
0,
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
gpu_ptr<InType>(a),
gpu_ptr<InType>(b),
gpu_ptr<OutType>(out_a),
gpu_ptr<OutType>(out_b),
out_a.data_size());
});
}

View File

@@ -293,8 +293,13 @@ void Compiled::eval_gpu(
}
}
auto& encoder = cu::get_command_encoder(s);
// Put outputs.
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
compiled_allocate_outputs(
inputs, outputs, is_constant_, contiguous, [&](auto n) {
return cu::malloc_async(n, encoder);
});
for (auto& x : outputs) {
args.append(x);
}
@@ -324,7 +329,6 @@ void Compiled::eval_gpu(
kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
}
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}

View File

@@ -15,19 +15,16 @@ namespace mlx::core {
namespace {
// Alias for better readability.
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
#define CONV_BACKWARD_INPUT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
#define CONV_BACKWARD_WEIGHT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
// Custom placeholder representing fallback kernel.
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
enum ConvBackendType {
CONV_FALLBACK,
CONV_FORWARD,
CONV_BACKWARD_INPUT,
CONV_BACKWARD_WEIGHT,
};
struct ConvCacheKey {
int device_id;
cudnnDataType_t cudnn_dtype;
fe::DataType_t cudnn_dtype;
std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> weight_shape;
std::array<int, MAX_NDIM> stride;
@@ -44,15 +41,13 @@ struct ConvCacheKey {
auto& conv_cache() {
static LRUBytesKeyCache<
ConvCacheKey,
std::pair<
cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>>
std::pair<ConvBackendType, std::optional<DnnGraph>>>
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
return cache;
}
auto get_conv_op_settings(
cudnnBackendDescriptorType_t backend_type,
auto get_conv_settings(
ConvBackendType backend_type,
array& x,
array& w,
array& y,
@@ -68,8 +63,8 @@ auto get_conv_op_settings(
for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_lo[i] - 1;
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
int in_size = 1 + kernel_strides[i] * (y.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (x.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + padding_hi[i];
}
return std::make_tuple(
@@ -95,49 +90,57 @@ auto get_conv_op_settings(
}
}
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
std::optional<DnnGraph> build_conv_graph(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
ConvBackendType backend_type,
Dtype dtype,
array& x,
array& w,
array& y,
const SmallVector<int64_t>& stride,
const SmallVector<int64_t>& padding_lo,
const SmallVector<int64_t>& padding_hi,
const SmallVector<int64_t>& dilation) {
try {
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
? CUDNN_DATA_FLOAT
: dtype_to_cudnn_type(dtype);
auto conv_desc = cudnn_frontend::ConvDescBuilder()
.setDataType(compute_dtype)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding_lo,
const std::vector<int64_t>& padding_hi,
const std::vector<int64_t>& dilation) {
auto compute_dtype =
(dtype == float16 || dtype == bfloat16) ? float32 : dtype;
DnnGraph graph(encoder.device().cudnn_handle(), dtype, compute_dtype);
auto x_ = graph.tensor_nchw("X", 'x', x);
auto w_ = graph.tensor_nchw("W", 'w', w);
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_cudnn_tensor_nchw('x', x))
.setwDesc(build_cudnn_tensor_nchw('w', w))
.setyDesc(build_cudnn_tensor_nchw('y', y))
.setcDesc(conv_desc)
.build();
auto set_options = [&](auto& options) {
options.set_compute_data_type(dtype_to_cudnn_type(compute_dtype))
.set_convolution_mode(fe::ConvolutionMode_t::CROSS_CORRELATION)
.set_stride(stride)
.set_pre_padding(padding_lo)
.set_post_padding(padding_hi)
.set_dilation(dilation);
};
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
return cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
throw;
}
std::shared_ptr<fe::graph::Tensor_attributes> y_;
if (backend_type == CONV_FORWARD) {
auto options = fe::graph::Conv_fprop_attributes();
set_options(options);
y_ = graph.conv_fprop(x_, w_, options);
} else if (backend_type == CONV_BACKWARD_INPUT) {
auto options = fe::graph::Conv_dgrad_attributes();
set_options(options);
y_ = graph.conv_dgrad(x_, w_, options);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
auto options = fe::graph::Conv_wgrad_attributes();
set_options(options);
y_ = graph.conv_wgrad(w_, x_, options);
}
graph.tensor_nchw(y_, 'y', y)->set_output(true);
if (graph.prepare().is_bad()) {
return std::nullopt;
}
graph.deselect_numeric_notes({fe::NumericalNote_t::DOWN_CONVERT_INPUTS});
if (dtype == float32 && !env::enable_tf32()) {
graph.deselect_numeric_notes({fe::NumericalNote_t::TENSOR_CORE});
}
CHECK_CUDNN_FE_ERROR(graph.build());
return graph;
}
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
@@ -181,7 +184,7 @@ array group_transpose(
// eval_gpu, with cost of possible redundant copies.
std::tuple<array, array, array> prepare_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
ConvBackendType backend_type,
array in,
array wt,
array out,
@@ -221,27 +224,11 @@ std::tuple<array, array, array> prepare_args(
return {std::move(in), std::move(wt), std::move(out)};
}
// Get the x/w/y args from the in/wt/out args depending on backend type.
inline std::tuple<array&, array&, array&> dispatch_args(
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& out) {
switch (backend_type) {
case CONV_BACKWARD_INPUT:
return {out, wt, in};
case CONV_BACKWARD_WEIGHT:
return {in, out, wt};
default:
return {in, wt, out};
}
}
// Register inputs and outputs before actually running conv op. Can only be
// called once per eval_gpu.
void register_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
ConvBackendType backend_type,
array& in,
array& wt,
array& intermediate_out,
@@ -270,19 +257,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
if (out_.size() == 0) {
return;
}
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2);
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache.
ConvCacheKey cache_key{
BytesKey<ConvCacheKey> cache_key;
cache_key.pod = {
encoder.device().cuda_device(),
dtype_to_cudnn_type(dtype),
vector_key(in.shape()),
@@ -297,16 +284,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
get_alignment(wt),
get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
auto& [backend_type, plan] = it->second;
if (plan) {
// Run cached plan.
auto& [backend_type, graph] = it->second;
if (graph) {
// Run cached graph.
std::tie(in, wt, out) =
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
}
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
encoder,
{
{'x', gpu_ptr<void>(in)},
{'w', gpu_ptr<void>(wt)},
{'y', gpu_ptr<void>(out)},
}));
} else {
// Run fallback kernel.
gemm_conv(
@@ -327,7 +317,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
// There is no reliable way to deduce the proper cuDNN backend for the
// convolution, so we make a best guess and then try.
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
SmallVector<ConvBackendType, 2> try_backends;
if (flip_) {
// When weight is flipped, we assume it is backward input convolution.
try_backends.push_back(CONV_BACKWARD_INPUT);
@@ -345,13 +335,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
}
// Try to build op graph.
cudnnBackendDescriptorType_t backend_type;
std::optional<cudnn_frontend::OperationGraph> op_graph;
ConvBackendType backend_type;
std::optional<DnnGraph> graph;
for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
auto [x, w, y] =
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings(
try_backend,
x,
w,
@@ -361,7 +350,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
padding_hi_,
kernel_dilation_,
input_dilation_);
op_graph = build_conv_op_graph(
graph = build_conv_graph(
encoder,
try_backend,
dtype,
@@ -372,30 +361,27 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
padding_lo,
padding_hi,
dilation);
if (op_graph) {
if (graph) {
backend_type = try_backend;
in = std::move(in_copy);
wt = std::move(wt_copy);
out = std::move(out_copy);
in = std::move(x);
wt = std::move(w);
out = std::move(y);
break;
}
}
if (op_graph) {
// Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (plan) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
}
}
if (graph) {
register_args(encoder, backend_type, in, wt, out, out_);
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
encoder,
{
{'x', gpu_ptr<void>(in)},
{'w', gpu_ptr<void>(wt)},
{'y', gpu_ptr<void>(out)},
}));
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*graph)));
return;
}
// Use fallback kernel for settings not supported by cuDNN.

View File

@@ -86,7 +86,7 @@ array unfold_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
encoder.add_temporary(unfolded);
int filter_size = params.C;
@@ -118,8 +118,8 @@ array unfold_inputs_nd(
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(unfolded),
filter_size,
out_pixels,
params);

View File

@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
encoder.add_temporary(unfolded);
int filter_size = params.C;
@@ -121,8 +121,8 @@ array grouped_unfold_transpose_inputs_nd(
num_blocks,
block_dims,
0,
in.data<DataType>(),
unfolded.data<DataType>(),
gpu_ptr<DataType>(in),
gpu_ptr<DataType>(unfolded),
filter_size,
out_pixels,
params);

View File

@@ -5,6 +5,21 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
auto& encoder = cu::get_command_encoder(s);
bool donated = set_copy_output_data(
in, out, ctype, [&](auto n) { return cu::malloc_async(n, encoder); });
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
@@ -87,11 +102,31 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
if (out.size() == 0) {
return;
}
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_input_array(in);
encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
}
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace mlx::core

View File

@@ -77,8 +77,8 @@ void copy_contiguous(
num_blocks,
block_dims,
0,
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
gpu_ptr<InType>(in) + in_offset,
gpu_ptr<OutType>(out) + out_offset,
out.data_size());
});
});

View File

@@ -106,8 +106,8 @@ void copy_general(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)

View File

@@ -69,8 +69,8 @@ void copy_general_dynamic(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) {
@@ -90,8 +90,8 @@ void copy_general_dynamic(
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in),
const_param<dims_constant()>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
gpu_ptr<int64_t>(dynamic_offset_in),
gpu_ptr<int64_t>(dynamic_offset_out));
});
} else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large());
@@ -107,8 +107,8 @@ void copy_general_dynamic(
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
gpu_ptr<int64_t>(dynamic_offset_in),
gpu_ptr<int64_t>(dynamic_offset_out));
}
});
});

View File

@@ -92,14 +92,17 @@ void copy_general_input(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
int work_per_thread = 1;
int work_per_thread = 8;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
if (dim0 >= 4 && dim0 < 8) {
work_per_thread = 4;
} else if (dim0 < 4) {
work_per_thread = 1;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
@@ -110,7 +113,10 @@ void copy_general_input(
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) {
if (work_per_thread == 8) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
} else if (work_per_thread == 4) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
}
@@ -127,7 +133,9 @@ void copy_general_input(
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
if (work_per_thread == 8) {
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
} else if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(

View File

@@ -0,0 +1,89 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn.h>
namespace mlx::core {
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
void check_cudnn_error(const char* name, cudnnStatus_t err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
// Skip if there was an error to avoid throwing in the destructors
if (cudaPeekAtLastError() != cudaSuccess) {
return;
}
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
namespace cu {
class Device;
}; // namespace cu
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
} // namespace mlx::core

View File

@@ -7,32 +7,26 @@ namespace mlx::core {
namespace {
// Create a cudnn tensor descriptor.
template <typename Vec>
inline cudnn_frontend::Tensor build_cudnn_tensor(
int64_t id,
const array& x,
const Vec& shape,
const Vec& strides) {
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
#define RETURN_IF_ERROR(cmd) \
if (auto ret = cmd; ret.is_bad()) { \
return ret; \
}
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
// whether a tensor is contiguous is determined with:
// shape[dim] == shape[dim + 1] * strides[dim + 1]
// So a contiguous array with singleton dims in MLX may be mistakenly treated
// as strided in cuDNN, and we work around it by normalizing the strides.
Strides normalized_strides(const array& x) {
if (!x.flags().row_contiguous || x.ndim() < 2) {
return x.strides();
std::vector<int64_t> normalized_strides(const array& x) {
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
if (std::all_of(
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
strides.back() = 1;
return strides;
}
if (!x.flags().row_contiguous || x.ndim() < 2) {
return strides;
}
Strides strides = x.strides();
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
@@ -42,7 +36,9 @@ Strides normalized_strides(const array& x) {
}
// Return the shape and strides after transposing from NHWC to NCHW.
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
inline auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
auto strides = normalized_strides(x);
assert(shape.size() >= 3);
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
@@ -51,225 +47,95 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
return std::make_tuple(std::move(shape), std::move(strides));
}
inline auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(
convert_vector<int64_t>(x.shape()), normalized_strides(x));
}
// Return available engines for a |op_graph|.
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = true) {
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
sources.push_back([](auto& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
});
if (use_fallback) {
sources.push_back([&backend_type](auto& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
});
}
auto configs =
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
// Take |engine_configs| and |op_graph| and find a working execution plans
// from them.
std::optional<cudnn_frontend::ExecutionPlan>
find_cudnn_plan_from_engine_configs(
cudnnHandle_t handle,
const cudnn_frontend::EngineConfigList& engine_configs,
const cudnn_frontend::OperationGraph& op_graph) {
auto op_graph_tag = op_graph.getTag();
for (const auto& config : engine_configs) {
try {
return cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config, op_graph_tag)
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return std::nullopt;
}
// Prepare workspace and args to execute plan.
template <typename F>
bool prepare_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
array workspace(
workspace_size > 0 ? allocator::malloc(workspace_size)
: allocator::Buffer(nullptr),
{workspace_size},
uint8);
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>())
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
return false;
}
encoder.add_temporary(workspace);
return true;
}
} // namespace
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
fe::error_t DnnGraph::prepare() {
RETURN_IF_ERROR(validate());
try {
RETURN_IF_ERROR(build_operation_graph(handle_));
} catch (cudnn_frontend::cudnnException& error) {
// cuDNN bug: they did not catch all exceptions in the API.
return {fe::error_code_t::CUDNN_BACKEND_API_FAILED, error.what()};
}
RETURN_IF_ERROR(create_execution_plans({fe::HeurMode_t::A}));
return {};
}
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
fe::error_t DnnGraph::build() {
RETURN_IF_ERROR(check_support(handle_));
RETURN_IF_ERROR(build_plans(handle_));
return {};
}
fe::error_t DnnGraph::encode_graph(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack) {
cudnnSetStream(handle_, encoder.stream());
CudaGraph cuda_graph(encoder.device());
RETURN_IF_ERROR(populate_cuda_graph(
handle_, variant_pack, prepare_workspace(encoder), cuda_graph));
encoder.add_graph_node(cuda_graph);
return {};
}
fe::error_t DnnGraph::encode_capturing(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack) {
auto* workspace_ptr = prepare_workspace(encoder);
auto capture = encoder.capture_context();
cudnnSetStream(handle_, encoder.stream());
auto ret = execute(handle_, variant_pack, workspace_ptr);
if (ret.is_bad()) {
capture.discard = true;
}
return ret;
}
void* DnnGraph::prepare_workspace(cu::CommandEncoder& encoder) {
int64_t workspace_size = 0;
CHECK_CUDNN_FE_ERROR(get_workspace_size(workspace_size));
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder),
{static_cast<int>(workspace_size)},
uint8);
encoder.add_temporary(workspace);
return gpu_ptr<void>(workspace);
}
return nullptr;
}
void DnnGraph::set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& strides) {
tensor->set_uid(uid)
.set_alignment(get_alignment(x))
.set_data_type(dtype_to_cudnn_type(x.dtype()))
.set_dim(shape)
.set_stride(strides);
}
void DnnGraph::set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x) {
set_tensor_attrs(
tensor,
uid,
x,
convert_vector<int64_t>(x.shape()),
normalized_strides(x));
}
void DnnGraph::set_tensor_attrs_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return build_cudnn_tensor(id, x, shape, strides);
set_tensor_attrs(tensor, uid, x, shape, strides);
}
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
if (x.ndim() == 0) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
}
if (x.ndim() == 1) {
int64_t s = x.shape(0);
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
SmallVector<int64_t, 4> strides = {s, 1, s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 2) {
int64_t s =
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 3 || x.ndim() == 4) {
return build_cudnn_tensor_nchw(id, x);
}
throw std::runtime_error(
fmt::format("Unsupported array with {} dims.", x.ndim()));
}
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return cudnn_frontend::TensorBuilder()
.setDim(scalar_dims.size(), scalar_dims.data())
.setStrides(scalar_dims.size(), scalar_dims.data())
.setId(id)
.setAlignment(16)
.setDataType(dtype_to_cudnn_type(dtype))
.setByValue(true)
.build();
}
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
if (engine_configs.empty()) {
return std::nullopt;
}
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
}
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
auto capture = encoder.capture_context();
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
return true;
});
}
#if CUDNN_VERSION >= 90500
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
if (!graph) {
graph = CudaGraph(encoder.device());
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
} else {
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
}
encoder.add_graph_node(graph);
return true;
});
}
#endif
} // namespace mlx::core

View File

@@ -2,28 +2,34 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <algorithm>
#include <array>
namespace mlx::core {
namespace cu {
class CommandEncoder;
}
namespace fe = cudnn_frontend;
#define CHECK_CUDNN_FE_ERROR(cmd) \
do { \
auto error = cmd; \
if (!error.is_good()) { \
throw std::runtime_error( \
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
} \
} while (0)
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
@@ -34,8 +40,31 @@ inline uint8_t get_alignment(const array& x) {
// Convert the type of elements in |vec| to |T|.
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
inline std::vector<T> convert_vector(const Vec& vec) {
return std::vector<T>(vec.begin(), vec.end());
}
// Map dtype to cudnn data type.
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return fe::DataType_t::INT8;
case int32:
return fe::DataType_t::INT32;
case uint8:
return fe::DataType_t::UINT8;
case float16:
return fe::DataType_t::HALF;
case bfloat16:
return fe::DataType_t::BFLOAT16;
case float32:
return fe::DataType_t::FLOAT;
case float64:
return fe::DataType_t::DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype)));
}
}
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
@@ -43,122 +72,100 @@ inline SmallVector<T> convert_vector(const Vec& vec) {
// There are 2 differences from the const_param util from kernel_utils.cuh:
// 1. The rest of array is filled with 0.
// 2. This util can be used in .cpp files.
template <typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) {
template <int NDIM = MAX_NDIM, typename T, template <typename U> class Vec>
inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
fmt::format("ndim can not be larger than {}.", NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::array<T, NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(arr.data<void>());
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
inline void* get_data_ptr(T& scalar) {
return &scalar;
}
// Return an array filled with data pointers of args.
template <typename... Args>
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
return {get_data_ptr(args)...};
}
// Map dtype to cudnn data type.
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
// Extends cuDNN graph with helpers.
class DnnGraph : public fe::graph::Graph {
public:
DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32)
: handle_(handle) {
set_io_data_type(dtype_to_cudnn_type(io_dtype));
set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype));
set_compute_data_type(dtype_to_cudnn_type(compute_dtype));
}
}
// Create a tensor descriptor from |x|.
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
// Create a cuDNN tensor description from MLX array |x|.
auto& tensor(
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
int64_t uid,
const array& x) {
set_tensor_attrs(attrs, uid, x);
return attrs;
}
auto tensor(const char* name, int64_t uid, const array& x) {
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
tensor(attrs, uid, x);
return attrs;
}
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
// Create a cuDNN tensor description from MLX array |x|, and transpose it from
// NHWC layout to NCHW.
auto& tensor_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
int64_t uid,
const array& x) {
set_tensor_attrs_nchw(attrs, uid, x);
return attrs;
}
auto tensor_nchw(const char* name, int64_t uid, const array& x) {
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
tensor_nchw(attrs, uid, x);
return attrs;
}
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
// from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
// Create a cuDNN tensor for scalar.
auto scalar(const char* name, int64_t uid, Dtype dtype) {
return Graph::tensor(fe::graph::Tensor_attributes()
.set_name(name)
.set_uid(uid)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(dtype_to_cudnn_type(dtype)));
}
// Create a 4D scalar tensor descriptor, which is passed by value.
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
// Call this before setting notes.
fe::error_t prepare();
// Call this after setting notes.
fe::error_t build();
// Find a working plan for |op_graph|.
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph);
// Add cuDNN graph to CUDA graph, using native CUDA graph API.
fe::error_t encode_graph(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack);
// Add cuDNN graph to CUDA graph, using stream capture.
fe::error_t encode_capturing(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack);
// Encode the plan to command buffer by capturing.
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs);
private:
void* prepare_workspace(cu::CommandEncoder& encoder);
#if CUDNN_VERSION >= 90500
// Encode the plan to command buffer by using native graph api of cudnn. If the
// |graph| is empty it will be populated, otherwise it will be updated.
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs);
#endif
void set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& strides);
void set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x);
void set_tensor_attrs_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x);
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_capturing(
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
}
#if CUDNN_VERSION >= 90500
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_graph_api(
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
}
#endif
cudnnHandle_t handle_;
};
} // namespace mlx::core

View File

@@ -57,7 +57,7 @@ std::string build_kernel(
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) {
const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header;
@@ -81,17 +81,17 @@ std::string build_kernel(
kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
if (std::get<0>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Shape ";
kernel_source += name;
kernel_source += "_shape,\n";
}
if (shape_infos[i].strides) {
if (std::get<1>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Strides ";
kernel_source += name;
kernel_source += "_strides,\n";
}
if (shape_infos[i].ndim) {
if (std::get<2>(shape_infos[i])) {
kernel_source += " const __grid_constant__ int ";
kernel_source += name;
kernel_source += "_ndim,\n";
@@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel(
"[custom_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
std::vector<std::tuple<bool, bool, bool>> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
std::tuple<bool, bool, bool> shape_info;
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
@@ -254,8 +254,8 @@ std::vector<array> precompiled_cuda_kernel(
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
std::vector<std::tuple<bool, bool, bool>> shape_infos(
inputs.size(), {false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
@@ -279,6 +279,7 @@ void CustomKernel::eval_gpu(
std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
std::vector<array> copies;
@@ -288,7 +289,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(allocator::malloc(out.nbytes()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
}
}
@@ -326,13 +327,13 @@ void CustomKernel::eval_gpu(
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
args.append(in);
if (shape_info.shape) {
if (std::get<0>(shape_info)) {
args.append_ndim(in.shape());
}
if (shape_info.strides) {
if (std::get<1>(shape_info)) {
args.append_ndim(in.strides());
}
if (shape_info.ndim) {
if (std::get<2>(shape_info)) {
args.append<int32_t>(in.ndim());
}
}
@@ -356,7 +357,6 @@ void CustomKernel::eval_gpu(
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) {
encoder.set_input_array(in);
}

Some files were not shown because too many files have changed in this diff Show More