Compare commits

..

79 Commits

Author SHA1 Message Date
Angelos Katharopoulos
d2bc340df4 Update the mlx.launch and mlx.distributed_config docs 2025-12-12 17:03:38 -08:00
Angelos Katharopoulos
fabc947df4 Possible doc fix 2025-12-12 15:34:59 -08:00
Angelos Katharopoulos
5523087cfb Finish the distributed docs 2025-12-12 14:16:16 -08:00
Angelos Katharopoulos
2f939acefa Progress with the docs 2025-12-12 04:36:46 -08:00
Angelos Katharopoulos
3b416a2e36 Comments 2025-12-12 01:21:43 -08:00
Angelos Katharopoulos
753c6a4d0f Disable echo for interactive distributed 2025-12-11 17:47:42 -08:00
Angelos Katharopoulos
d3a754c8aa Fix config when more cables are connected 2025-12-10 02:14:29 -08:00
Angelos Katharopoulos
595a4ad206 Improve interactivity of mlx.launch 2025-12-10 00:54:27 -08:00
Angelos Katharopoulos
a4dc1fac6c Set the shell to bash explicitly 2025-12-09 15:17:09 -08:00
Angelos Katharopoulos
ebda161a86 Remove old joined script 2025-12-09 13:39:57 -08:00
Angelos Katharopoulos
fa31a4b295 Add more checks and improve errors 2025-12-09 13:36:17 -08:00
Angelos Katharopoulos
9d707ba3b5 Remove python from the launch script 2025-12-09 13:04:37 -08:00
Angelos Katharopoulos
405d30b6e5 Refactor distributed config 2025-12-09 05:58:44 -08:00
Angelos Katharopoulos
cd4b12ce1b Refactoring launcher 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
425043ccca Change the name to a fun pun 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
95d92af8a0 Add headers for gcc 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
bfdddd644b Expose per-backend availability in C++ and python 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
1216afdc91 Add a no_ibv 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
04e94d78bb Add empty sum_scatter 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
60d4e8b2a8 Add send/recv 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
c5745fddd2 Make sure that there is space for work completions 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
e937a8033f Add working reduce and semi-working all gather 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
4dfe02d7c6 Fix ring 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
5c2cff9329 Fix side channel initialization for more than 2 peers 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
325dab9559 All gather 2025-12-08 15:50:05 -08:00
Angelos Katharopoulos
67e454ab0a Initial working all reduce 2025-12-08 15:50:05 -08:00
Awni Hannun
27232db1ba [CUDA] Enable more graphs to be updatable (#2883)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-08 06:18:01 -08:00
Awni Hannun
a4b3bc969b Try not to fail when there should be memory available (#2869)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-07 06:11:00 -08:00
Awni Hannun
667c0f3bb9 [Metal] No copy array init (#2875)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-05 13:36:45 -08:00
Cheng
6245824d42 Make allocator::malloc throw on allocation failure (#2874)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-05 17:44:38 +09:00
Awni Hannun
39289ef025 [CUDA] Release build for cuda 13 (#2872) 2025-12-04 21:42:26 -08:00
Awni Hannun
aefc9bd3f6 [CUDA] Faster general copy (#2873) 2025-12-04 21:42:15 -08:00
Angelos Katharopoulos
997cfc7699 Add a 2-pass col reduce for CUDA (#2863)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-04 15:53:59 -08:00
Awni Hannun
1fa8dc5797 Do a PyPi release for cuda on arm (#2866) 2025-12-04 15:28:29 -08:00
Awni Hannun
a6d6717181 fix compile copying (#2871) 2025-12-04 12:32:56 -08:00
Awni Hannun
941cfe23d7 Layer norm throws on dimension mismatch (#2870) 2025-12-04 11:21:05 -08:00
romanoneg
9abb0b8123 Added support for pytree types that inherit from tuple and typing.namedtuple (#2845) 2025-12-04 11:06:45 -08:00
Tian En "TianHeng
50d3914c67 Update gumbel function signature parameters (#2868)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-03 15:37:35 -08:00
Awni Hannun
cacbdbf995 Fix init from double (#2861)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-03 06:08:11 -08:00
Awni Hannun
193cdcd81a Fix graph updating (#2857)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-02 17:12:24 -08:00
Awni Hannun
d8ceae7b77 Reduce JVP (#2854) 2025-12-02 16:17:47 -08:00
Awni Hannun
eff0e31f00 Fix export scatters (#2852)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-02 11:24:40 -08:00
Awni Hannun
6c5785bc2f use thread local cpature mode (#2850)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-01 19:02:47 -08:00
CCYeh
8879ee00eb Support more Numpy interfaces for masked_scatter (#2832) 2025-12-01 17:51:02 -08:00
Cheng
6e762fe2e2 [CUDA] Migrate conv code to new cuDNN APIs (#2847)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-02 07:55:43 +09:00
Cheng
2b95d0c270 [CUDA] Use cuDNN attention when T_q != T_kv (#2843)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-11-27 09:58:43 +09:00
Chaoran Yu
b054838780 Added clarification to apply_fn parameter of apply_to_modules (#2831)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-26 15:40:56 -08:00
Awni Hannun
dd79d3c465 [CUDA] Faster rms norm for small dimension (#2838) 2025-11-26 15:10:41 -08:00
Cheng
704fd1ae28 [CUDA] Support array mask in SDPA (#2822)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-26 11:08:58 +09:00
Cheng
c9f4dc851f Merge build-cuda and build-linux actions (#2783)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-11-25 20:06:42 +09:00
Cheng
f8bd675655 [CUDA] Output of SDPA should have same layout with inputs (#2826)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-25 15:22:58 +09:00
Cheng
23a9168d34 [CUDA] Add debug env to save cuda graphs to dot files (#2825) 2025-11-25 15:22:36 +09:00
Awni Hannun
bca205e287 [CUDA] Exit on crash and more helpful errors (#2830) 2025-11-24 19:46:03 -08:00
CCYeh
1d4eacb737 Fix mx.core.linspace type annotation (#2820)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-11-24 14:15:08 -08:00
dependabot[bot]
8abd37ad05 Bump actions/checkout from 5 to 6 (#2828)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-24 06:04:46 -08:00
Andrey Portnoy
3e05cea9f8 Force cudaGraphExec reinstantiation when clusters are used (#2813)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-22 12:43:49 -08:00
CCYeh
5b0f047226 Fix mx.core.load type annotation (#2819) 2025-11-22 11:09:44 -08:00
Harsh Sutaria
618c87af8c Add float64 Eig and complex64 SVD/Eig support (Fixes #2708) (#2737)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-22 06:51:36 -08:00
Cheng
d5f61a93fa Fix typo: refs/head/main => refs/heads/main (#2818)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-22 09:43:35 +09:00
Awni Hannun
4a09264236 Tolerance for some ops tests on cuda (#2815) 2025-11-21 16:06:16 -08:00
Awni Hannun
0dbc7e5bee Centralize NAX condition (#2811)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-21 13:28:15 -08:00
Awni Hannun
0d68efd461 patch bump for future version (#2804)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-20 09:26:20 -08:00
Awni Hannun
f9e1a14135 [CUDA] Partly fix random for large sizes (#2798) 2025-11-20 07:27:50 -08:00
Awni Hannun
d8e9ded928 Fix cuda allocator copy condition (#2800) 2025-11-20 07:06:55 -08:00
Awni Hannun
60939d010c Fix macos release target and linux arm release (#2802)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-19 21:37:50 -08:00
Awni Hannun
fdcd2923fd patch + fix docs build (#2799) 2025-11-19 16:16:26 -08:00
Jagrit Digani
54f1cc6e3e Add Neural Accelerator Support (#2772) 2025-11-19 15:06:00 -08:00
CCYeh
b3825ac149 Add Masked Scatter (#2663)
Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-11-19 14:53:32 -08:00
Awni Hannun
7f4b7e553c version (#2797) 2025-11-19 14:11:16 -08:00
Awni Hannun
ad16f41a7f Fix version tag (#2790)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-19 08:55:57 -08:00
Awni Hannun
f46877bc08 more accurate rope fallback (#2792) 2025-11-19 06:07:21 -08:00
Cheng
6f35017d1b [CUDA] cuDNN backward attention (#2762)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-19 08:13:50 +09:00
Awni Hannun
b167f0df1c build docs on linux (#2787)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-18 08:01:03 -08:00
Cheng
a9f0d6b160 Avoid duplicate CI runs when starting a PR from upstream branch (#2788)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-18 15:16:25 +09:00
Cheng
940f4c7818 Fix building with CUDA < 12.8 (#2782)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-18 12:55:19 +09:00
Cheng
35f81728f1 Remove unneeded tests in nightly build (#2786) 2025-11-18 08:09:58 +09:00
Cheng
4442ed86c1 Fix nightly build (#2785) 2025-11-18 08:07:51 +09:00
Cheng
698559c231 Test every commit in main branch (#2781) 2025-11-18 08:07:22 +09:00
Cheng
ecc4879b07 Do not run CPU tests in CUDA builds (#2784) 2025-11-18 07:27:09 +09:00
165 changed files with 14597 additions and 2675 deletions

View File

@@ -2,9 +2,13 @@ name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
toolkit:
description: 'The CUDA toolkit'
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
@@ -12,9 +16,9 @@ runs:
- name: Build package
shell: bash
env:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}

View File

@@ -1,26 +0,0 @@
name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
toolkit:
description: 'The CUDA toolkit'
required: true
runs:
using: "composite"
steps:
- name: Install Python package
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
run: pip install --no-build-isolation -e ".[dev]" -v
- name: Build CPP only
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)

View File

@@ -1,25 +1,19 @@
name: 'Build Documentation'
description: 'Build documentation on a mac'
description: 'Build documentation'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-macos
- name: Setup uv
uses: astral-sh/setup-uv@v6
with:
python-version: "3.10"
activate-environment: true
uses: ./.github/actions/setup-linux
- name: Install dependencies
shell: sh
shell: bash
run: |
brew install doxygen
uv pip install --upgrade pip cmake
uv pip install -r docs/requirements.txt
uv pip install . -v
sudo apt-get install -y doxygen
source .venv/bin/activate
pip install -r docs/requirements.txt
pip install . -v
- name: Build documentation
shell: bash
@@ -30,8 +24,8 @@ runs:
make html O=-W
- name: Create artifact tar
shell: sh
run: tar -cf artifact.tar --cd docs --dereference build/html index.html
shell: bash
run: tar -cf artifact.tar -C docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact

View File

@@ -1,15 +1,32 @@
name: 'Build and Test on Linux'
description: 'Build and test MLX on Linux'
inputs:
toolkit:
description: 'The toolkit to build with'
required: false
default: 'cpu'
runs:
using: "composite"
steps:
- name: Install Python package
id: python_build
shell: sh
env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
DEBUG: 1
run: pip install --no-build-isolation -e ".[dev]" -v
CMAKE_ARGS: >-
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
run: |
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
# There is no GPU in arm64 runner, use a common arch.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
# Can not build tests when the built executables can not run.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
fi
pip install --no-build-isolation -e ".[dev]" -v
# Pass the CMAKE_ARGS to following steps.
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
- name: Generate package stubs
shell: sh
@@ -20,6 +37,5 @@ runs:
- name: Build CPP only
shell: bash
run: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
cmake --build build -j $(nproc)

View File

@@ -17,6 +17,8 @@ runs:
steps:
- name: Build Python package
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
pip install build
python setup.py clean --all
@@ -25,6 +27,8 @@ runs:
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w

View File

@@ -15,6 +15,7 @@ runs:
using: "composite"
steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
@@ -35,7 +36,7 @@ runs:
run: |
python -m venv .venv
source .venv/bin/activate
pip install cmake nanobind==2.4.0
pip install setuptools cmake nanobind==2.4.0
echo PATH=$PATH >> $GITHUB_ENV
# Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
@@ -51,23 +52,23 @@ runs:
# Note: the CI machine does not meet CUDA 13's driver requirement.
# Compatibility matrix:
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# The `nvcc` is installed into `/usr/local/cuda-VERSION/bin/nvcc` - but
# it's *not* on the default toolkit path.
PACKAGES: |
{
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
"cuda-12.8": "libcudnn9-dev-cuda-12 cuda-toolkit-12-8",
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
}
run: |
export ARCH=${{ runner.arch == 'arm64' && 'arm64' || 'x86_64' }}
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
# Jetson specific. SBSA means Arm Server Base System Architecture.
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y \
libnccl2 libnccl-dev \
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
- name: CUDA packages and driver report
if: ${{ startsWith(inputs.toolkit, 'cuda') }}

View File

@@ -1,8 +1,8 @@
name: 'Run Linux tests'
inputs:
cpu-only:
description: 'Skip GPU tests'
has-gpu:
description: 'Run GPU tests'
required: false
default: false
@@ -17,7 +17,7 @@ runs:
echo "::endgroup::"
- name: Run distributed tests
if: ${{ inputs.cpu-only == 'true' }}
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
run: |
echo "::group::Distributed tests"
@@ -30,6 +30,7 @@ runs:
echo "::endgroup::"
- name: Run Python tests - CPU
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
env:
DEVICE: cpu
@@ -39,7 +40,7 @@ runs:
echo "::endgroup::"
- name: Run Python tests - GPU
if: ${{ inputs.cpu-only == 'false' }}
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
@@ -58,7 +59,7 @@ runs:
echo "::endgroup::"
- name: Run CPP tests - GPU
if: ${{ inputs.cpu-only == 'false' }}
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu

View File

@@ -1,39 +1,67 @@
name: Build and Test
on: [pull_request, push]
on:
pull_request:
push:
branches:
- main
# For testing CI without starting a pull request:
- test/*
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
check_lint:
name: Check Lint
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: actions/checkout@v6
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
name: Linux (cpu, ${{ matrix.arch }})
needs: check_lint
strategy:
matrix:
runner:
- ubuntu-22.04
- ubuntu-22.04-arm
fail-fast: false
runs-on: ${{ matrix.runner }}
matrix:
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
cuda_build_and_test:
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
if: github.repository == 'ml-explore/mlx'
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
cpu-only: true
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
if: matrix.arch == 'x86_64'
with:
has-gpu: true
mac_build_and_test:
name: macOS (${{ matrix.macos-target }})
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
@@ -43,38 +71,22 @@ jobs:
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
cuda_build_and_test:
if: github.repository == 'ml-explore/mlx'
strategy:
fail-fast: false
matrix:
toolkit: ['cuda-12.8', 'cuda-12.9']
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-cuda
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
build_documentation:
name: Build Documentation
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
runs-on: ubuntu-22.04
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
name: Linux Fedora (${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
@@ -89,7 +101,7 @@ jobs:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
uses: actions/checkout@v6
- name: CPP Build Test - No Release
run: |

View File

@@ -8,9 +8,9 @@ permissions:
jobs:
build:
runs-on: [self-hosted, macos]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy:

View File

@@ -16,7 +16,7 @@ jobs:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux-release
with:
@@ -40,20 +40,18 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python_version: ["3.11", "3.12", "3.13", "3.14"]
runner:
- ubuntu-22.04
- ubuntu-22.04-arm
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
with:
cpu-only: true
build_mac_release:
if: github.repository == 'ml-explore/mlx'
@@ -62,7 +60,7 @@ jobs:
python-version: ["3.10", "3.13"]
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
@@ -78,28 +76,11 @@ jobs:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
build_cuda_with_tests:
if: github.repository == 'ml-explore/mlx'
strategy:
fail-fast: false
matrix:
toolkit: ['cuda-12.8', 'cuda-12.9']
runs-on: gpu-t4-4-core
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-cuda
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'
@@ -113,25 +94,3 @@ jobs:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -23,9 +23,9 @@ jobs:
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
deploy_documentation:
@@ -53,7 +53,7 @@ jobs:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
@@ -65,14 +65,14 @@ jobs:
uses: actions/upload-artifact@v5
with:
overwrite: true
name: linux-wheels-${{ matrix.python_version }}
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cpu
name: mlx-cpu-${{ matrix.arch }}
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
@@ -86,7 +86,7 @@ jobs:
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
@@ -128,19 +128,23 @@ jobs:
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
strategy:
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: 'cuda-12.9'
toolkit: ${{ matrix.toolkit }}
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
toolkit: 'cuda-12.9'
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
@@ -208,7 +212,8 @@ jobs:
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cpu
pattern: mlx-cpu-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist

View File

@@ -1,6 +1,5 @@
# Copyright © 2023 Apple Inc.
import argparse
import os
import subprocess
import time

View File

@@ -0,0 +1,212 @@
import math
import os
import subprocess
import time
from copy import copy
from functools import partial
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
from matplotlib.ticker import FuncFormatter
RESULTS_DIR = "./results"
if not os.path.isdir(RESULTS_DIR):
os.mkdir(RESULTS_DIR)
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
TORCH_DEVICE = torch.device(
"mps"
if torch.backends.mps.is_available()
else ("cuda" if torch.cuda.is_available() else "cpu")
)
N_WARMUP = 5
N_ITER_BENCH = 50
N_ITER_FUNC = 20
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
D_TYPES = ("float32", "float16")
def _power_of_two_formatter(value, _position):
if value <= 0:
return ""
exponent = int(round(math.log2(value)))
if abs(value - (1 << exponent)) / value > 1e-6:
return f"{value:g}"
return f"$2^{{{exponent}}}$"
def torch_sync():
if TORCH_DEVICE.type == "cuda":
torch.cuda.synchronize()
elif TORCH_DEVICE.type == "mps":
torch.mps.synchronize()
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
outs = []
for _ in range(N_ITER_FUNC):
out = copy(self_arr)
out[mask_arr] = src_arr
outs.append(out)
mx.eval(outs)
return outs
@torch.no_grad()
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
outs = []
for _ in range(N_ITER_FUNC):
out = self_tensor.clone()
out.masked_scatter_(mask_tensor, src_tensor)
outs.append(out)
torch_sync()
return outs
def measure(fn):
for _ in range(N_WARMUP):
fn()
start = time.perf_counter_ns()
for _ in range(N_ITER_BENCH):
fn()
end = time.perf_counter_ns()
return (end - start) * 1e-9
def bytes_touched(length, true_count, item_size):
mask_bytes = length
self_bytes = length * item_size * 2 # read + write
src_bytes = true_count * item_size
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
def build_case(length, density, np_dtype, torch_dtype):
true_count = max(1, int(round(length * density)))
rng = np.random.default_rng()
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
mask_np = np.zeros(length, dtype=bool)
mask_np[:true_count] = True
rng.shuffle(mask_np)
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
self_mlx = mx.array(self_np)
mask_mlx = mx.array(mask_np)
src_mlx = mx.array(src_np)
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
# Correctness check once per configuration
mx_out = mx.array(self_np)
mx_out[mask_mlx] = src_mlx
mx.eval(mx_out)
torch_out = self_torch.clone()
torch_out.masked_scatter_(mask_torch, src_torch)
atol = 5e-3 if np_dtype == np.float16 else 1e-5
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
raise AssertionError("masked_scatter results diverged between MLX and Torch")
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
def bench_case(length, density, dtype):
np_dtype = getattr(np, dtype)
torch_dtype = getattr(torch, dtype)
(
self_mlx,
mask_mlx,
src_mlx,
self_torch,
mask_torch,
src_torch,
true_count,
) = build_case(length, density, np_dtype, torch_dtype)
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
time_torch = measure(
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
)
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
bytes_per_gb = float(1024**3)
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
return time_mlx, time_torch, mlx_gbps, torch_gbps
def plot_density(ax_perf, ax_speedup, density, dtype):
mlx_gbps = []
torch_gbps = []
mlx_times = []
torch_times = []
for length in VECTOR_LENGTHS:
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
mlx_gbps.append(gbps_mlx)
torch_gbps.append(gbps_torch)
mlx_times.append(t_mlx)
torch_times.append(t_torch)
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
ax_perf.set_xscale("log", base=2)
ax_perf.set_xticks(VECTOR_LENGTHS)
formatter = FuncFormatter(_power_of_two_formatter)
ax_perf.xaxis.set_major_formatter(formatter)
ax_perf.set_title(f"density={density:.2f}")
ax_perf.set_ylabel("GB/s")
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
ax_perf.legend()
speedup = np.array(torch_times) / np.array(mlx_times)
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
ax_speedup.set_xscale("log", base=2)
ax_speedup.set_xticks(VECTOR_LENGTHS)
ax_speedup.xaxis.set_major_formatter(formatter)
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
def main():
for dtype in D_TYPES:
fig, axs = plt.subplots(
len(MASK_DENSITIES),
2,
figsize=(10, 12),
layout="constrained",
sharex=True,
)
for i, density in enumerate(MASK_DENSITIES):
plot_density(axs[i][0], axs[i][1], density, dtype)
axs[i][0].set_xlabel("vector length")
axs[i][1].set_xlabel("vector length")
fig.suptitle(
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
)
output_path = os.path.join(
RESULTS_DIR,
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
)
fig.savefig(output_path)
plt.close(fig)
if __name__ == "__main__":
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

View File

@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell
pip install mlx[cuda]
pip install mlx[cuda12]
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia architecture >= SM 7.5
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux)
^^^^^^^^^^^^^^^^

View File

@@ -7,22 +7,29 @@ Distributed Communication
MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the
moment we support three different communication backends:
moment we support several different communication backends introduced below.
.. list-table::
:widths: 20 80
:header-rows: 1
* - Backend
- Description
* - :ref:`MPI <mpi_section>`
- A full featured and mature distributed communications library.
* - :ref:`RING <ring_section>`
- Ring all reduce and all gather over TCP sockets. Always available and
usually faster than MPI.
* - :ref:`JACCL <ring_section>`
- Low latency communication with RDMA over thunderbolt. Necessary for
things like tensor parallelism.
* - :ref:`NCCL <nccl_section>`
- The backend of choice for CUDA environments.
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets. It should be
faster for thunderbolt connections, but it also works over Ethernet.
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`.
.. note::
Some operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
@@ -85,7 +92,7 @@ Selecting Backend
^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
one of ``{'any', 'ring', 'jaccl', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
available backends. If they all fail then a singleton group is created.
.. note::
@@ -110,6 +117,8 @@ The following examples aim to clarify the backend initialization logic in MLX:
world_ring = mx.distributed.init(backend="ring")
world_any = mx.distributed.init() # same as MPI because it was initialized first!
.. _training_example:
Training Example
----------------
@@ -192,16 +201,273 @@ almost identical to the example above:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
.. _ring_section:
Getting Started with Ring
-------------------------
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver is not supported in the ring backend.
Defining a Ring
^^^^^^^^^^^^^^^
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
.. code:: json
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
Thunderbolt Ring
^^^^^^^^^^^^^^^^
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
.. code:: shell
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --backend ring
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.
.. _jaccl_section:
Getting Started with RDMA over Thunderbolt
------------------------------------------
Starting from version 26.2 RDMA over thunderbolt is available in MacOS and
enables low-latency communication between Macs with thunderbolt 5. MLX provides
the JACCL backend that uses this functionality to achieve communication latency
an order of magnitude lower than the ring backend.
.. note::
The name JACCL (pronounced Jackal) stands for *Jack and Angelos' Collective
Communication Library* and it is an obvious pun to Nvidia's NCCL but also
tribute to *Jack Beasley* who led the development of RDMA over Thunderbolt
at Apple.
Enabling RDMA
^^^^^^^^^^^^^
Until the feature matures, enabling RDMA over thunderbolt is slightly more
involved and **cannot** be done remotely even with sudo. In fact, it has to be
done in macOS recovery:
1. `Start your computer in recovery <https://support.apple.com/en-us/102518>`_.
2. Open the Terminal by going to Utilities -> Terminal.
3. Run ``rdma_ctl enable``.
4. Reboot.
To verify that you have successfully enabled Thunderbolt RDMA you can run
``ibv_devices`` which should produce something like the following for an M3 Ultra.
.. code-block:: bash
~ % ibv_devices
device node GUID
------ ----------------
rdma_en2 8096a9d9edbaac05
rdma_en3 8196a9d9edbaac05
rdma_en5 8396a9d9edbaac05
rdma_en4 8296a9d9edbaac05
rdma_en6 8496a9d9edbaac05
rdma_en7 8596a9d9edbaac05
Defining a Mesh
^^^^^^^^^^^^^^^
The JACCL backend supports only fully connected topologies. Namely, there needs
to be a thunderbolt cable connecting all pairs of Macs directly. For example, in
the following topology visualizations, the left one is valid because there is a
connection from any node to any other node, while for the one on the right M3
Ultra 1 is not connected to M3 Ultra 2.
.. raw:: html
<div style="display: flex; text-align: center; align-items: end; font-size: 80%;">
<div>
<img src="/_static/distributed/m3-ultra-mesh.png" alt="M3 Ultra thunderbolt mesh" style="width: 55%">
<p>Fully connected mesh of four M3 Ultra.</p>
</div>
<div>
<img src="/_static/distributed/m3-ultra-mesh-broken.png" alt="M3 Ultra broken thunderbolt mesh" style="width: 55%">
<p>Not a valid mesh (M3 Ultra 1 is not connected to M3 Ultra 2).</p>
</div>
</div>
Similar to the ring backend, the easiest way to use JACCL with MLX is to write
a JSON hostfile that will be used by ``mlx.launch``. The hostfile needs to contain
- Hostnames to use for launching scripts via ssh
- An IP for rank 0 that is reachable by all nodes
- A list of rdma devices that connect each node to each other node
The following JSON defines the valid 4-node mesh from the image above.
.. code-block:: json
[
{
"ssh": "m3-ultra-1",
"ips": ["123.123.123.1"],
"rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"]
},
{
"ssh": "m3-ultra-2",
"ips": [],
"rdma": ["rdma_en5", null, "rdma_en3", "rdma_en4"]
},
{
"ssh": "m3-ultra-3",
"ips": [],
"rdma": ["rdma_en4", "rdma_en3", null, "rdma_en5"]
},
{
"ssh": "m3-ultra-4",
"ips": [],
"rdma": ["rdma_en3", "rdma_en4", "rdma_en5", null]
}
]
Even though TCP/IP is not used when communicating with Thunderbolt RDMA,
disabling the thunderbolt bridge is still required as well as setting up
isolated local networks for each thunderbolt connection.
All of the above can be done instead via ``mlx.distributed_config``. This helper
script will
- ssh into each node
- extract the thunderbolt connectivity
- check for a valid mesh
- provide the commands to configure each node (or run them if sudo is available)
- generate the hostfile to be used with ``mlx.launch``
Putting it All Together
^^^^^^^^^^^^^^^^^^^^^^^^
For example launching a distributed MLX script that uses JACCL is fairly simple
if the nodes are reachable via ssh and have password-less sudo.
First, connect all the thunderbolt cables. Then we can verify the connections
by using the ``mlx.distributed_config`` script to visualize them.
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --dot | dot -Tpng | open -f -a Preview
After making sure that everything looks right we can auto-configure the nodes
and save the hostfile to ``m3-ultra-jaccl.json`` by running:
.. code-block::
mlx.distributed_config --verbose \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 \
--over thunderbolt --backend jaccl \
--auto-setup --output m3-ultra-jaccl.json
And now we are ready to run a distributed MLX script such as distributed inference
of a gigantic model using MLX-LM.
.. code-block::
mlx.launch --verbose --backend jaccl --hostfile m3-ultra-jaccl.json \
--env MLX_METAL_FAST_SYNCH=1 -- \ # <--- important
/path/to/remote/python -m mlx_lm chat --model mlx-community/DeepSeek-V3.2-8bit --shard
.. note::
Defining the environment variable ``MLX_METAL_FAST_SYNCH=1`` enables a
different, faster way of synchronizing between the GPU and the CPU. It is
not specific to the JACCL backend and can be used in all cases where the CPU
and GPU need to collaborate for some computation and is pretty critical for
low-latency communication since the communication is done by the CPU.
.. _nccl_section:
Getting Started with NCCL
-------------------------
MLX on CUDA environments ships with the ability to talk to `NCCL
<https://developer.nvidia.com/nccl>`_ which is a high-performance collective
communication library that supports both multi-gpu and multi-node setups.
For CUDA environments, NCCL is the default backend for ``mlx.launch`` and all
it takes to run a distributed job is
.. code-block::
mlx.launch -n 8 test.py
# perfect for interactive scripts
mlx.launch -n 8 python -m mlx_lm chat --model my-model --shard
You can also use ``mlx.launch`` to ssh to a remote node and launch a script
with the same ease
.. code-block::
mlx.launch --hosts my-cuda-node -n 8 test.py
In many cases you may not want to use ``mlx.launch`` with the NCCL backend
because the cluster scheduler will be the one launching the processes. You can
:ref:`see which environment variables need to be defined <no_mlx_launch>` in
order for the MLX NCCL backend to be initialized correctly.
.. _mpi_section:
Getting Started with MPI
------------------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. Launching distributed MLX programs that use MPI can be done with
``mpirun`` as expected. However, in the following examples we will be using
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
library.
MLX already comes with the ability to "talk" to `MPI
<https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ if it is installed
on the machine. Launching distributed MLX programs that use MPI can be done
with ``mpirun`` as expected. However, in the following examples we will be
using ``mlx.launch --backend mpi`` which takes care of some nuisances such as
setting absolute paths for the ``mpirun`` executable and the ``libmpi.dyld``
shared library.
The simplest possible usage is the following which, assuming the minimal
example in the beginning of this page, should result in:
@@ -269,78 +535,116 @@ Force MPI to use the most performant network interface by setting ``--mca
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
to use.
Getting Started with Ring
-------------------------
.. _no_mlx_launch:
The ring backend does not depend on any third party library so it is always
available. It uses TCP sockets so the nodes need to be reachable via a network.
As the name suggests the nodes are connected in a ring which means that rank 1
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
and so on and so forth. As a result :func:`send` and :func:`recv` with
arbitrary sender and receiver is not supported in the ring backend.
Distributed Without ``mlx.launch``
----------------------------------
Defining a Ring
^^^^^^^^^^^^^^^
None of the implementations of the distributed backends require launching with
``mlx.launch``. The script simply connects to each host. Starts a process per
rank and sets up the necessary environment variables before delegating to your
MLX script. See the :doc:`dedicated documentation page <launching_distributed>`
for more details.
The easiest way to define and use a ring is via a JSON hostfile and the
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
defines a hostname to ssh into to run commands on this node and one or more IPs
that this node will listen to for connections.
For many use-cases this will be the easiest way to perform distributed
computations in MLX. However, there may be reasons that you cannot or should
not use ``mlx.launch``. A common such case is the use of a scheduler that
starts all the processes for you on machines undetermined at the time of
scheduling the job.
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
rank 0, ``hostname2`` rank 1 etc.
Below we list the environment variables required to use each backend.
.. code:: json
Ring
^^^^^^
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
]
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
node, run the script which will listen for connections in each of the provided
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
connection from ``123.123.123.4`` and so on and so forth.
**MLX_HOSTFILE** should contain the path to a json file that contains IPs and
ports for each rank to listen to, something like the following:
Thunderbolt Ring
^^^^^^^^^^^^^^^^
.. code-block:: json
Although the ring backend can have benefits over MPI even for Ethernet, its
main purpose is to use Thunderbolt rings for higher bandwidth communication.
Setting up such thunderbolt rings can be done manually, but is a relatively
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
[
["123.123.1.1:5000", "123.123.1.2:5000"],
["123.123.2.1:5000", "123.123.2.2:5000"],
["123.123.3.1:5000", "123.123.3.2:5000"],
["123.123.4.1:5000", "123.123.4.2:5000"]
]
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
utility as follows:
**MLX_RING_VERBOSE** is optional and if set to 1 it enables some more logging
from the distributed backend.
.. code:: shell
JACCL
^^^^^
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
By default the script will attempt to discover the thunderbolt ring and provide
you with the commands to configure each node as well as the ``hostfile.json``
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
then ``--auto-setup`` can be used to configure them automatically.
**MLX_JACCL_COORDINATOR** should contain the IP and port that rank 0 can listen
to all the other ranks connect to in order to establish the RDMA connections.
To validate your connection without configuring anything
``mlx.distributed_config`` can also plot the ring using DOT format.
**MLX_IBV_DEVICES** should contain the path to a json file that contains the
ibverbs device names that connect each node to each other node, something like
the following:
.. code:: shell
.. code-block:: json
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
dot -Tpng ring.dot >ring.png
open ring.png
[
[null, "rdma_en5", "rdma_en4", "rdma_en3"],
["rdma_en5", null, "rdma_en3", "rdma_en4"],
["rdma_en4", "rdma_en3", null, "rdma_en5"],
["rdma_en3", "rdma_en4", "rdma_en5", null]
]
If you want to go through the process manually, the steps are as follows:
* Disable the thunderbolt bridge interface
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
corresponding to that cable in nodes ``i`` and ``i + 1``.
* Set up a unique subnetwork connecting the two nodes for the corresponding
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
``192.168.0.2`` respectively to the two nodes. For more details you can see
the commands prepared by the utility script.
NCCL
^^^^^
**MLX_RANK** should contain a single 0-based integer that defines the rank of
the process.
**MLX_WORLD_SIZE** should contain the total number of processes that will be
launched.
**NCCL_HOST_IP** and **NCCL_PORT** should contain the IP and port that all
hosts can connect to to establish the NCCL communication.
**CUDA_VISIBLE_DEVICES** should contain the local index of the gpu that
corresponds to this process.
Of course any `other environment variable
<https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`_ that is
used by NCCL can be set.
.. _tips_and_tricks:
Tips and Tricks
----------------
This is a small collection of tips to help you utilize better the distributed
communication capabilities of MLX.
- *Test locally first.*
You can use the pattern ``mlx.launch -n2 -- my_script.py`` to run a small
scale test on a single node first.
- *Batch your communication.*
As described in the :ref:`training example <training_example>`, performing a
lot of small communication can hurt performance. Copy the approach of
:func:`mlx.nn.average_gradients` to gather many small communications in a
single large one.
- *Visualize the connectivity.*
Use ``mlx.distributed_config --hosts h1,h2,h3 --over thunderbolt --dot`` to
visualize the connnections and make sure that the cables are connected
correctly. See the :ref:`JACCL section <jaccl_section>` for examples.
- *Use the debugger.*
``mlx.launch`` is meant for interactive use. It broadcasts stdin to all
processes and gathers stdout from all processes. This makes using ``pdb`` a
breeze.

View File

@@ -70,7 +70,8 @@ Differences from NumPy
* Indexing does not perform bounds checking. Indexing out of bounds is
undefined behavior.
* Boolean mask based indexing is not yet supported.
* Boolean mask based indexing is supported for assignment only (see
:ref:`boolean-mask-assignment`).
The reason for the lack of bounds checking is that exceptions cannot propagate
from the GPU. Performing bounds checking for array indices before launching the
@@ -143,3 +144,51 @@ expected. For example:
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
and ones elsewhere.
.. _boolean-mask-assignment:
Boolean Mask Assignment
-----------------------
MLX supports boolean indices using NumPy syntax. A mask must already be
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
Other index types are routed through the standard scatter code.
.. code-block:: shell
>>> a = mx.array([1.0, 2.0, 3.0])
>>> mask = mx.array([True, False, True])
>>> updates = mx.array([5.0, 6.0])
>>> a[mask] = updates
>>> a
array([5.0, 2.0, 6.0], dtype=float32)
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
assignments, ``updates`` must provide at least as many elements as there are
``True`` entries in ``mask``.
.. code-block:: shell
>>> a = mx.zeros((2, 3))
>>> mask = mx.array([[True, False, True],
[False, False, True]])
>>> a[mask] = 1.0
>>> a
array([[1.0, 0.0, 1.0],
[0.0, 0.0, 1.0]], dtype=float32)
Boolean masks follow NumPy semantics:
- The mask shape must match the shape of the axes it indexes exactly. The only
exception is a scalar boolean mask, which broadcasts to the full array.
- Any axes not covered by the mask are taken in full.
.. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
axes and therefore raise errors.

View File

@@ -7,13 +7,106 @@ Launching Distributed Programs
.. currentmodule:: mlx.core.distributed
Installing the MLX python package provides a helper script ``mlx.launch`` that
can be used to run python scripts distributed on several nodes. It allows
launching using either the MPI backend or the ring backend. See the
:doc:`distributed docs <distributed>` for the different backends.
Installing the MLX python package provides two utilities to help you configure
your Macs for distributed computation and also launch distributed programs on
multiple nodes or with many processes in a single node. These utilities are aptly named
Usage
-----
- ``mlx.launch``
- ``mlx.distributed_config``
See the :doc:`distributed docs <distributed>` for an introduction and
getting-started guides to the various backends.
``mlx.distributed_config``
---------------------------
Unless you are launching distributed jobs locally for development or multi-gpu
CUDA environments, then you have several Macs that you need to configure for
distributed communication with MLX.
``mlx.distributed_config`` aims to automate the process of configuring the
network interfaces (especially for communication over thunderbolt) and also
creating the hostfile to be used with ``mlx.launch``.
We will analyse 3 cases of using ``mlx.distributed_config``
1. RDMA over thunderbolt using JACCL
2. TCP/IP over thunderbolt using the ring backend
3. TCP/IP over ethernet using the ring backend
JACCL
^^^^^^^
After following :ref:`the steps to enable RDMA <jaccl_section>` you can run the
following command to configure the nodes and create the hostfile.
.. code-block::
mlx.distributed_config --verbose --backend jaccl \
--hosts m3-ultra-1,m3-ultra-2,m3-ultra-3,m3-ultra-4 --over thunderbolt \
--auto-setup --output m3-ultra-jaccl.json
Let's walk through the steps that the script takes to configure the nodes.
1. Ssh to all nodes to verify that they are reachable
2. Extract the thunderbolt connectivity. Namely run commands on each node to
calculate which node is connected to which other node.
3. Verify that we have a valid fully connected mesh
4. Check that RDMA is enabled
5. Extract the ethernet IP from interface en0
6. Disable the thunderbolt bridge and set up peer to peer networks for each
thunderbolt cable
7. Write the hostfile
Knowing the above steps allows you to manually configure the nodes but also
debug any configuration issue. For instance changing the Ethernet IP to a
different interface directly in the config is possible (as long as it is
reachable from all nodes).
The ``--auto-setup`` argument requires password-less sudo on each node. If it
isn't available then the configuration script will print commands to be run on
each node.
Ring over thunderbolt
^^^^^^^^^^^^^^^^^^^^^
Setting up a ring backend over thunderbolt only requires changing the
``--backend`` from ``jaccl`` to ``ring``.
The steps are very similar with the main difference being that instead of
verifying that the nodes are fully connected, the script attempts to identify a
ring topology (or multiple rings).
Ring over Ethernet
^^^^^^^^^^^^^^^^^^
Configuring the ring backend over ethernet doesn't require setting up network
interface and as such it simply extracts the ``en0`` IP from each node and
writes the hostfile.
Debugging cable connections
^^^^^^^^^^^^^^^^^^^^^^^^^^^
``mlx.distributed_config`` can help you debug the connectivity of your nodes
over thunderbolt by exporting a graph of the connections.
Running
.. code-block::
mlx.distributed_config --verbose \
--hosts host1,host2,host3,host4 \
--over thunderbolt --dot
will export a `GraphViz <https://graphviz.org>`_ representation of the
connections between the nodes which makes it very easy to figure out which
cable is not connected correctly.
See :ref:`the JACCL section <jaccl_section>` for an example.
``mlx.launch``
--------------
The minimal usage example of ``mlx.launch`` is simply
@@ -33,6 +126,10 @@ the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
It also takes care of forwarding the output of each remote process to stdout
and stderr respectively.
Importantly, it also broadcasts stdin to each process which enables interactive
programs to work in distributed mode as well as debugging using the interactive
debugger.
Providing Hosts
^^^^^^^^^^^^^^^^
@@ -63,10 +160,62 @@ host and on the same path. A good checklist to debug errors is the following:
``mlx.launch --print-python`` to see what that path is.
* the script you want to run is available on all hosts at the same path
If you are launching from a node with a completely different setup than the
nodes that the program will run on, you can specify ``--no-verify-script`` so
that ``mlx.launch`` does not attempt to verify that the executable and script
exist locally before launching the distributed job.
.. _ring_specifics:
Ring Specifics
^^^^^^^^^^^^^^
The :ref:`ring <ring_section>` backend, which is also the default
backend, can be explicitly selected with the argument ``--backend ring``. The
ring backend has some specific requirements and arguments that are different to
other backends:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.
.. _jaccl_specifics:
JACCL Specifics
^^^^^^^^^^^^^^^^
The :ref:`JACCL <jaccl_section>` backend can be selected with the argument
``--backend jaccl``. A hostfile is necessary to launch with this backend
because it needs to contain the RDMA devices connecting each node to each other
node.
NCCL Specifics
^^^^^^^^^^^^^^
The :ref:`NCCL <nccl_section>` backend is the default backend for CUDA
environments. When launching from a Mac to a Linux machine with CUDA then the
backend should be selected using ``--backend nccl``.
The ``--repeat-hosts, -n`` argument should be used to launch multi-node and
multi-gpu jobs. For instance
.. code-block::
mlx.launch --backend nccl --hosts linux-1,linux-2 -n 8 --no-verify-script -- ./my-job.sh
will attempt to launch 16 processes, 8 on each node that will all run
``my-job.sh``.
.. _mpi_specifics:
MPI Specifics
-------------
^^^^^^^^^^^^^
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
@@ -83,23 +232,3 @@ to choose a specific interface for the byte-transfer-layer of MPI we can call
.. code:: shell
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
.. _ring_specifics:
Ring Specifics
--------------
The ring backend, which is also the default backend, can be explicitly selected
with the argument ``--backend ring``. The ring backend has some specific
requirements and arguments that are different to MPI:
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
ssh to a hostname that does not correspond to the IP we want to bind to we
have to provide a hostfile.
* ``--starting-port`` defines the port to bind to on the remote hosts.
Specifically rank 0 for the first IP will use this port and each subsequent
IP or rank will add 1 to this port.
* ``--connections-per-ip`` allows us to increase the number of connections
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
``mpirun``.

View File

@@ -1,7 +1,6 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp

View File

@@ -1,24 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <sstream>
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
void free(Buffer buffer) {
allocator().free(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -28,16 +28,16 @@ class Buffer {
};
};
Buffer malloc(size_t size);
void free(Buffer buffer);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
virtual Buffer make_buffer(void* ptr, size_t size) {
return Buffer{nullptr};
};
virtual void release(Buffer buffer) {}
Allocator() = default;
Allocator(const Allocator& other) = delete;
@@ -49,4 +49,25 @@ class Allocator {
Allocator& allocator();
inline Buffer malloc(size_t size) {
return allocator().malloc(size);
}
inline void free(Buffer buffer) {
allocator().free(buffer);
}
// Make a Buffer from a raw pointer of the given size without a copy. If a
// no-copy conversion is not possible then the returned buffer.ptr() will be
// nullptr. Any buffer created with this function must be released with
// release(buffer)
inline Buffer make_buffer(void* ptr, size_t size) {
return allocator().make_buffer(ptr, size);
};
// Release a buffer from the allocator made with make_buffer
inline void release(Buffer buffer) {
allocator().release(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -82,6 +82,28 @@ array::array(std::initializer_list<int> data, Dtype dtype)
init(data.begin());
}
array::array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
auto buffer = allocator::make_buffer(data, nbytes());
if (buffer.ptr() == nullptr) {
set_data(allocator::malloc(nbytes()));
auto ptr = static_cast<char*>(data);
std::copy(ptr, ptr + nbytes(), this->data<char>());
deleter(data);
} else {
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
auto ptr = buffer.ptr();
allocator::release(buffer);
return deleter(ptr);
};
set_data(buffer, std::move(wrapped_deleter));
}
}
/* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {

View File

@@ -57,6 +57,16 @@ class array {
Shape shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a raw pointer. The constructor will attempt to use the
* input data without a copy. The deleter will be called when the array no
* longer needs the underlying memory - after the array is destroyed in the
* no-copy case and after the copy otherwise. */
explicit array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter);
/* Build an array from a buffer */
explicit array(
allocator::Buffer data,

View File

@@ -12,6 +12,167 @@ namespace mlx::core {
namespace {
template <typename T>
complex64_t to_complex(T r, T i) {
return {static_cast<float>(r), static_cast<float>(i)};
}
template <typename T, class Enable = void>
struct EigWork {};
template <typename T>
struct EigWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using O = complex64_t;
char jobl;
char jobr;
int N;
int lwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
T work;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
if (compute_eigenvectors) {
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
}
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, O* values, O* vectors) {
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
T* vec_tmp = nullptr;
if (vectors) {
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
}
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
eig_tmp,
eig_tmp + N,
vectors ? vec_tmp : nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
work,
&lwork,
&info);
for (int i = 0; i < N; ++i) {
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
}
if (vectors) {
for (int i = 0; i < N; ++i) {
if (values[i].imag() != 0) {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] =
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
vectors[(i + 1) * N + j] =
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
}
}
}
}
}
};
template <>
struct EigWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
using O = T;
char jobl;
char jobr;
int N;
int lwork;
int lrwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
T work;
R rwork;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&rwork,
&info);
lwork = static_cast<int>(work.real());
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
}
void run(T* a, T* values, T* vectors) {
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
values,
vectors,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&info);
}
};
template <typename T>
void eig_impl(
array& a,
@@ -19,101 +180,39 @@ void eig_impl(
array& values,
bool compute_eigenvectors,
Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>();
auto eig_ptr = values.data<OT>();
auto val_ptr = values.data<complex64_t>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
OT* vec_ptr = nullptr;
complex64_t* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<OT>();
vec_ptr = vectors.data<complex64_t>();
}
encoder.dispatch([a_ptr,
val_ptr,
vec_ptr,
eig_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
int info;
{
T work;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
}
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
for (size_t i = 0; i < size / (N * N); ++i) {
geev<T>(
&jobl,
&jobr,
&N,
a_ptr,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
}
work.run(a_ptr, val_ptr, vec_ptr);
a_ptr += N * N;
val_ptr += N;
if (vec_ptr) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
}
}
}
vec_ptr += N * N;
}
a_ptr += N * N;
eig_ptr += N;
if (info != 0) {
if (work.info != 0) {
std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
<< work.info;
throw std::runtime_error(msg.str());
}
}
@@ -165,8 +264,17 @@ void Eig::eval_cpu(
case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case float64:
eig_impl<double>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case complex64:
eig_impl<std::complex<float>>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
throw std::runtime_error(
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
}
}

View File

@@ -747,4 +747,108 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
});
}
template <typename T>
void masked_scatter_impl(const array& mask, const array& src, array& out) {
ContiguousIterator mask_it(mask);
ContiguousIterator src_it(src);
ContiguousIterator out_it(out);
const bool* mask_ptr = mask.data<bool>();
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
const size_t batch_count = mask.shape(0);
const size_t mask_batch_size = mask.size() / batch_count;
const size_t src_batch_size = src.size() / batch_count;
for (uint b = 0; b < batch_count; ++b) {
size_t src_consumed = 0;
src_it.seek(b * src_batch_size);
for (size_t i = 0; i < mask_batch_size; ++i) {
if (mask_ptr[mask_it.loc]) {
if (src_consumed >= src_batch_size) {
throw std::runtime_error(
"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.");
}
dst_ptr[out_it.loc] = src_ptr[src_it.loc];
src_it.step();
++src_consumed;
}
mask_it.step();
out_it.step();
}
}
}
void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
auto& dst = inputs[0];
auto& mask = inputs[1];
auto& src = inputs[2];
// Copy src into out (copy allocates memory for out)
auto ctype =
dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(dst, out, ctype, stream());
if (mask.size() == 0) {
return;
}
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(mask);
encoder.set_input_array(src);
encoder.set_output_array(out);
encoder.dispatch([mask = array::unsafe_weak_copy(mask),
src = array::unsafe_weak_copy(src),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
masked_scatter_impl<bool>(mask, src, out);
break;
case uint8:
masked_scatter_impl<uint8_t>(mask, src, out);
break;
case uint16:
masked_scatter_impl<uint16_t>(mask, src, out);
break;
case uint32:
masked_scatter_impl<uint32_t>(mask, src, out);
break;
case uint64:
masked_scatter_impl<uint64_t>(mask, src, out);
break;
case int8:
masked_scatter_impl<int8_t>(mask, src, out);
break;
case int16:
masked_scatter_impl<int16_t>(mask, src, out);
break;
case int32:
masked_scatter_impl<int32_t>(mask, src, out);
break;
case int64:
masked_scatter_impl<int64_t>(mask, src, out);
break;
case float16:
masked_scatter_impl<float16_t>(mask, src, out);
break;
case float32:
masked_scatter_impl<float>(mask, src, out);
break;
case float64:
masked_scatter_impl<double>(mask, src, out);
break;
case bfloat16:
masked_scatter_impl<bfloat16_t>(mask, src, out);
break;
case complex64:
masked_scatter_impl<complex64_t>(mask, src, out);
break;
}
});
}
} // namespace mlx::core

View File

@@ -45,9 +45,7 @@
INSTANTIATE_LAPACK_REAL(geqrf)
INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesdd)
INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri)
@@ -63,3 +61,20 @@ INSTANTIATE_LAPACK_REAL(trtri)
}
INSTANTIATE_LAPACK_COMPLEX(heevd)
#define INSTANTIATE_LAPACK_ALL(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, double>) { \
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_ALL(geev)
INSTANTIATE_LAPACK_ALL(gesdd)

View File

@@ -8,6 +8,183 @@
namespace mlx::core {
template <typename T, class Enable = void>
struct SVDWork {};
template <typename T>
struct SVDWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <>
struct SVDWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
const int lrwork =
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
int lwork_query = -1;
int work_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension.real();
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <typename T>
void svd_impl(
const array& a,
@@ -27,6 +204,8 @@ void svd_impl(
const int N = a.shape(-1);
const int K = std::min(M, N);
using R = typename SVDWork<T>::R;
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
@@ -42,7 +221,7 @@ void svd_impl(
encoder.set_input_array(a);
auto in_ptr = in.data<T>();
T* u_ptr;
T* s_ptr;
R* s_ptr;
T* vt_ptr;
if (compute_uv) {
@@ -58,7 +237,7 @@ void svd_impl(
encoder.set_output_array(s);
encoder.set_output_array(vt);
s_ptr = s.data<T>();
s_ptr = s.data<R>();
u_ptr = u.data<T>();
vt_ptr = vt.data<T>();
} else {
@@ -68,96 +247,26 @@ void svd_impl(
encoder.set_output_array(s);
s_ptr = s.data<T>();
s_ptr = s.data<R>();
u_ptr = nullptr;
vt_ptr = nullptr;
}
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
auto jobz = (u_ptr) ? "A" : "N";
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
static const int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto jobz = (u_ptr) ? 'A' : 'N';
SVDWork<T> svd_work(N, M, K, jobz);
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
svd_work.run(
in_ptr + M * N * i,
s_ptr + K * i,
vt_ptr ? vt_ptr + N * N * i : nullptr,
u_ptr ? u_ptr + M * M * i : nullptr);
}
});
encoder.add_temporary(in);
}
template <typename T>
void compute_svd(
const array& a,
bool compute_uv,
std::vector<array>& outputs,
Stream stream) {}
void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -168,9 +277,12 @@ void SVD::eval_cpu(
case float64:
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
break;
case complex64:
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
break;
default:
throw std::runtime_error(
"[SVD::eval_cpu] only supports float32 or float64.");
"[SVD::eval_cpu] only supports float32, float64, or complex64.");
}
}

View File

@@ -123,14 +123,21 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif()
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
# Use native CUDA arch by default.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
execute_process(
COMMAND bash detect_cuda_arch.sh
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMAND __nvcc_device_query
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
OUTPUT_STRIP_TRAILING_WHITESPACE)
set(UPGRADABLE_ARCHITECTURES "90;100;121")
if(MLX_CUDA_ARCHITECTURES STREQUAL "")
message(
FATAL_ERROR
"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
# Use arch-specific compute capability whenever possible.
set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
endif()
endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
@@ -142,6 +149,7 @@ FetchContent_Declare(
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl)
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
set_target_properties(mlx PROPERTIES CCCL_DIR "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX.
FetchContent_Declare(
@@ -167,7 +175,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.14.0
GIT_TAG v1.16.0
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)

View File

@@ -20,6 +20,19 @@ constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8;
#if CUDART_VERSION >= 13000
inline cudaMemLocation cuda_mem_loc(int i) {
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
return loc;
}
#else
inline int cuda_mem_loc(int i) {
return i;
}
#endif // CUDART_VERSION >= 13000
// The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;
@@ -35,13 +48,7 @@ SmallSizePool::SmallSizePool() {
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
auto loc = cuda_mem_loc(i);
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
@@ -90,9 +97,10 @@ CudaAllocator::CudaAllocator()
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
size_t free;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
memory_limit_ = total_memory_ * 0.95;
free_limit_ = total_memory_ - memory_limit_;
max_pool_size_ = memory_limit_;
int device_count = 0;
@@ -104,6 +112,10 @@ CudaAllocator::CudaAllocator()
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
cudaMemPool_t mem_pool;
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
mem_pools_.push_back(mem_pool);
}
CHECK_CUDA_ERROR(cudaSetDevice(curr));
}
@@ -119,7 +131,8 @@ void copy_to_managed(CudaBuffer& buf) {
buf.data = new_data;
}
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
Buffer
CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
if (size == 0) {
return Buffer{new CudaBuffer{nullptr, 0, -1}};
}
@@ -134,9 +147,8 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
size = page_size * ((size + page_size - 1) / page_size);
}
int device = -1;
if (size > small_block_size && stream != nullptr) {
CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device));
if (size <= small_block_size || stream == nullptr) {
device = -1;
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
@@ -154,19 +166,35 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
}
lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size, device};
cudaError_t err;
void* data = nullptr;
if (device == -1) {
err = cudaMallocManaged(&buf->data, size);
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
} else {
err = cudaMallocAsync(&buf->data, size, stream);
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
if (!data) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
buf = new CudaBuffer{data, size, device};
}
lock.lock();
// If any cuda memory pool has too much reserved memory, clear some
// memory from the cache. This prevents graph / kernel execution failing
// from OOM
if (get_cache_memory() > 0) {
for (auto p : mem_pools_) {
size_t used = 0;
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
p, cudaMemPoolAttrReservedMemCurrent, &used));
if (used > (total_memory_ - free_limit_)) {
buffer_cache_.release_cached_buffers(free_limit_);
break;
}
}
}
}
active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_);
@@ -176,18 +204,14 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
// Copy to managed here if the buffer is not on the right device
if (buf->device != device) {
if (buf->device >= 0 && buf->device != device) {
copy_to_managed(*buf);
}
return Buffer{buf};
}
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
return malloc_impl(size, stream);
}
Buffer CudaAllocator::malloc(size_t size) {
return malloc_impl(size, nullptr);
return malloc_async(size, -1, nullptr);
}
void CudaAllocator::free(Buffer buffer) {
@@ -223,9 +247,9 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
scalar_pool_.free(buf);
} else {
if (buf->device >= 0) {
cudaFreeAsync(buf->data, free_streams_[buf->device]);
CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
} else {
cudaFree(buf->data);
CHECK_CUDA_ERROR(cudaFree(buf->data));
}
delete buf;
}
@@ -277,8 +301,9 @@ CudaAllocator& allocator() {
return *allocator_;
}
Buffer malloc_async(size_t size, cudaStream_t stream) {
auto buffer = allocator().malloc_async(size, stream);
Buffer malloc_async(size_t size, CommandEncoder& encoder) {
auto buffer = allocator().malloc_async(
size, encoder.device().cuda_device(), encoder.stream());
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes.";

View File

@@ -13,6 +13,8 @@
namespace mlx::core::cu {
class CommandEncoder;
using allocator::Buffer;
// Stores cuda-managed unified memory.
@@ -48,7 +50,7 @@ class SmallSizePool {
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, cudaStream_t stream);
Buffer malloc_async(size_t size, int device, cudaStream_t stream);
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
@@ -62,7 +64,6 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache();
private:
Buffer malloc_impl(size_t size, cudaStream_t stream);
void cuda_free(CudaBuffer* buf);
CudaAllocator();
@@ -70,16 +71,19 @@ class CudaAllocator : public allocator::Allocator {
std::mutex mutex_;
size_t memory_limit_;
size_t free_limit_;
size_t total_memory_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_;
std::vector<cudaMemPool_t> mem_pools_;
SmallSizePool scalar_pool_;
};
CudaAllocator& allocator();
Buffer malloc_async(size_t size, cudaStream_t stream);
Buffer malloc_async(size_t size, CommandEncoder& encoder);
} // namespace mlx::core::cu

View File

@@ -42,7 +42,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {

View File

@@ -143,7 +143,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);

View File

@@ -367,9 +367,8 @@ void binary_op_gpu(
auto bopt = get_binary_op_type(a, b);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(
a, b, out, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -246,12 +246,10 @@ void binary_two_op_gpu_inplace(
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(
a, b, out_a, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
set_binary_op_output_data(
a, b, out_b, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
if (out_a.size() == 0) {
return;

View File

@@ -298,7 +298,7 @@ void Compiled::eval_gpu(
// Put outputs.
compiled_allocate_outputs(
inputs, outputs, is_constant_, contiguous, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
return cu::malloc_async(n, encoder);
});
for (auto& x : outputs) {
args.append(x);

View File

@@ -15,19 +15,16 @@ namespace mlx::core {
namespace {
// Alias for better readability.
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
#define CONV_BACKWARD_INPUT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
#define CONV_BACKWARD_WEIGHT \
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
// Custom placeholder representing fallback kernel.
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
enum ConvBackendType {
CONV_FALLBACK,
CONV_FORWARD,
CONV_BACKWARD_INPUT,
CONV_BACKWARD_WEIGHT,
};
struct ConvCacheKey {
int device_id;
cudnnDataType_t cudnn_dtype;
fe::DataType_t cudnn_dtype;
std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> weight_shape;
std::array<int, MAX_NDIM> stride;
@@ -44,15 +41,13 @@ struct ConvCacheKey {
auto& conv_cache() {
static LRUBytesKeyCache<
ConvCacheKey,
std::pair<
cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>>
std::pair<ConvBackendType, std::optional<DnnGraph>>>
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
return cache;
}
auto get_conv_op_settings(
cudnnBackendDescriptorType_t backend_type,
auto get_conv_settings(
ConvBackendType backend_type,
array& x,
array& w,
array& y,
@@ -68,8 +63,8 @@ auto get_conv_op_settings(
for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_lo[i] - 1;
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
int in_size = 1 + kernel_strides[i] * (y.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (x.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + padding_hi[i];
}
return std::make_tuple(
@@ -95,49 +90,57 @@ auto get_conv_op_settings(
}
}
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
std::optional<DnnGraph> build_conv_graph(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
ConvBackendType backend_type,
Dtype dtype,
array& x,
array& w,
array& y,
const SmallVector<int64_t>& stride,
const SmallVector<int64_t>& padding_lo,
const SmallVector<int64_t>& padding_hi,
const SmallVector<int64_t>& dilation) {
try {
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
? CUDNN_DATA_FLOAT
: dtype_to_cudnn_type(dtype);
auto conv_desc = cudnn_frontend::ConvDescBuilder()
.setDataType(compute_dtype)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding_lo,
const std::vector<int64_t>& padding_hi,
const std::vector<int64_t>& dilation) {
auto compute_dtype =
(dtype == float16 || dtype == bfloat16) ? float32 : dtype;
DnnGraph graph(encoder.device().cudnn_handle(), dtype, compute_dtype);
auto x_ = graph.tensor_nchw("X", 'x', x);
auto w_ = graph.tensor_nchw("W", 'w', w);
auto op = cudnn_frontend::OperationBuilder(backend_type)
.setxDesc(build_cudnn_tensor_nchw('x', x))
.setwDesc(build_cudnn_tensor_nchw('w', w))
.setyDesc(build_cudnn_tensor_nchw('y', y))
.setcDesc(conv_desc)
.build();
auto set_options = [&](auto& options) {
options.set_compute_data_type(dtype_to_cudnn_type(compute_dtype))
.set_convolution_mode(fe::ConvolutionMode_t::CROSS_CORRELATION)
.set_stride(stride)
.set_pre_padding(padding_lo)
.set_post_padding(padding_hi)
.set_dilation(dilation);
};
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
return cudnn_frontend::OperationGraphBuilder()
.setHandle(encoder.device().cudnn_handle())
.setOperationGraph(ops.size(), ops.data())
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
throw;
}
std::shared_ptr<fe::graph::Tensor_attributes> y_;
if (backend_type == CONV_FORWARD) {
auto options = fe::graph::Conv_fprop_attributes();
set_options(options);
y_ = graph.conv_fprop(x_, w_, options);
} else if (backend_type == CONV_BACKWARD_INPUT) {
auto options = fe::graph::Conv_dgrad_attributes();
set_options(options);
y_ = graph.conv_dgrad(x_, w_, options);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
auto options = fe::graph::Conv_wgrad_attributes();
set_options(options);
y_ = graph.conv_wgrad(w_, x_, options);
}
graph.tensor_nchw(y_, 'y', y)->set_output(true);
if (graph.prepare().is_bad()) {
return std::nullopt;
}
graph.deselect_numeric_notes({fe::NumericalNote_t::DOWN_CONVERT_INPUTS});
if (dtype == float32 && !env::enable_tf32()) {
graph.deselect_numeric_notes({fe::NumericalNote_t::TENSOR_CORE});
}
CHECK_CUDNN_FE_ERROR(graph.build());
return graph;
}
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
@@ -181,7 +184,7 @@ array group_transpose(
// eval_gpu, with cost of possible redundant copies.
std::tuple<array, array, array> prepare_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
ConvBackendType backend_type,
array in,
array wt,
array out,
@@ -221,27 +224,11 @@ std::tuple<array, array, array> prepare_args(
return {std::move(in), std::move(wt), std::move(out)};
}
// Get the x/w/y args from the in/wt/out args depending on backend type.
inline std::tuple<array&, array&, array&> dispatch_args(
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& out) {
switch (backend_type) {
case CONV_BACKWARD_INPUT:
return {out, wt, in};
case CONV_BACKWARD_WEIGHT:
return {in, out, wt};
default:
return {in, wt, out};
}
}
// Register inputs and outputs before actually running conv op. Can only be
// called once per eval_gpu.
void register_args(
cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type,
ConvBackendType backend_type,
array& in,
array& wt,
array& intermediate_out,
@@ -277,7 +264,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
Dtype dtype = out.dtype();
// Search cache.
@@ -297,16 +284,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
get_alignment(wt),
get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
auto& [backend_type, plan] = it->second;
if (plan) {
// Run cached plan.
auto& [backend_type, graph] = it->second;
if (graph) {
// Run cached graph.
std::tie(in, wt, out) =
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
throw std::runtime_error("[conv] Cached plan failed to execute.");
}
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
encoder,
{
{'x', gpu_ptr<void>(in)},
{'w', gpu_ptr<void>(wt)},
{'y', gpu_ptr<void>(out)},
}));
} else {
// Run fallback kernel.
gemm_conv(
@@ -327,7 +317,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
// There is no reliable way to deduce the proper cuDNN backend for the
// convolution, so we make a best guess and then try.
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
SmallVector<ConvBackendType, 2> try_backends;
if (flip_) {
// When weight is flipped, we assume it is backward input convolution.
try_backends.push_back(CONV_BACKWARD_INPUT);
@@ -345,13 +335,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
}
// Try to build op graph.
cudnnBackendDescriptorType_t backend_type;
std::optional<cudnn_frontend::OperationGraph> op_graph;
ConvBackendType backend_type;
std::optional<DnnGraph> graph;
for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] =
auto [x, w, y] =
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings(
try_backend,
x,
w,
@@ -361,7 +350,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
padding_hi_,
kernel_dilation_,
input_dilation_);
op_graph = build_conv_op_graph(
graph = build_conv_graph(
encoder,
try_backend,
dtype,
@@ -372,30 +361,27 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
padding_lo,
padding_hi,
dilation);
if (op_graph) {
if (graph) {
backend_type = try_backend;
in = std::move(in_copy);
wt = std::move(wt_copy);
out = std::move(out_copy);
in = std::move(x);
wt = std::move(w);
out = std::move(y);
break;
}
}
if (op_graph) {
// Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (plan) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
}
}
if (graph) {
register_args(encoder, backend_type, in, wt, out, out_);
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
encoder,
{
{'x', gpu_ptr<void>(in)},
{'w', gpu_ptr<void>(wt)},
{'y', gpu_ptr<void>(out)},
}));
conv_cache().emplace(
cache_key, std::make_pair(backend_type, std::move(*graph)));
return;
}
// Use fallback kernel for settings not supported by cuDNN.

View File

@@ -86,7 +86,7 @@ array unfold_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
encoder.add_temporary(unfolded);
int filter_size = params.C;

View File

@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
encoder.add_temporary(unfolded);
int filter_size = params.C;

View File

@@ -7,9 +7,8 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
auto& encoder = cu::get_command_encoder(s);
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
bool donated = set_copy_output_data(
in, out, ctype, [&](auto n) { return cu::malloc_async(n, encoder); });
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
@@ -104,7 +103,7 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
return;
}
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_input_array(in);
encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
@@ -114,7 +113,7 @@ void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
copy_gpu_inplace(
in,
out,

View File

@@ -95,11 +95,14 @@ void copy_general_input(
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size();
int work_per_thread = 1;
int work_per_thread = 8;
auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0;
if (dim0 >= 4) {
if (dim0 >= 4 && dim0 < 8) {
work_per_thread = 4;
} else if (dim0 < 4) {
work_per_thread = 1;
}
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1);
@@ -110,7 +113,10 @@ void copy_general_input(
dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) {
if (work_per_thread == 8) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
} else if (work_per_thread == 4) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
}
@@ -127,7 +133,9 @@ void copy_general_input(
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) {
if (work_per_thread == 8) {
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
} else if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
}
encoder.add_kernel_node(

View File

@@ -5,6 +5,7 @@
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cudnn.h>
namespace mlx::core {
@@ -12,10 +13,12 @@ namespace mlx::core {
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
void check_cudnn_error(const char* name, cudnnStatus_t err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
@@ -29,6 +32,10 @@ class CudaHandle {
}
~CudaHandle() {
// Skip if there was an error to avoid throwing in the destructors
if (cudaPeekAtLastError() != cudaSuccess) {
return;
}
reset();
}

View File

@@ -7,32 +7,26 @@ namespace mlx::core {
namespace {
// Create a cudnn tensor descriptor.
template <typename Vec>
inline cudnn_frontend::Tensor build_cudnn_tensor(
int64_t id,
const array& x,
const Vec& shape,
const Vec& strides) {
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
#define RETURN_IF_ERROR(cmd) \
if (auto ret = cmd; ret.is_bad()) { \
return ret; \
}
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
// whether a tensor is contiguous is determined with:
// shape[dim] == shape[dim + 1] * strides[dim + 1]
// So a contiguous array with singleton dims in MLX may be mistakenly treated
// as strided in cuDNN, and we work around it by normalizing the strides.
Strides normalized_strides(const array& x) {
if (!x.flags().row_contiguous || x.ndim() < 2) {
return x.strides();
std::vector<int64_t> normalized_strides(const array& x) {
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
if (std::all_of(
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
strides.back() = 1;
return strides;
}
if (!x.flags().row_contiguous || x.ndim() < 2) {
return strides;
}
Strides strides = x.strides();
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
@@ -42,7 +36,9 @@ Strides normalized_strides(const array& x) {
}
// Return the shape and strides after transposing from NHWC to NCHW.
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
inline auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
auto strides = normalized_strides(x);
assert(shape.size() >= 3);
shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1);
@@ -51,228 +47,95 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
return std::make_tuple(std::move(shape), std::move(strides));
}
inline auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(
convert_vector<int64_t>(x.shape()), normalized_strides(x));
}
// Return available engines for a |op_graph|.
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = true) {
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
sources.push_back([](auto& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
});
if (use_fallback) {
sources.push_back([&backend_type](auto& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
});
}
auto configs =
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
// Take |engine_configs| and |op_graph| and find a working execution plans
// from them.
std::optional<cudnn_frontend::ExecutionPlan>
find_cudnn_plan_from_engine_configs(
cudnnHandle_t handle,
const cudnn_frontend::EngineConfigList& engine_configs,
const cudnn_frontend::OperationGraph& op_graph) {
auto op_graph_tag = op_graph.getTag();
for (const auto& config : engine_configs) {
try {
return cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config, op_graph_tag)
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return std::nullopt;
}
// Prepare workspace and args to execute plan.
template <typename F>
bool prepare_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder.stream()),
{workspace_size},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
return false;
}
return true;
}
} // namespace
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
fe::error_t DnnGraph::prepare() {
RETURN_IF_ERROR(validate());
try {
RETURN_IF_ERROR(build_operation_graph(handle_));
} catch (cudnn_frontend::cudnnException& error) {
// cuDNN bug: they did not catch all exceptions in the API.
return {fe::error_code_t::CUDNN_BACKEND_API_FAILED, error.what()};
}
RETURN_IF_ERROR(create_execution_plans({fe::HeurMode_t::A}));
return {};
}
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
fe::error_t DnnGraph::build() {
RETURN_IF_ERROR(check_support(handle_));
RETURN_IF_ERROR(build_plans(handle_));
return {};
}
fe::error_t DnnGraph::encode_graph(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack) {
cudnnSetStream(handle_, encoder.stream());
CudaGraph cuda_graph(encoder.device());
RETURN_IF_ERROR(populate_cuda_graph(
handle_, variant_pack, prepare_workspace(encoder), cuda_graph));
encoder.add_graph_node(cuda_graph);
return {};
}
fe::error_t DnnGraph::encode_capturing(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack) {
auto* workspace_ptr = prepare_workspace(encoder);
auto capture = encoder.capture_context();
cudnnSetStream(handle_, encoder.stream());
auto ret = execute(handle_, variant_pack, workspace_ptr);
if (ret.is_bad()) {
capture.discard = true;
}
return ret;
}
void* DnnGraph::prepare_workspace(cu::CommandEncoder& encoder) {
int64_t workspace_size = 0;
CHECK_CUDNN_FE_ERROR(get_workspace_size(workspace_size));
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder),
{static_cast<int>(workspace_size)},
uint8);
encoder.add_temporary(workspace);
return gpu_ptr<void>(workspace);
}
return nullptr;
}
void DnnGraph::set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& strides) {
tensor->set_uid(uid)
.set_alignment(get_alignment(x))
.set_data_type(dtype_to_cudnn_type(x.dtype()))
.set_dim(shape)
.set_stride(strides);
}
void DnnGraph::set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x) {
set_tensor_attrs(
tensor,
uid,
x,
convert_vector<int64_t>(x.shape()),
normalized_strides(x));
}
void DnnGraph::set_tensor_attrs_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x) {
auto [shape, strides] = nhwc_to_nchw(x);
return build_cudnn_tensor(id, x, shape, strides);
set_tensor_attrs(tensor, uid, x, shape, strides);
}
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
if (x.ndim() == 0) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
}
if (x.ndim() == 1) {
int64_t s = x.shape(0);
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
SmallVector<int64_t, 4> strides = {s, 1, s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 2) {
int64_t s =
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 3 || x.ndim() == 4) {
return build_cudnn_tensor_nchw(id, x);
}
throw std::runtime_error(
fmt::format("Unsupported array with {} dims.", x.ndim()));
}
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return cudnn_frontend::TensorBuilder()
.setDim(scalar_dims.size(), scalar_dims.data())
.setStrides(scalar_dims.size(), scalar_dims.data())
.setId(id)
.setAlignment(16)
.setDataType(dtype_to_cudnn_type(dtype))
.setByValue(true)
.build();
}
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
if (engine_configs.empty()) {
return std::nullopt;
}
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
}
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
auto capture = encoder.capture_context();
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
return true;
});
}
#if CUDNN_VERSION >= 90500
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
if (!graph) {
graph = CudaGraph(encoder.device());
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
} else {
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
}
encoder.add_graph_node(graph);
return true;
});
}
#endif
} // namespace mlx::core

View File

@@ -2,25 +2,30 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
#include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h>
#include <algorithm>
#include <array>
namespace mlx::core {
namespace cu {
class CommandEncoder;
}
namespace fe = cudnn_frontend;
#define CHECK_CUDNN_FE_ERROR(cmd) \
do { \
auto error = cmd; \
if (!error.is_good()) { \
throw std::runtime_error( \
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
} \
} while (0)
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
@@ -35,8 +40,31 @@ inline uint8_t get_alignment(const array& x) {
// Convert the type of elements in |vec| to |T|.
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
inline std::vector<T> convert_vector(const Vec& vec) {
return std::vector<T>(vec.begin(), vec.end());
}
// Map dtype to cudnn data type.
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return fe::DataType_t::INT8;
case int32:
return fe::DataType_t::INT32;
case uint8:
return fe::DataType_t::UINT8;
case float16:
return fe::DataType_t::HALF;
case bfloat16:
return fe::DataType_t::BFLOAT16;
case float32:
return fe::DataType_t::FLOAT;
case float64:
return fe::DataType_t::DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype)));
}
}
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
@@ -55,111 +83,89 @@ inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
return result;
}
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(gpu_ptr<void>(arr));
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
inline void* get_data_ptr(T& scalar) {
return &scalar;
}
// Return an array filled with data pointers of args.
template <typename... Args>
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
return {get_data_ptr(args)...};
}
// Map dtype to cudnn data type.
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
// Extends cuDNN graph with helpers.
class DnnGraph : public fe::graph::Graph {
public:
DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32)
: handle_(handle) {
set_io_data_type(dtype_to_cudnn_type(io_dtype));
set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype));
set_compute_data_type(dtype_to_cudnn_type(compute_dtype));
}
}
// Create a tensor descriptor from |x|.
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
// Create a cuDNN tensor description from MLX array |x|.
auto& tensor(
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
int64_t uid,
const array& x) {
set_tensor_attrs(attrs, uid, x);
return attrs;
}
auto tensor(const char* name, int64_t uid, const array& x) {
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
tensor(attrs, uid, x);
return attrs;
}
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
// Create a cuDNN tensor description from MLX array |x|, and transpose it from
// NHWC layout to NCHW.
auto& tensor_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
int64_t uid,
const array& x) {
set_tensor_attrs_nchw(attrs, uid, x);
return attrs;
}
auto tensor_nchw(const char* name, int64_t uid, const array& x) {
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
tensor_nchw(attrs, uid, x);
return attrs;
}
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
// from NHWC to NCHW.
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
// Create a cuDNN tensor for scalar.
auto scalar(const char* name, int64_t uid, Dtype dtype) {
return Graph::tensor(fe::graph::Tensor_attributes()
.set_name(name)
.set_uid(uid)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(dtype_to_cudnn_type(dtype)));
}
// Create a 4D scalar tensor descriptor, which is passed by value.
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
// Call this before setting notes.
fe::error_t prepare();
// Call this after setting notes.
fe::error_t build();
// Find a working plan for |op_graph|.
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph);
// Add cuDNN graph to CUDA graph, using native CUDA graph API.
fe::error_t encode_graph(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack);
// Add cuDNN graph to CUDA graph, using stream capture.
fe::error_t encode_capturing(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack);
// Encode the plan to command buffer by capturing.
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs);
private:
void* prepare_workspace(cu::CommandEncoder& encoder);
#if CUDNN_VERSION >= 90500
// Encode the plan to command buffer by using native graph api of cudnn. If the
// |graph| is empty it will be populated, otherwise it will be updated.
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs);
#endif
void set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& strides);
void set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x);
void set_tensor_attrs_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x);
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_capturing(
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
}
#if CUDNN_VERSION >= 90500
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_graph_api(
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
}
#endif
cudnnHandle_t handle_;
};
} // namespace mlx::core

View File

@@ -289,7 +289,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
}
}

View File

@@ -1,13 +0,0 @@
#!/bin/bash
arch=`__nvcc_device_query`
case "$arch" in
"90")
echo "90a" ;;
"100")
echo "100a" ;;
"121")
echo "121a" ;;
*)
echo "native" ;;
esac

View File

@@ -14,20 +14,20 @@ namespace mlx::core::cu {
namespace {
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
}
bool use_cuda_graphs() {
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
return use_graphs;
}
bool use_cuda_graphs() {
static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
const char* save_cuda_graphs_dot_file() {
static const char* filename = []() -> const char* {
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
if (env && std::strlen(env) == 0) {
return nullptr;
}
return env;
}();
return use_graphs;
return filename;
}
} // namespace
@@ -87,7 +87,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
return;
}
CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeThreadLocal));
}
CommandEncoder::CaptureContext::~CaptureContext() {
@@ -115,18 +115,17 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
}
// Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
enc.empty_node_count_++;
CommandEncoder::GraphNode empty{NULL, "E", std::to_string(enc.node_count_++)};
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
// Insert the concurrent -> empty node dependencies
for (auto& from : enc.concurrent_nodes_) {
enc.from_nodes_.push_back(from.node);
enc.to_nodes_.push_back(empty.node);
enc.graph_key_ += from.id;
enc.graph_key_ += from.node_type;
enc.graph_key_ += empty.id;
enc.graph_key_ += empty.node_type;
enc.graph_deps_key_ += from.id;
enc.graph_deps_key_ += "-";
enc.graph_deps_key_ += empty.id;
enc.graph_deps_key_ += "-";
}
// Insert the input -> concurrent node dependencies without updating output
@@ -141,9 +140,6 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
}
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
if (node.node_type == 'G') {
graph_node_count_++;
}
node.id = std::to_string(node_count_++);
if (in_concurrent_) {
concurrent_nodes_.push_back(std::move(node));
@@ -155,6 +151,10 @@ void CommandEncoder::insert_graph_dependencies(GraphNode node) {
}
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
for (auto& node : nodes) {
graph_nodes_key_ += node.node_type;
graph_nodes_key_ += "-";
}
std::vector<GraphNode> deps;
{
// Dependencies must be added in the same order to produce a consistent
@@ -182,10 +182,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
for (auto& to : nodes) {
from_nodes_.push_back(from.node);
to_nodes_.push_back(to.node);
graph_key_ += from.id;
graph_key_ += from.node_type;
graph_key_ += to.id;
graph_key_ += to.node_type;
graph_deps_key_ += from.id;
graph_deps_key_ += "-";
graph_deps_key_ += to.id;
graph_deps_key_ += "-";
}
}
}
@@ -309,13 +309,61 @@ void CommandEncoder::add_kernel_node(
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
insert_graph_dependencies(GraphNode{node, "K"});
}
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
CUgraphNode node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
insert_graph_dependencies(GraphNode{node, "K"});
}
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
// Constructs a key representing the nodes of a sub-graph.
// Also checks if the sub-graph is updatable as CUDA graphs do not get
// updated correctly if a kernel node getting updated has a different cluster
// shape than the node it's being updated with.
std::string key = "(";
size_t num_nodes = 0;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) {
return {key + ")", true};
}
bool is_updatable = true;
std::vector<cudaGraphNode_t> nodes(num_nodes);
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
for (const auto& node : nodes) {
if (!is_updatable) {
break;
}
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeGraph) {
// Try to be updatable for a structure like graph -> graph -> kernel
cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
is_updatable &= sub_is_updatable;
key += subkey;
} else if (type == cudaGraphNodeTypeMemset) {
key += "M";
} else if (type != cudaGraphNodeTypeKernel) {
is_updatable = false;
} else {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only allow dim.x to be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
}
}
}
key += ")";
return {key, is_updatable};
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
@@ -328,8 +376,10 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
return;
}
cudaGraphNode_t node;
auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
is_graph_updatable_ &= is_updatable;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'});
insert_graph_dependencies(GraphNode{node, sub_graph_key});
}
bool CommandEncoder::needs_commit() {
@@ -354,44 +404,53 @@ void CommandEncoder::commit() {
from_nodes_.size()));
}
graph_key_ += ".";
graph_key_ += std::to_string(node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(graph_node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_);
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graph_exec, graph_, &info);
update_result = info.result;
#else
cudaGraphNode_t error_node;
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
graph_exec.reset();
}
}
if (graph_exec == nullptr) {
graph_exec.instantiate(graph_);
}
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
if (!is_graph_updatable_) {
CudaGraphExec graph_exec;
graph_exec.instantiate(graph_);
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
} else {
auto graph_key = graph_nodes_key_ + ":" + graph_deps_key_;
auto& graph_exec = graph_cache_[graph_key];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graph_exec, graph_, &info);
update_result = info.result;
#else
cudaGraphNode_t error_node;
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
graph_exec.reset();
}
}
if (graph_exec == nullptr) {
graph_exec.instantiate(graph_);
}
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
}
// Save cuda graph to dot file
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
static int count = 0;
auto path = fmt::format("{}_{}.dot", filename, ++count);
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
}
// Reset state
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
to_nodes_.clear();
graph_key_.clear();
graph_deps_key_.clear();
graph_nodes_key_.clear();
node_map_.clear();
graph_ = CudaGraph(device_);
is_graph_updatable_ = true;
}
// Put completion handlers in a batch.

View File

@@ -106,8 +106,9 @@ class CommandEncoder {
cudaGraphNode_t node;
// K = kernel
// E = empty
// G = subgraph
char node_type;
// () = subgraph (with metadata)
// Symbols ':', '-' are reserved as separators
std::string node_type;
std::string id;
};
@@ -119,12 +120,11 @@ class CommandEncoder {
CudaGraph graph_;
Worker worker_;
char node_count_{0};
char graph_node_count_{0};
char empty_node_count_{0};
bool in_concurrent_{false};
std::vector<cudaGraphNode_t> from_nodes_;
std::vector<cudaGraphNode_t> to_nodes_;
std::string graph_key_;
std::string graph_nodes_key_;
std::string graph_deps_key_;
std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_;
LRUCache<std::string, CudaGraphExec> graph_cache_;
@@ -132,6 +132,7 @@ class CommandEncoder {
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
size_t bytes_in_graph_{0};
bool is_graph_updatable_{true};
int max_ops_per_graph_;
int max_mb_per_graph_;
};

View File

@@ -26,7 +26,7 @@ void AllReduce::eval_gpu(
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
return {in, out};
}
};
@@ -74,7 +74,7 @@ void AllGather::eval_gpu(
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);
@@ -103,7 +103,7 @@ void ReduceScatter::eval_gpu(
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);

View File

@@ -305,6 +305,7 @@ void Event::wait() {
} else {
event->atomic->wait(value());
}
CHECK_CUDA_ERROR(cudaPeekAtLastError());
}
void Event::wait(Stream s) {

View File

@@ -370,7 +370,7 @@ void CublasGemm::execute(
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
cu::malloc_async(nbytes, encoder.stream()),
cu::malloc_async(nbytes, encoder),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);

View File

@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder),
{batch_count * 3},
uint64);
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder),
{batch_count * 4},
uint64);

View File

@@ -61,7 +61,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) {
return;
}
@@ -241,7 +241,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) {
return;
}

View File

@@ -244,7 +244,7 @@ void LayerNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(),
x.strides(),
x.flags());
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
g_in_gw = true;
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
encoder.add_temporary(gw_temp);
}
}

View File

@@ -32,7 +32,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream());
auto size = out.size();
auto nbytes = size * out.itemsize();
out.set_data(cu::malloc_async(nbytes, encoder.stream()));
out.set_data(cu::malloc_async(nbytes, encoder));
auto out_ptr = malloc(nbytes);
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
if (swap_endianness_) {

View File

@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
}
flags.col_contiguous = col_contig;
out.set_data(
cu::malloc_async(in.nbytes() / n, encoder.stream()),
cu::malloc_async(in.nbytes() / n, encoder),
in.data_size() / n,
std::move(strides),
flags);

View File

@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
gemm_and_bias(
encoder,
M,
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);

View File

@@ -37,6 +37,7 @@ NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
NO_GPU(MaskedScatter)
namespace distributed {
NO_GPU_MULTI(Send)

View File

@@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu(
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto& w = outputs[0];
w.set_data(cu::malloc_async(w.nbytes(), enc.stream()));
w.set_data(cu::malloc_async(w.nbytes(), enc));
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[2], enc, s);
@@ -72,11 +72,11 @@ void fast::Quantize::eval_gpu(
auto& wq = outputs[0];
auto& scales = outputs[1];
wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream()));
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream()));
wq.set_data(cu::malloc_async(wq.nbytes(), enc));
scales.set_data(cu::malloc_async(scales.nbytes(), enc));
if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2];
biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream()));
biases.set_data(cu::malloc_async(biases.nbytes(), enc));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
} else {
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);

View File

@@ -139,30 +139,36 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
uint32_t num_keys = keys.size() / 2;
size_t num_keys = keys.size() / 2;
uint32_t elems_per_key = out.size() / num_keys;
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) {
return;
}
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
uint32_t half_size = out_per_key / 2;
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
size_t half_size = out_per_key / 2;
bool odd = out_per_key % 2;
if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {
throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported");
}
encoder.set_input_array(keys);
encoder.set_output_array(out);
dim3 grid_dims{num_keys, half_size + odd};
int64_t total = grid_dims.x * grid_dims.y;
int32_t threads_y = 1;
while ((total / threads_y) >= (1U << 31)) {
int64_t total = num_keys * (half_size + odd);
uint32_t threads_y = 1;
while ((total / threads_y) >= UINT_MAX) {
threads_y *= 2;
}
int32_t threads_x = cuda::ceil_div(total, threads_y);
uint32_t threads_x = cuda::ceil_div(total, threads_y);
dim3 grid_dims{
static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
auto& stream = encoder.stream();
if (keys.flags().row_contiguous) {

View File

@@ -66,7 +66,7 @@ void all_reduce(
Reduce::ReduceType reduce_type) {
constexpr int N_READS = 8;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
auto get_args = [](size_t size, int N) {
int threads = std::min(512UL, (size + N - 1) / N);
@@ -107,8 +107,7 @@ void all_reduce(
encoder.set_input_array(in);
if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate);
dispatch_all_types(dt, [&](auto type_tag) {

View File

@@ -89,9 +89,13 @@ template <
int NDIM,
int BM,
int BN,
int N_READS = 4>
__global__ void
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
int N_READS = 4,
int BLOCKS = 1>
__global__ void col_reduce_looped(
T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
int64_t out_size) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
size_t tile_out = tile_y / out_size;
tile_y = tile_y % out_size;
// Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row;
@@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
totals[i] = ReduceInit<Op, T>::value();
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size;
size_t per_block, start, end;
if constexpr (BLOCKS > 1) {
per_block = (total + BLOCKS - 1) / BLOCKS;
start = tile_out * per_block + thread_y;
end = min((tile_out + 1) * per_block, total);
} else {
per_block = total;
start = thread_y;
end = total;
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
if (tile_x * BN + BN <= args.reduction_stride) {
if (args.reduction_stride % N_READS == 0) {
for (size_t r = thread_y; r < total; r += BM) {
for (size_t r = start; r < end; r += BM) {
T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
@@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
for (size_t r = start; r < end; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) {
@@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
}
}
} else {
for (size_t r = thread_y; r < total; r += BM) {
for (size_t r = start; r < end; r += BM) {
T vals[N_READS];
cub::LoadDirectBlocked(
thread_x,
@@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
// Write result.
if (warp.thread_rank() == 0) {
if (BLOCKS > 1) {
out += tile_out * out_size * args.reduction_stride;
}
cub::StoreDirectBlocked(
warp.meta_group_rank(),
out + tile_y * args.reduction_stride + tile_x * BN,
@@ -227,11 +247,12 @@ __global__ void col_reduce_small(
inline auto output_grid_for_col_reduce(
const array& out,
const cu::ColReduceArgs& args,
int bn) {
int bn,
int outer = 1) {
int gx, gy = 1;
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
size_t n_outer_blocks = out.size() / args.reduction_stride;
size_t n_blocks = n_outer_blocks * n_inner_blocks;
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
while (n_blocks / gy > INT32_MAX) {
gy *= 2;
}
@@ -277,7 +298,8 @@ void col_reduce_looped(
0,
indata,
gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args));
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
@@ -320,6 +342,117 @@ void col_reduce_small(
});
}
void col_reduce_two_pass(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes, encoder);
// Allocate an intermediate array to hold the 1st pass result
constexpr int outer = 32;
Shape intermediate_shape;
intermediate_shape.push_back(outer);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
Strides intermediate_strides;
intermediate_strides.push_back(out.size());
intermediate_strides.insert(
intermediate_strides.end(), out.strides().begin(), out.strides().end());
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
auto [data_size, rc, cc] =
check_contiguity(intermediate_shape, intermediate_strides);
auto fl = out.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = true;
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder),
data_size,
intermediate_strides,
fl,
allocator::free);
encoder.add_temporary(intermediate);
encoder.set_input_array(in);
encoder.set_output_array(intermediate);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
int blocks = BM * BN / N_READS;
auto kernel = cu::
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
indata,
gpu_ptr<U>(intermediate),
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
// Prepare the reduction arguments for the 2nd pass
cu::ColReduceArgs second_args = args;
second_args.reduction_size = outer;
second_args.reduction_stride = out.size();
second_args.ndim = 0;
second_args.reduce_shape[0] = outer;
second_args.reduce_strides[0] = out.size();
second_args.reduce_ndim = 1;
second_args.non_col_reductions = 1;
encoder.set_input_array(intermediate);
encoder.set_output_array(out);
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
int blocks = BM * BN / N_READS;
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
gpu_ptr<T>(intermediate),
gpu_ptr<U>(out),
second_args,
second_args.reduction_stride);
});
});
});
}
void col_reduce(
cu::CommandEncoder& encoder,
const array& in,
@@ -334,6 +467,18 @@ void col_reduce(
// It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements.
//
// - col_reduce_small
//
// It is a column reduce for small columns. Each thread loops over the whole
// column without communicating with any other thread.
//
// - col_reduce_two_pass
//
// It is a reduce for long columns. To increase parallelism, we split the
// reduction in two passes. First we do a column reduce where many
// threadblocks operate on different parts of the reduced axis. Then we
// perform a final column reduce.
//
// Notes: As in row reduce we opt to read as much in order as possible and
// leave transpositions as they are (contrary to our Metal backend).
//
@@ -349,6 +494,14 @@ void col_reduce(
return;
}
// Long column with smallish row
size_t total_sums = args.non_col_reductions * args.reduction_size;
size_t approx_threads = out.size();
if (total_sums / approx_threads > 32) {
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
}

View File

@@ -28,7 +28,7 @@ void init_reduce(
Reduce::ReduceType reduce_type) {
// Allocate if needed
if (out.data_shared_ptr() == nullptr) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
}
encoder.set_output_array(out);

View File

@@ -96,7 +96,7 @@ inline void allocate_same_layout(
const std::vector<int>& axes,
cu::CommandEncoder& encoder) {
if (in.flags().row_contiguous) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
return;
}
@@ -135,7 +135,7 @@ inline void allocate_same_layout(
fl.col_contiguous = cc;
fl.contiguous = true;
out.set_data(
cu::malloc_async(out.nbytes(), encoder.stream()),
cu::malloc_async(out.nbytes(), encoder),
data_size,
final_strides,
fl,

View File

@@ -22,26 +22,28 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
}
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
template <typename T, int BLOCK_DIM>
template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>
struct BlockBroadcastReduce {
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
static_assert(BLOCK_DIM % WARP_SIZE == 0);
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];
cg::thread_block& block;
TempStorage& temp;
template <typename Op>
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
auto warp = cg::tiled_partition<WARP_SIZE>(block);
auto warp = cg::tiled_partition<GROUP_DIM>(block);
T x = cg::reduce(warp, input, op);
if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
if constexpr (BLOCK_DIM > GROUP_DIM) {
if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
} else {
return x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
}
__device__ T Sum(const T& input) {
@@ -49,6 +51,52 @@ struct BlockBroadcastReduce {
}
};
template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>
__global__ void rms_norm_small(
const T* x,
const T* w,
T* out,
float eps,
uint32_t axis_size,
uint32_t n_rows,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;
__shared__ typename BlockReduceT::TempStorage temp;
auto row =
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
if (row >= n_rows) {
return;
}
x += row * axis_size;
out += row * axis_size;
// Normalizer.
float normalizer = 0;
auto index = block.thread_index().x;
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]);
normalizer += t * t;
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
normalizer = rsqrt(normalizer / axis_size + eps);
// Outputs.
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float y = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(y);
}
store_vector<N_READS>(out, index, xn, axis_size);
}
template <typename T, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm(
const T* x,
@@ -94,6 +142,74 @@ __global__ void rms_norm(
}
}
template <
typename T,
bool HAS_W,
int BLOCK_DIM,
int REDUCE_DIM,
int N_READS = 4>
__global__ void rms_norm_vjp_small(
const T* x,
const T* w,
const T* g,
T* gx,
T* gw,
float eps,
int32_t axis_size,
int32_t n_rows,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;
__shared__ typename BlockReduceF2::TempStorage temp;
auto row =
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
if (row >= n_rows) {
return;
}
x += row * axis_size;
g += row * axis_size;
gx += row * axis_size;
gw += row * axis_size;
// Normalizer.
float2 factors = {};
auto index = block.thread_index().x;
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]);
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f2(factors, {wg * t, t * t});
}
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
float meangwx = factors.x / axis_size;
float normalizer = rsqrt(factors.y / axis_size + eps);
float normalizer3 = normalizer * normalizer * normalizer;
// Outputs.
for (int i = 0; i < N_READS; i++) {
float xi = xn[i];
float wi = wn[i];
float gi = gn[i];
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
if constexpr (HAS_W) {
wn[i] = static_cast<T>(gi * xi * normalizer);
}
}
store_vector<N_READS>(gx, index, xn, axis_size);
if constexpr (HAS_W) {
store_vector<N_READS>(gw, index, wn, axis_size);
}
}
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm_vjp(
const T* x,
@@ -107,12 +223,8 @@ __global__ void rms_norm_vjp(
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
__shared__ union {
typename BlockReduceF::TempStorage f;
typename BlockReduceF2::TempStorage f2;
} temp;
__shared__ typename BlockReduceF2::TempStorage temp;
x += grid.block_rank() * axis_size;
g += grid.block_rank() * axis_size;
@@ -134,7 +246,7 @@ __global__ void rms_norm_vjp(
factors = plus_f2(factors, {wg * t, t * t});
}
}
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
float meangwx = factors.x / axis_size;
float normalizer = rsqrt(factors.y / axis_size + eps);
float normalizer3 = normalizer * normalizer * normalizer;
@@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
template <int n_per_thread, typename F>
void dispatch_group_dim(int axis_size, F&& f) {
if (axis_size <= n_per_thread * 8) {
f(std::integral_constant<int, 8>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 16>());
} else if (axis_size <= n_per_thread * 16) {
f(std::integral_constant<int, 16>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 8>());
} else if (axis_size <= n_per_thread * 32) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 4>());
} else if (axis_size <= n_per_thread * 32 * 2) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 2>(),
std::integral_constant<int, 2>());
} else if (axis_size <= n_per_thread * 32 * 4) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 4>(),
std::integral_constant<int, 1>());
} else if (axis_size <= n_per_thread * 32 * 8) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 8>(),
std::integral_constant<int, 1>());
} else if (axis_size <= n_per_thread * 32 * 16) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 16>(),
std::integral_constant<int, 1>());
} else {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 32>(),
std::integral_constant<int, 1>());
}
}
// TODO: There are duplicate code with backend/metal/normalization.cpp
void RMSNorm::eval_gpu(
const std::vector<array>& inputs,
@@ -190,7 +339,7 @@ void RMSNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(),
x.strides(),
x.flags());
@@ -216,12 +365,33 @@ void RMSNorm::eval_gpu(
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
if (axis_size <= N_READS * 1024) {
dispatch_group_dim<N_READS>(
axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
constexpr int block_dim = n_groups() * group_dim();
auto kernel =
cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;
auto n_blocks =
(n_rows + groups_per_block() - 1) / groups_per_block();
encoder.add_kernel_node(
kernel,
n_blocks,
{block_dim, groups_per_block()},
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(out),
eps_,
axis_size,
n_rows,
w_stride);
});
} else {
auto kernel = cu::rms_norm<DataType, 1024, N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
block_dim(),
1024,
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
@@ -229,7 +399,7 @@ void RMSNorm::eval_gpu(
eps_,
axis_size,
w_stride);
});
}
});
}
@@ -274,7 +444,7 @@ void RMSNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
@@ -292,7 +462,7 @@ void RMSNormVJP::eval_gpu(
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
encoder.add_temporary(gw_temp);
}
}
@@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu(
dispatch_bool(has_w, [&](auto has_w_constant) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::rms_norm_vjp<
DataType,
has_w_constant.value,
block_dim(),
N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
block_dim(),
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);
});
if (axis_size <= N_READS * 1024) {
dispatch_group_dim<N_READS>(
axis_size,
[&](auto group_dim, auto n_groups, auto groups_per_block) {
constexpr int block_dim = group_dim() * n_groups();
auto kernel = cu::rms_norm_vjp_small<
DataType,
has_w_constant.value,
block_dim,
group_dim(),
N_READS>;
auto n_blocks =
(n_rows + groups_per_block() - 1) / groups_per_block();
encoder.add_kernel_node(
kernel,
n_blocks,
{block_dim, groups_per_block()},
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
n_rows,
w_stride);
});
} else {
auto kernel =
cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
1024,
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);
}
});
});

View File

@@ -292,14 +292,14 @@ void RoPE::eval_gpu(
donated = true;
out.copy_shared_buffer(in);
} else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];

View File

@@ -5,47 +5,13 @@
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace fe = cudnn_frontend;
namespace {
#define CHECK_CUDNN_FE_ERROR(cmd) \
do { \
auto error = cmd; \
if (!error.is_good()) { \
throw std::runtime_error( \
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
} \
} while (0)
std::vector<int64_t> normalized_strides(const array& x) {
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
if (!x.flags().row_contiguous || x.ndim() < 2) {
return strides;
}
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
}
}
return strides;
}
void set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x) {
tensor->set_uid(uid)
.set_dim({x.shape().begin(), x.shape().end()})
.set_stride(normalized_strides(x));
}
array prepare_sdpa_input(const array& x, Stream s) {
// SDPA kernel's requirements on inputs:
// 1. last dim's stride be 1;
@@ -59,11 +25,43 @@ array prepare_sdpa_input(const array& x, Stream s) {
return x;
}
void malloc_with_same_layout(
cu::CommandEncoder& encoder,
array& o,
const array& q) {
if (q.flags().row_contiguous) {
o.set_data(cu::malloc_async(o.nbytes(), encoder));
return;
}
// fill_order = argsort(q.strides())
Shape fill_order(q.ndim());
std::iota(fill_order.begin(), fill_order.end(), 0);
std::stable_sort(
fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {
auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;
auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;
return s1 < s2;
});
// Generate o_strides with fill_order
Strides o_strides(q.ndim());
int64_t stride = 1;
for (int i : fill_order) {
o_strides[i] = stride;
stride *= o.shape(i);
}
// o is a transposed contiguous array
o.set_data(
cu::malloc_async(o.nbytes(), encoder),
o.size(),
o_strides,
{true, false, false});
}
constexpr int QKV_NDIM = 4;
struct SDPACacheKey {
int device_id;
cudnnDataType_t cudnn_dtype;
fe::DataType_t cudnn_dtype;
std::array<int, QKV_NDIM> q_shape;
std::array<int, QKV_NDIM> k_shape;
std::array<int, QKV_NDIM> v_shape;
@@ -71,11 +69,50 @@ struct SDPACacheKey {
std::array<int64_t, QKV_NDIM> k_strides;
std::array<int64_t, QKV_NDIM> v_strides;
bool do_causal;
std::array<int, QKV_NDIM> mask_shape;
std::array<int64_t, QKV_NDIM> mask_strides;
bool output_logsumexp;
};
inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp = true) {
BytesKey<SDPACacheKey> cache_key;
cache_key.pod = {
encoder.device().cuda_device(),
dtype_to_cudnn_type(q.dtype()),
vector_key<QKV_NDIM>(q.shape()),
vector_key<QKV_NDIM>(k.shape()),
vector_key<QKV_NDIM>(v.shape()),
vector_key<QKV_NDIM>(q.strides()),
vector_key<QKV_NDIM>(k.strides()),
vector_key<QKV_NDIM>(v.strides()),
do_causal,
{},
{},
output_logsumexp,
};
if (mask_arr) {
cache_key.pod.mask_shape = vector_key<QKV_NDIM>(mask_arr->shape());
cache_key.pod.mask_strides = vector_key<QKV_NDIM>(mask_arr->strides());
}
return cache_key;
}
auto& sdpa_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 128);
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
return cache;
}
auto& sdpa_backward_cache() {
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
return cache;
}
@@ -84,59 +121,106 @@ enum UIDS {
K,
V,
SCALE,
BIAS,
O,
STATS,
// Backward graph:
D_Q,
D_K,
D_V,
D_O,
};
fe::graph::Graph build_sdpa_graph(
DnnGraph build_sdpa_graph(
cudnnHandle_t handle,
const array& q,
const array& k,
const array& v,
bool do_causal,
const array& o) {
auto dtype = fe::DataType_t::HALF;
if (q.dtype() == bfloat16) {
dtype = fe::DataType_t::BFLOAT16;
const std::optional<array>& mask_arr,
bool output_logsumexp,
const array& o,
const array& stats) {
DnnGraph graph(handle, q.dtype());
auto q_ = graph.tensor("Q", Q, q);
auto k_ = graph.tensor("K", K, k);
auto v_ = graph.tensor("V", V, v);
auto options = fe::graph::SDPA_attributes()
.set_name("sdpa_cudnn")
.set_attn_scale(graph.scalar("Scale", SCALE, float32))
.set_generate_stats(output_logsumexp);
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) {
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
}
fe::graph::Graph graph;
graph.set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
graph.tensor(o_, O, o)->set_output(true);
if (output_logsumexp) {
graph.tensor(stats_, STATS, stats)->set_output(true);
}
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
set_tensor_attrs(q_, Q, q);
set_tensor_attrs(k_, K, k);
set_tensor_attrs(v_, V, v);
auto scale = graph.tensor(fe::graph::Tensor_attributes()
.set_name("Scale")
.set_uid(SCALE)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto sdpa_options = fe::graph::SDPA_attributes()
.set_name("sdpa_cudnn")
.set_attn_scale(scale)
.set_causal_mask(do_causal)
.set_generate_stats(false);
auto [o_, _] = graph.sdpa(q_, k_, v_, sdpa_options);
o_->set_output(true);
set_tensor_attrs(o_, O, o);
CHECK_CUDNN_FE_ERROR(graph.validate());
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
CHECK_CUDNN_FE_ERROR(graph.prepare());
graph.select_behavior_notes(
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
CHECK_CUDNN_FE_ERROR(graph.build());
return graph;
}
DnnGraph build_sdpa_backward_graph(
cudnnHandle_t handle,
const array& q,
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
const array& o,
const array& d_o,
const array& stats,
array& d_q,
array& d_k,
array& d_v) {
DnnGraph graph(handle, q.dtype());
auto q_ = graph.tensor("Q", Q, q);
auto k_ = graph.tensor("K", K, k);
auto v_ = graph.tensor("V", V, v);
auto o_ = graph.tensor("O", O, o);
auto d_o_ = graph.tensor("D_O", D_O, d_o);
auto stats_ = graph.tensor("STATS", STATS, stats);
auto options = fe::graph::SDPA_backward_attributes()
.set_name("sdpa_backward_cudnn")
.set_attn_scale(graph.scalar("Scale", SCALE, float32));
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) {
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
}
auto [d_q_, d_k_, d_v_] =
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
graph.tensor(d_q_, D_Q, d_q)->set_output(true);
graph.tensor(d_k_, D_K, d_k)->set_output(true);
graph.tensor(d_v_, D_V, d_v)->set_output(true);
CHECK_CUDNN_FE_ERROR(graph.prepare());
graph.select_behavior_notes(
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
CHECK_CUDNN_FE_ERROR(graph.build());
return graph;
}
@@ -146,7 +230,6 @@ bool supports_sdpa_cudnn(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool do_causal,
Stream s) {
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
@@ -159,19 +242,8 @@ bool supports_sdpa_cudnn(
return false;
}
if (has_mask) {
// TODO: Support array masks.
if (!do_causal) {
return false;
}
// FIXME: Causal mask generates wrong results when L_Q != L_K.
if (q.shape(2) != k.shape(2)) {
return false;
}
}
// Only use cuDNN for prefilling.
if (q.shape(2) != k.shape(2)) {
// Only use cuDNN for prefilling (T_q > 1) and training (T_q == T_kv).
if ((q.shape(2) == 1) && (q.shape(2) != k.shape(2))) {
return false;
}
@@ -191,66 +263,115 @@ void sdpa_cudnn(
const array& v,
float scale,
array& o,
array& stats,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
// TODO: Handle donation.
// TODO: Make O use same memory layout with Q.
o.set_data(cu::malloc_async(o.nbytes(), encoder.stream()));
auto handle = encoder.device().cudnn_handle();
malloc_with_same_layout(encoder, o, q);
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
if (mask_arr) {
encoder.set_input_array(*mask_arr);
}
if (output_logsumexp) {
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
encoder.set_output_array(stats);
}
// Search cache.
BytesKey<SDPACacheKey> cache_key;
cache_key.pod = {
encoder.device().cuda_device(),
dtype_to_cudnn_type(q.dtype()),
vector_key<QKV_NDIM>(q.shape()),
vector_key<QKV_NDIM>(k.shape()),
vector_key<QKV_NDIM>(v.shape()),
vector_key<QKV_NDIM>(q.strides()),
vector_key<QKV_NDIM>(k.strides()),
vector_key<QKV_NDIM>(v.strides()),
do_causal,
};
auto cache_key = build_sdpa_cache_key(
encoder, q, k, v, do_causal, mask_arr, output_logsumexp);
auto it = sdpa_cache().find(cache_key);
if (it == sdpa_cache().end()) {
it =
sdpa_cache()
.emplace(cache_key, build_sdpa_graph(handle, q, k, v, do_causal, o))
.first;
auto graph = build_sdpa_graph(
handle, q, k, v, do_causal, mask_arr, output_logsumexp, o, stats);
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
std::unordered_map<int64_t, void*> variant_pack{
{Q, const_cast<void*>(gpu_ptr<void>(q))},
{K, const_cast<void*>(gpu_ptr<void>(k))},
{V, const_cast<void*>(gpu_ptr<void>(v))},
{Q, gpu_ptr<void>(q)},
{K, gpu_ptr<void>(k)},
{V, gpu_ptr<void>(v)},
{SCALE, &scale},
{O, gpu_ptr<void>(o)}};
int64_t workspace_size = 0;
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder.stream()),
{static_cast<int>(workspace_size)},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
if (mask_arr) {
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
}
if (output_logsumexp) {
variant_pack[STATS] = gpu_ptr<void>(stats);
}
CudaGraph cuda_graph(encoder.device());
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
handle, variant_pack, workspace_ptr, cuda_graph));
encoder.add_graph_node(cuda_graph);
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
}
void sdpa_backward_cudnn(
const array& q,
const array& k,
const array& v,
float scale,
const array& o,
const array& stats,
bool do_causal,
const std::optional<array>& mask_arr,
const array& d_o,
array& d_q,
array& d_k,
array& d_v,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
malloc_with_same_layout(encoder, d_q, q);
malloc_with_same_layout(encoder, d_k, k);
malloc_with_same_layout(encoder, d_v, v);
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_input_array(o);
encoder.set_input_array(stats);
encoder.set_input_array(d_o);
encoder.set_output_array(d_q);
encoder.set_output_array(d_k);
encoder.set_output_array(d_v);
if (mask_arr) {
encoder.set_input_array(*mask_arr);
}
// Search cache.
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr);
auto it = sdpa_backward_cache().find(cache_key);
if (it == sdpa_backward_cache().end()) {
auto graph = build_sdpa_backward_graph(
handle, q, k, v, do_causal, mask_arr, o, d_o, stats, d_q, d_k, d_v);
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
std::unordered_map<int64_t, void*> variant_pack{
{Q, gpu_ptr<void>(q)},
{K, gpu_ptr<void>(k)},
{V, gpu_ptr<void>(v)},
{SCALE, &scale},
{O, gpu_ptr<void>(o)},
{STATS, gpu_ptr<void>(stats)},
{D_O, gpu_ptr<void>(d_o)},
{D_Q, gpu_ptr<void>(d_q)},
{D_K, gpu_ptr<void>(d_k)},
{D_V, gpu_ptr<void>(d_v)}};
if (mask_arr) {
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
}
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
}
// Defined in scaled_dot_product_attention.cu file.
@@ -260,7 +381,8 @@ bool supports_sdpa_vector(
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal);
bool do_causal,
bool output_logsumexp);
void sdpa_vector(
const array& q,
const array& k,
@@ -280,21 +402,25 @@ bool ScaledDotProductAttention::use_fallback(
bool has_mask,
bool has_arr_mask,
bool do_causal,
bool is_training,
bool output_logsumexp,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
return !supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal) &&
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
return !supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
!supports_sdpa_cudnn(q, k, v, do_causal, s);
}
bool ScaledDotProductAttention::supports_bool_mask() {
return false;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
std::vector<array>& outputs) {
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
auto& s = stream();
@@ -302,20 +428,79 @@ void ScaledDotProductAttention::eval_gpu(
array q = prepare_sdpa_input(inputs[0], s);
array k = prepare_sdpa_input(inputs[1], s);
array v = prepare_sdpa_input(inputs[2], s);
auto& out = outputs[0];
auto& stats = outputs[1];
bool has_mask = inputs.size() - has_sinks_ > 3;
bool has_arr_mask = has_mask && !do_causal_;
if (supports_sdpa_vector(q, k, v, has_mask, has_arr_mask, do_causal_)) {
std::optional<array> mask_arr;
if (has_arr_mask) {
mask_arr = prepare_sdpa_input(inputs[3], s);
}
if (supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
if (has_sinks_) {
sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s);
} else {
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
}
} else {
sdpa_cudnn(q, k, v, scale_, out, do_causal_, s);
sdpa_cudnn(
q,
k,
v,
scale_,
out,
stats,
do_causal_,
mask_arr,
output_logsumexp_,
s);
}
}
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
// The frontend adds a padding mask when sequence length is not a multiple of
// tile size.
if (q.shape(2) % 128 != 0) {
return true;
}
return s.device == Device::cpu;
}
void ScaledDotProductAttentionVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ScaledDotProductAttentionVJP::eval_gpu");
auto& s = stream();
assert(inputs.size() >= 6);
int primals_size = inputs.size() - 3;
bool has_arr_mask = primals_size > 3 + has_sinks_;
array q = prepare_sdpa_input(inputs[0], s);
array k = prepare_sdpa_input(inputs[1], s);
array v = prepare_sdpa_input(inputs[2], s);
array o = prepare_sdpa_input(inputs[primals_size], s);
array stats = prepare_sdpa_input(inputs[primals_size + 1], s);
array d_o = prepare_sdpa_input(inputs[primals_size + 2], s);
std::optional<array> mask_arr;
if (has_arr_mask) {
mask_arr = prepare_sdpa_input(inputs[3], s);
}
assert(outputs.size() == 3);
auto& d_q = outputs[0];
auto& d_k = outputs[1];
auto& d_v = outputs[2];
sdpa_backward_cudnn(
q, k, v, scale_, o, stats, do_causal_, mask_arr, d_o, d_q, d_k, d_v, s);
}
} // namespace fast
} // namespace mlx::core

View File

@@ -561,10 +561,9 @@ void sdpa_vector_2pass_fallback(
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream()));
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
sums.set_data(cu::malloc_async(sums.nbytes(), encoder));
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder));
encoder.add_temporary(intermediate);
encoder.add_temporary(sums);
@@ -665,7 +664,12 @@ bool supports_sdpa_vector(
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal) {
bool do_causal,
bool output_logsumexp) {
if (output_logsumexp) {
return false;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
@@ -769,7 +773,7 @@ void sdpa_vector(
};
o.set_data(
cu::malloc_async(o.nbytes(), encoder.stream()),
cu::malloc_async(o.nbytes(), encoder),
o.size(),
{str_oB, str_oH, str_oL, str_oD},
flags);

View File

@@ -374,7 +374,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(in);
} else {
out.set_data(
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
in.data_size(),
in.strides(),
in.flags());

View File

@@ -24,7 +24,7 @@ void concatenate_gpu(
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
auto strides = out.strides();
auto flags = out.flags();
@@ -89,7 +89,7 @@ array compute_dynamic_offset(
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream()));
offset.set_data(cu::malloc_async(offset.itemsize(), encoder));
}
encoder.add_temporary(offset);

View File

@@ -118,7 +118,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(x);
} else {
out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(),
x.strides(),
x.flags());

View File

@@ -49,14 +49,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array trans = swapaxes_in_eval(in, axis, last_dim);
in = contiguous_copy_gpu(trans, s);
encoder.add_temporary(in);
out = array(
cu::malloc_async(out.nbytes(), encoder.stream()),
in.shape(),
out.dtype());
out =
array(cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
encoder.add_temporary(out);
} else {
out.set_data(
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
in.data_size(),
in.strides(),
in.flags());
@@ -74,17 +72,13 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
if (argsort) {
// Indices in the sorted dimension.
array indices(
cu::malloc_async(out.nbytes(), encoder.stream()),
in.shape(),
out.dtype());
cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
encoder.add_temporary(indices);
// In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it.
array discard(
cu::malloc_async(in.nbytes(), encoder.stream()),
in.shape(),
in.dtype());
cu::malloc_async(in.nbytes(), encoder), in.shape(), in.dtype());
encoder.add_temporary(discard);
size_t size;
@@ -104,9 +98,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
stream));
array temp(
cu::malloc_async(size, encoder.stream()),
{static_cast<int>(size)},
uint8);
cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Start capturing after allocations
@@ -148,9 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
stream));
array temp(
cu::malloc_async(size, encoder.stream()),
{static_cast<int>(size)},
uint8);
cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Start capturing after allocations

View File

@@ -257,9 +257,8 @@ void ternary_op_gpu(
auto& c = inputs[2];
auto topt = get_ternary_op_type(a, b, c);
auto& encoder = cu::get_command_encoder(s);
set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_ternary_op_output_data(
a, b, c, out, topt, [&](auto n) { return cu::malloc_async(n, encoder); });
ternary_op_gpu_inplace<Op>(inputs, out, s);
}

View File

@@ -208,9 +208,8 @@ void unary_op_gpu(
const char* op,
const Stream& s) {
auto& encoder = cu::get_command_encoder(s);
set_unary_output_data(inputs[0], out, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_unary_output_data(
inputs[0], out, [&](auto n) { return cu::malloc_async(n, encoder); });
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -5,6 +5,7 @@
#include "mlx/dtype_utils.h"
#include <fmt/format.h>
#include <vector>
namespace mlx::core {
@@ -31,6 +32,13 @@ void check_cuda_error(const char* name, CUresult err) {
}
}
void check_cudnn_error(const char* name, cudnnStatus_t err) {
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
}
}
const char* dtype_to_cuda_type(const Dtype& dtype) {
switch (dtype) {
case bool_:
@@ -72,7 +80,6 @@ CudaGraph::CudaGraph(cu::Device& device) {
}
void CudaGraph::end_capture(cudaStream_t stream) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
}

View File

@@ -31,8 +31,10 @@ inline T* gpu_ptr(array& arr) {
arr.offset());
}
// For const array, keep constness in pointer unless it is untyped.
template <typename T>
inline const T* gpu_ptr(const array& arr) {
inline std::conditional_t<std::is_same_v<T, void>, void*, const T*> gpu_ptr(
const array& arr) {
return gpu_ptr<T>(const_cast<array&>(arr));
}

View File

@@ -7,8 +7,6 @@
namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}

View File

@@ -28,6 +28,7 @@ make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(indexing/scatter kernels/indexing/indexing.h)
make_jit_source(indexing/masked_scatter)
make_jit_source(indexing/gather kernels/indexing/indexing.h)
make_jit_source(indexing/gather_front kernels/indexing/indexing.h)
make_jit_source(indexing/gather_axis)

View File

@@ -149,7 +149,9 @@ Buffer MetalAllocator::malloc(size_t size) {
buf = device_->newBuffer(size, resource_options);
}
if (!buf) {
return Buffer{nullptr};
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
lk.lock();
num_resources_++;
@@ -201,6 +203,32 @@ size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
}
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
if (!buf) {
return Buffer{nullptr};
}
std::unique_lock lk(mutex_);
residency_set_.insert(buf);
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
num_resources_++;
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::release(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_);
active_memory_ -= buf->length();
num_resources_--;
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();
}
MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This

View File

@@ -21,6 +21,9 @@ class MetalAllocator : public allocator::Allocator {
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
virtual Buffer make_buffer(void* ptr, size_t size) override;
virtual void release(Buffer buffer) override;
size_t get_active_memory() {
return active_memory_;
};

View File

@@ -265,4 +265,19 @@ Device& device(mlx::core::Device);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
inline bool is_nax_available() {
auto _check_nax = []() {
bool can_use_nax = false;
if (__builtin_available(
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
can_use_nax = true;
}
can_use_nax &=
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
return can_use_nax;
};
static bool is_nax_available_ = _check_nax();
return is_nax_available_;
}
} // namespace mlx::core::metal

View File

@@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
@@ -8,7 +9,9 @@
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/scan.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/dtype.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -641,4 +644,84 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
const array& dst = inputs[0];
const array& mask = inputs[1];
const array& src = inputs[2];
auto& s = stream();
auto& d = metal::device(s.device);
const size_t total = mask.size();
const CopyType ct = (total == 1)
? CopyType::Scalar
: (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_gpu(dst, out, ct, s);
if (total == 0) {
return;
}
array mask_flat = flatten_in_eval(mask, 1, -1, s);
if (mask_flat.data<void>() != mask.data<void>()) {
d.add_temporary(mask_flat, s.index);
}
if (!mask_flat.flags().row_contiguous) {
mask_flat = contiguous_copy_gpu(mask_flat, s);
d.add_temporary(mask_flat, s.index);
}
// Prefix (exclusive) of mask → scatter_offsets
array scatter_offsets(mask_flat.shape(), uint32, nullptr, {});
scatter_offsets.set_data(allocator::malloc(scatter_offsets.nbytes()));
d.add_temporary(scatter_offsets, s.index);
scan_gpu_inplace(
mask_flat,
scatter_offsets,
Scan::Sum,
/*axis=*/1,
/*reverse=*/false,
/*inclusive=*/false,
s);
// Kernel selection/build
static constexpr std::string_view kBaseName = "masked_assign";
const std::string dtype_tag = type_to_name(out.dtype());
const std::string value_type = get_type_string(out.dtype());
const std::string contiguous =
(src.flags().row_contiguous) ? "true" : "false";
const std::string kernel_name =
fmt::format("{}_{}_{}", kBaseName, dtype_tag, contiguous);
auto lib = d.get_library(kernel_name, [&]() {
std::string source = metal::utils();
source += metal::masked_scatter();
source += fmt::format(
std::string(masked_assign_kernel), kernel_name, value_type, contiguous);
return source;
});
auto kernel = d.get_kernel(kernel_name, lib);
// Binding
int bind_idx = 0;
const int ndim = static_cast<int>(src.ndim());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(mask_flat, bind_idx++);
compute_encoder.set_input_array(scatter_offsets, bind_idx++);
compute_encoder.set_input_array(src, bind_idx++);
compute_encoder.set_output_array(out, bind_idx++);
compute_encoder.set_vector_bytes(src.shape(), bind_idx++);
compute_encoder.set_vector_bytes(src.strides(), bind_idx++);
compute_encoder.set_bytes(ndim, bind_idx++);
compute_encoder.set_bytes(src.size() / src.shape(0), bind_idx++);
compute_encoder.set_bytes(mask_flat.size() / mask.shape(0), bind_idx++);
// Dispatch
auto group_dims = get_block_dims(total, 1, 1);
MTL::Size grid_dims(total, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -11,6 +11,7 @@ const char* ternary_ops();
const char* reduce_utils();
const char* gather();
const char* scatter();
const char* masked_scatter();
const char* arange();
const char* unary();

View File

@@ -70,3 +70,7 @@ constexpr std::string_view scatter_kernels = R"(
gid);
}}
)";
constexpr std::string_view masked_assign_kernel = R"(
template [[host_name("{0}")]] [[kernel]] decltype(masked_assign_impl<{1}, {2}>) masked_assign_impl<{1}, {2}>;
)";

View File

@@ -9,7 +9,14 @@ set(BASE_HEADERS
utils.h)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
set(METAL_FLAGS
-x
metal
-Wall
-Wextra
-fno-fast-math
-Wno-c++17-extensions
-Wno-c++20-extensions)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif()
@@ -120,6 +127,30 @@ if(NOT MLX_METAL_JIT)
build_kernel(gemv_masked steel/utils.h)
endif()
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
26.2))
set(STEEL_NAX_HEADERS
steel/defines.h
steel/utils.h
steel/gemm/transforms.h
steel/gemm/nax.h
steel/gemm/gemm_nax.h
steel/utils/type_traits.h
steel/utils/integral_constant.h)
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
build_kernel(fp_quantized_nax fp_quantized_nax.h ${STEEL_NAX_HEADERS})
set(STEEL_NAX_ATTN_HEADERS
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
steel/utils/integral_constant.h)
build_kernel(steel/attn/kernels/steel_attention_nax ${STEEL_NAX_ATTN_HEADERS})
endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,74 @@
// Copyright © 2025 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
#include "mlx/backend/metal/kernels/fp_quantized_nax.h"
#define instantiate_quantized_batched(mode, name, type, bm, bn, bk, wm, wn, batched) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \
fp_ ## name, \
type, \
32, \
4, \
batched)
#define instantiate_quantized_aligned(mode, name, type, bm, bn, bk, wm, wn, aligned) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \
fp_ ## name, \
type, \
32, \
4, \
aligned)
#define instantiate_quantized_aligned_batched(mode, name, type, bm, bn, bk, wm, wn, aligned, batched) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \
fp_ ## name, \
type, \
32, \
4, \
aligned, \
batched)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
32, \
4, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, true) \
instantiate_quantized_aligned(mxfp4, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, false) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 1) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 0) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 1) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 0)
#define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, mxfp4_gather_qmm_rhs_nax_nt, type, 64, 64, 64, 2, 2, true) \
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, mxfp4_gather_qmm_rhs_nax_nn, type, 64, 64, 64, 2, 2, false)
#define instantiate_quantized_types(type) \
instantiate_quantized_all_aligned(type) \
instantiate_quantized_all_rhs(type)
instantiate_quantized_types(float)
instantiate_quantized_types(bfloat16_t)
instantiate_quantized_types(float16_t)
// clang-format on

View File

@@ -0,0 +1,38 @@
// Copyright © 2025 Apple Inc.
#pragma once
template <typename T, bool src_contiguous>
[[kernel]] void masked_assign_impl(
const device bool* mask [[buffer(0)]],
const device uint* scatter_offsets [[buffer(1)]],
const device T* src [[buffer(2)]],
device T* out [[buffer(3)]],
const constant int* src_shapes [[buffer(4)]],
const constant int64_t* src_strides [[buffer(5)]],
const constant int& src_ndim [[buffer(6)]],
const constant int64_t& src_batch_size [[buffer(7)]],
const constant int64_t& mask_batch_size [[buffer(8)]],
uint idx [[thread_position_in_grid]]) {
const bool mask_value = mask[idx];
if (!mask_value) {
return;
}
const uint src_index = scatter_offsets[idx];
if (src_index >= src_batch_size) {
return;
}
const uint batch_idx = idx / mask_batch_size;
if (src_contiguous) {
out[idx] = src[batch_idx * src_batch_size + src_index];
} else {
out[idx] = src[elem_to_loc<uint>(
batch_idx * src_batch_size + src_index,
src_shapes,
src_strides,
src_ndim)];
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
// Copyright © 2023-2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
#include "mlx/backend/metal/kernels/quantized_nax.h"
#define instantiate_quantized(name, type, group_size, bits, bm, bn, bk, wm, wn) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
bits, bm, bk, bn, wm, wn)
#define instantiate_quantized_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
batched, bm, bk, bn, wm, wn)
#define instantiate_quantized_aligned(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \
name, \
type, \
group_size, \
bits, \
aligned, bm, bk, bn, wm, wn)
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
aligned, \
batched, bm, bk, bn, wm, wn)
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
group_size, \
bits, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 0)
#define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qmm_n_nax, type, group_size, bits)
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_gather_qmm_n_nax, type, group_size, bits, 64, 64, 64, 2, 2)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true) \
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 0) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 0)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nt, type, group_size, bits, 64, 64, 64, 2, 2, true) \
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nn, type, group_size, bits, 64, 64, 64, 2, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_rhs(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \
instantiate_quantized_funcs(float16_t, group_size, bits) \
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
#define instantiate_quantized_groups(bits) \
instantiate_quantized_types(128, bits) \
instantiate_quantized_types(64, bits) \
instantiate_quantized_types(32, bits)
#define instantiate_quantized_all() \
instantiate_quantized_groups(2) \
instantiate_quantized_groups(3) \
instantiate_quantized_groups(4) \
instantiate_quantized_groups(5) \
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8)
instantiate_quantized_all() // clang-format on

View File

@@ -51,6 +51,7 @@ using namespace metal;
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
instantiate_scan_helper(sum_bool__uint32, bool, uint32_t, CumSum, 4)
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4)

View File

@@ -0,0 +1,476 @@
// Copyright © 2024-25 Apple Inc.
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool align_Q [[function_constant(200)]];
constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
template <typename T>
struct TransformScale {
T scale;
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
METAL_FUNC T apply(T x) const {
return scale * x;
}
};
struct MaxOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return metal::max(x, y);
}
};
struct SumOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x + y;
}
};
struct MulOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x * y;
}
};
struct SubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x - y;
}
};
struct ExpSubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return fast::exp2(x - y);
}
};
struct DivOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x / y;
}
};
// clang-format off
template <
typename T,
int BQ,
int BK,
int BD,
int WM,
int WN,
typename MaskType = float,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant AttnParams* params [[buffer(4)]],
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
(void)simd_lane_id;
// Move to correct block
ulong3 tidl{tid.x, tid.y, tid.z};
Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Sequence
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch
kv_head_idx * params->K_strides[1]; // Head
V += tidl.z * params->V_strides[0] + // Batch
kv_head_idx * params->V_strides[1]; // Head
O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Sequence
if (has_mask) {
mask += tidl.z * mask_params->M_strides[0] + // Batch
tidl.y * mask_params->M_strides[1]; // Head
}
const metal::uniform<float> scale2 =
make_uniform(params->scale) * make_uniform(1.44269504089f);
// Prepare MMA tiles
constexpr short UQ = 16;
constexpr short UD = 32;
constexpr int kNWarps = WM * WN;
static_assert(
BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0,
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
// Q seq frags per warp
constexpr int TQ = BQ / (kNWarps * UQ);
// HeadDim frags (all warps load the same frags)
constexpr int TD = BD / UD;
static_assert(TQ == 1, "Check TQ");
using OSubTile = NAXSubTile<AccumType, UQ, UD>;
NAXTile<AccumType, TQ, TD, OSubTile> Otile;
Otile.clear();
// Prepare mma tile offsets
const short2 simd_coord = OSubTile::NAXFrag_t::get_coord();
const short sm = simd_coord.y;
const short sn = simd_coord.x;
const short tm = UQ * TQ * simd_group_id;
Q += (tm + sm) * int(params->Q_strides[2]) + sn;
K += sm * int(params->K_strides[2]) + sn;
V += sm * int(params->V_strides[2]) + sn;
// Init row reduction variables
constexpr short kRowsPT = decltype(Otile)::kRowsPerThread;
metal::vec<AccumType, kRowsPT> max_score;
metal::vec<AccumType, kRowsPT> sum_score{0};
// Init to -Inf
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::finite_min;
}
if (has_sinks) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
sum_score[i] = 1;
}
}
int kb_lim = params->NK;
if (do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK;
kb_lim = min(params->NK, kb_lim);
}
const bool is_last_bq = int(tid.x) == (params->NQ_aligned);
// const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);
const bool is_last_q = is_last_bq;
const short lim_rows_q = params->qL_rem - (tm + sm);
const short lim_rows_k = params->kL_rem - sm;
// Loop over KV seq length
for (int kb = 0; kb < kb_lim; kb++) {
const int is_last_k = (kb == (params->NK_aligned));
// Do S = Q @ K.T
constexpr short UDs = 16;
constexpr short UKs = 32;
constexpr short TDs = BD / UDs;
constexpr short TKs = BK / UKs;
using SSubTile = NAXSubTile<AccumType, UQ, UKs>;
using QSubTile = NAXSubTile<T, UQ, UDs>;
using KSubTile = NAXSubTile<T, UKs, UDs>;
NAXTile<AccumType, TQ, TKs, SSubTile> Stile;
Stile.clear();
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TKs; ik++) {
STEEL_PRAGMA_UNROLL
for (short id = 0; id < TDs; id++) {
NAXTile<T, 1, 1, QSubTile> Qtile;
NAXTile<T, 1, 1, KSubTile> Ktile;
const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs;
const int K_load_off =
ik * UKs * int(params->K_strides[2]) + id * UDs;
if (!align_Q && is_last_q) {
// Qtile.load_rows(
// Q + Q_load_off,
// int(params->Q_strides[2]),
// lim_rows_q - iq * UQ);
Qtile.load_safe(
Q + Q_load_off,
int(params->Q_strides[2]),
short2(BD, lim_rows_q - iq * UQ));
} else {
Qtile.load(Q + Q_load_off, int(params->Q_strides[2]));
}
if (!align_K && is_last_k) {
// Ktile.load_rows(
// K + K_load_off,
// int(params->K_strides[2]),
// lim_rows_k - ik * UKs);
Ktile.load_safe(
K + K_load_off,
int(params->K_strides[2]),
short2(BD, lim_rows_k - ik * UKs));
} else {
Ktile.load(K + K_load_off, int(params->K_strides[2]));
}
subtile_matmad_nax(
Stile.subtile_at(iq, ik),
Qtile.subtile_at(0, 0),
metal::false_type{},
Ktile.subtile_at(0, 0),
metal::true_type{});
}
}
}
// Scale S
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
Stile.elems()[ii] *= float(scale2);
}
// Scale and Retile S
constexpr short UK = 16;
constexpr short TK = BK / UK;
using PSubTile = NAXSubTile<AccumType, UQ, UK>;
NAXTile<AccumType, TQ, TK, PSubTile> Ptile;
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
Ptile.elems()[ii] = Stile.elems()[ii];
}
// Mask out length sequence
if (!align_K && is_last_k) {
constexpr auto neg_inf = Limits<AccumType>::finite_min;
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
const short col_pos = sn + ik * UK;
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
const auto loc = ii * PSubTile::kFragThrCols + jj;
fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc];
}
}
}
}
}
// Mask out if causal
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
constexpr auto neg_inf = Limits<AccumType>::finite_min;
const int base_row = tid.x * BQ + params->qL_off + tm;
const int base_col = kb * BK;
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
const short row_pos = base_row + iq * UQ;
const short col_pos = base_col + ik * UK;
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm;
const auto c = col_pos + jj + sn;
const auto loc = ii * PSubTile::kFragThrCols + jj;
fg[loc] = (r < c) ? neg_inf : fg[loc];
}
}
}
}
}
// Other masking as needed
if (has_mask) {
constexpr auto neg_inf = Limits<AccumType>::finite_min;
const int base_row = tid.x * BQ + tm;
const int base_col = kb * BK;
constexpr bool is_bool = is_same_v<MaskType, bool>;
using melem_t = typename metal::conditional_t<is_bool, bool, AccumType>;
using MSubTile = NAXSubTile<melem_t, UQ, UK>;
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
const short row_pos = base_row + iq * UQ + sm;
const short col_pos = base_col + ik * UK + sn;
MSubTile mfrag;
mfrag.load_safe(
mask,
int(mask_params->M_strides[2]),
Int<1>{},
params->qL,
params->kL,
row_pos,
col_pos);
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) {
if constexpr (is_bool) {
fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf;
} else {
fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]);
}
}
}
}
}
// Do softmax
// Temp variables
metal::vec<AccumType, kRowsPT> new_max;
metal::vec<AccumType, kRowsPT> factor;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
new_max[i] = max_score[i];
}
// Row max
Ptile.template row_reduce<MaxOp>(new_max);
// exp(Si - rowmax(Si))
Ptile.template row_bin_op<ExpSubOp>(new_max);
// Factor exp(rowmax(Si) - rowmax(Si-1))
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
factor[i] = fast::exp2(max_score[i] - new_max[i]);
max_score[i] = new_max[i];
}
// Row Sum
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
sum_score[i] = sum_score[i] * factor[i];
}
Ptile.template row_reduce<SumOp>(sum_score);
// Update O
Otile.template row_bin_op<MulOp>(factor);
simdgroup_barrier(mem_flags::mem_none);
// Do O = P @ V
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short id = 0; id < TD; id++) {
if constexpr (BD == 128) {
if (id == 2) {
threadgroup_barrier(mem_flags::mem_none);
}
}
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
using VSubTile = NAXSubTile<T, UK, UD>;
NAXTile<T, 1, 1, VSubTile> Vtile;
const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD;
if (!align_K && is_last_k) {
// Vtile.load_rows(
// V + V_load_off,
// int(params->V_strides[2]),
// lim_rows_k - ik * UK);
Vtile.load_safe(
V + V_load_off,
int(params->V_strides[2]),
short2(BD, lim_rows_k - ik * UK));
} else {
Vtile.load(V + V_load_off, int(params->V_strides[2]));
}
subtile_matmad_nax(
Otile.subtile_at(iq, id),
Ptile.subtile_at(iq, ik),
metal::bool_constant<false>{},
Vtile.subtile_at(0, 0),
metal::bool_constant<false>{});
}
}
}
// Prepare for next iteration
K += BK * int(params->K_strides[2]);
V += BK * int(params->V_strides[2]);
}
// Normalize output
threadgroup_barrier(mem_flags::mem_none);
metal::vec<AccumType, kRowsPT> rcp;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
rcp[i] = (1.f / sum_score[i]);
}
Otile.template row_bin_op<MulOp>(rcp);
// Store results
O += (tm + sm) * int(params->O_strides[2]) + sn;
if (!align_Q && is_last_q) {
if (lim_rows_q <= 0)
return;
// Otile.store_rows(O, params->O_strides[2], lim_rows_q);
Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q));
} else {
Otile.store(O, int(params->O_strides[2]));
}
}

View File

@@ -0,0 +1,33 @@
// Copyright © 2024-25 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/nax.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h"
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
instantiate_kernel( \
"steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \
"_wm" #wm "_wn" #wn "_mask" #mname, \
attention_nax, dtype, bq, bk, bd, wm, wn, mtype, float)
#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
instantiate_attn(iname, itype, 64, 32, 128, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 64, 32, 64, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 64, 64, 128, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 64, 64, 64, 4, 1, mname, mtype)
#define instantiate_attn_mask_helper(iname, itype) \
instantiate_attn_shapes_helper(iname, itype, iname, itype) \
instantiate_attn_shapes_helper(iname, itype, bool_, bool)
instantiate_attn_mask_helper(float16, half);
instantiate_attn_mask_helper(bfloat16, bfloat);
instantiate_attn_mask_helper(float32, float);
// clang-format on

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,7 @@
// Copyright © 2024 Apple Inc.
#pragma once
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
#define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)")

View File

@@ -0,0 +1,154 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
using namespace metal;
namespace mlx::steel {
template <
typename T,
short SM,
short SN,
short SK,
short BK,
bool transpose_a,
bool transpose_b,
bool kAlignedM,
bool kAlignedN,
bool kAlignedK,
short UM,
short UN,
short UK,
typename AccumType = float>
auto gemm_loop(
const device T* A,
const device T* B,
const constant GEMMParams* params [[buffer(4)]],
const short sgp_sm,
const short sgp_sn) {
constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
constexpr short TK = SK / UK;
constexpr int RA = transpose_a ? TK : TM;
constexpr int CA = transpose_a ? TM : TK;
constexpr int RB = transpose_b ? TN : TK;
constexpr int CB = transpose_b ? TK : TN;
using DSubTile = NAXSubTile<AccumType, UM, UN>;
using ASubTile =
NAXSubTile<T, (transpose_a ? UK : UM), (transpose_a ? UM : UK)>;
using BSubTile =
NAXSubTile<T, (transpose_b ? UN : UK), (transpose_b ? UK : UN)>;
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
Dtile.clear();
int gemm_k_iterations_ = params->gemm_k_iterations_aligned;
STEEL_PRAGMA_NO_UNROLL
for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) {
threadgroup_barrier(mem_flags::mem_none);
STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
NAXTile<T, RA, CA, ASubTile> Atile;
NAXTile<T, RB, CB, BSubTile> Btile;
const int k = kk1;
volatile int compiler_barrier;
const int A_offset = transpose_a ? k * params->lda : k;
const int B_offset = transpose_b ? k : k * params->ldb;
if constexpr (kAlignedM) {
Atile.load(A + A_offset, params->lda);
} else {
const short rmax = transpose_a ? SK : sgp_sm;
const short cmax = transpose_a ? sgp_sm : SK;
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
}
if constexpr (kAlignedN) {
Btile.load(B + B_offset, params->ldb);
} else {
const short rmax = transpose_b ? sgp_sn : SK;
const short cmax = transpose_b ? SK : sgp_sn;
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
}
tile_matmad_nax(
Dtile,
Atile,
metal::bool_constant<transpose_a>{},
Btile,
metal::bool_constant<transpose_b>{});
(void)compiler_barrier;
}
A += transpose_a ? (BK * params->lda) : BK;
B += transpose_b ? BK : (BK * params->ldb);
}
if constexpr (!kAlignedK) {
simdgroup_barrier(mem_flags::mem_none);
const short rem_bk = params->K - gemm_k_iterations_ * BK;
STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) {
NAXTile<T, 1, 1, ASubTile> Atile;
NAXTile<T, 1, 1, BSubTile> Btile;
STEEL_PRAGMA_UNROLL
for (int mm = 0; mm < TM; mm++) {
STEEL_PRAGMA_UNROLL
for (int nn = 0; nn < TN; nn++) {
STEEL_PRAGMA_UNROLL
for (int kk = 0; kk < TK; kk++) {
const int m = mm * UM;
const int n = nn * UN;
const int k = kk1 + kk * UK;
const short psk = max(0, rem_bk - k);
const int A_offset =
transpose_a ? (m + k * params->lda) : (m * params->lda + k);
const int B_offset =
transpose_b ? (k + n * params->ldb) : (k * params->ldb + n);
{
const short psm = kAlignedM ? SM : max(0, sgp_sm - m);
const short rmax = transpose_a ? psk : psm;
const short cmax = transpose_a ? psm : psk;
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
}
{
const short psn = kAlignedN ? SN : max(0, sgp_sn - n);
const short rmax = transpose_b ? psn : psk;
const short cmax = transpose_b ? psk : psn;
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
}
subtile_matmad_nax(
Dtile.subtile_at(mm, nn),
Atile.subtile_at(0, 0),
metal::bool_constant<transpose_a>{},
Btile.subtile_at(0, 0),
metal::bool_constant<transpose_b>{});
}
}
}
}
}
return Dtile;
}
} // namespace mlx::steel

View File

@@ -0,0 +1,207 @@
// Copyright © 2025 Apple Inc.
using namespace mlx::steel;
constant bool has_batch [[function_constant(10)]];
constant bool use_out_source [[function_constant(100)]];
constant bool do_axpby [[function_constant(110)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
// clang-format off
template <
bool kAlignedM,
bool kAlignedN,
typename NAXTile_t,
typename T>
void gemm_epilogue(
thread NAXTile_t& Dtile,
const device T* C,
const constant GEMMParams* params,
const constant GEMMAddMMParams* addmm_params,
const short sgp_sm,
const short sgp_sn) { // clang-format on
(void)params;
constexpr short UM = NAXTile_t::kSubTileRows;
constexpr short UN = NAXTile_t::kSubTileCols;
using CSubTile = NAXSubTile<T, UM, UN>;
using V = typename NAXTile_t::elem_type;
constexpr short TM = NAXTile_t::kTileRows;
constexpr short TN = NAXTile_t::kTileCols;
constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile;
STEEL_PRAGMA_UNROLL
for (short mm = 0; mm < TM; mm++) {
STEEL_PRAGMA_UNROLL
for (short nn = 0; nn < TN; nn++) {
const short m = mm * UM;
const short n = nn * UN;
CSubTile CTile;
if constexpr (kAlignedM && kAlignedN) {
CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n);
} else {
CTile.load_safe(
C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n);
}
auto delems = Dtile.subtile_at(mm, nn).elems();
auto celems = CTile.elems();
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemsPerSubTile; i++) {
if (do_axpby) {
delems[i] = addmm_params->alpha * delems[i] +
addmm_params->beta * static_cast<V>(celems[i]);
} else {
delems[i] += static_cast<V>(celems[i]);
}
}
}
}
}
// clang-format off
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device T* C [[buffer(2), function_constant(use_out_source)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on
// Find block
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
// Exit early if out of bounds
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Adjust for batch
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
}
D += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
if (use_out_source) {
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
}
constexpr short UM = 16;
constexpr short UN = 32;
constexpr short UK = 16;
constexpr short SM = BM / WM;
constexpr short SN = BN / WN;
constexpr short SK = 32;
constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
const short tm = SM * (simd_group_id / WN);
const short tn = SN * (simd_group_id % WN);
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
A += transpose_a ? tm : (tm * params->lda);
B += transpose_b ? (tn * params->ldb) : tn;
D += tm * params->ldd + tn;
if (use_out_source) {
C += tm * addmm_params->ldc + tn * addmm_params->fdc;
}
using DSubTile = NAXSubTile<AccumType, UM, UN>;
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
dispatch_bool(align_K, [&](auto kAlignedK) {
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
Dtile = gemm_loop<
T,
SM,
SN,
SK,
BK,
transpose_a,
transpose_b,
kAlignedM.value,
kAlignedN.value,
kAlignedK.value,
UM,
UN,
UK,
AccumType>(A, B, params, sgp_sm, sgp_sn);
if (use_out_source) {
gemm_epilogue<kAlignedM.value, kAlignedN.value>(
Dtile, C, params, addmm_params, sgp_sm, sgp_sn);
}
if constexpr (kAlignedM && kAlignedN) {
Dtile.store(D, int(params->ldd));
} else {
Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm));
}
});
});
});
}

View File

@@ -0,0 +1,35 @@
// Copyright © 2025 Apple Inc.
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h"
// clang-format off
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gemm_fused_nax_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \
gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
instantiate_gemm_shapes_helper(float32, float, float32, float);
// clang-format on

View File

@@ -0,0 +1,132 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
gather_mm_rhs_nax(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* rhs_indices [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
constexpr short UM = 16;
constexpr short UN = 32;
constexpr short UK = 16;
constexpr short SM = BM / WM;
constexpr short SN = BN / WN;
constexpr short SK = 32;
constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Find the block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
rhs_indices += c_row;
const short tm = SM * (simd_group_id / WN);
const short tn = SN * (simd_group_id % WN);
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
A += transpose_a ? tm : (tm * params->lda);
B += transpose_b ? (tn * params->ldb) : tn;
C += tm * params->ldd + tn;
rhs_indices += tm;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = rhs_indices[0];
short offset_next = 0;
int n = 0;
while (n < sgp_sm) {
n++;
offset = offset_next;
index = index_next;
offset_next = sgp_sm;
for (; n < sgp_sm; n++) {
if (rhs_indices[n] != index) {
offset_next = n;
index_next = rhs_indices[n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
using DSubTile = NAXSubTile<AccumType, UM, UN>;
NAXTile<AccumType, TM, TN, DSubTile> Ctile;
dispatch_bool(align_K, [&](auto kAlignedK) {
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
auto do_gemm = gemm_loop<
T,
SM,
SN,
SK,
BK,
transpose_a,
transpose_b,
kAlignedM.value,
kAlignedN.value,
kAlignedK.value,
UM,
UN,
UK,
AccumType>;
Ctile = do_gemm(
A, B + index * params->batch_stride_b, params, sgp_sm, sgp_sn);
if constexpr (kAlignedN.value) {
if (offset_next - offset == SM) {
Ctile.store(C, int(params->ldd));
} else {
Ctile.store_slice(
C,
int(params->ldd),
short2(0, offset),
short2(SN, offset_next));
}
} else {
Ctile.store_slice(
C,
int(params->ldd),
short2(0, offset),
short2(sgp_sn, offset_next));
}
});
});
});
}
}

View File

@@ -0,0 +1,39 @@
// Copyright © 2024 Apple Inc.
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/utils.h"
// clang-format off
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_rhs_nax_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm_rhs_nax, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 128, 128, 1, 4) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 32, 128, 128, 1, 4) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 64, 128, 128, 2, 4)
// clang-format on
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
instantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);

Some files were not shown because too many files have changed in this diff Show More