Compare commits

..

47 Commits

Author SHA1 Message Date
Andrey Portnoy
3e05cea9f8 Force cudaGraphExec reinstantiation when clusters are used (#2813)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-22 12:43:49 -08:00
CCYeh
5b0f047226 Fix mx.core.load type annotation (#2819) 2025-11-22 11:09:44 -08:00
Harsh Sutaria
618c87af8c Add float64 Eig and complex64 SVD/Eig support (Fixes #2708) (#2737)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-22 06:51:36 -08:00
Cheng
d5f61a93fa Fix typo: refs/head/main => refs/heads/main (#2818)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-22 09:43:35 +09:00
Awni Hannun
4a09264236 Tolerance for some ops tests on cuda (#2815) 2025-11-21 16:06:16 -08:00
Awni Hannun
0dbc7e5bee Centralize NAX condition (#2811)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-21 13:28:15 -08:00
Awni Hannun
0d68efd461 patch bump for future version (#2804)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-20 09:26:20 -08:00
Awni Hannun
f9e1a14135 [CUDA] Partly fix random for large sizes (#2798) 2025-11-20 07:27:50 -08:00
Awni Hannun
d8e9ded928 Fix cuda allocator copy condition (#2800) 2025-11-20 07:06:55 -08:00
Awni Hannun
60939d010c Fix macos release target and linux arm release (#2802)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-19 21:37:50 -08:00
Awni Hannun
fdcd2923fd patch + fix docs build (#2799) 2025-11-19 16:16:26 -08:00
Jagrit Digani
54f1cc6e3e Add Neural Accelerator Support (#2772) 2025-11-19 15:06:00 -08:00
CCYeh
b3825ac149 Add Masked Scatter (#2663)
Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-11-19 14:53:32 -08:00
Awni Hannun
7f4b7e553c version (#2797) 2025-11-19 14:11:16 -08:00
Awni Hannun
ad16f41a7f Fix version tag (#2790)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-19 08:55:57 -08:00
Awni Hannun
f46877bc08 more accurate rope fallback (#2792) 2025-11-19 06:07:21 -08:00
Cheng
6f35017d1b [CUDA] cuDNN backward attention (#2762)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-19 08:13:50 +09:00
Awni Hannun
b167f0df1c build docs on linux (#2787)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-18 08:01:03 -08:00
Cheng
a9f0d6b160 Avoid duplicate CI runs when starting a PR from upstream branch (#2788)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-18 15:16:25 +09:00
Cheng
940f4c7818 Fix building with CUDA < 12.8 (#2782)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-18 12:55:19 +09:00
Cheng
35f81728f1 Remove unneeded tests in nightly build (#2786) 2025-11-18 08:09:58 +09:00
Cheng
4442ed86c1 Fix nightly build (#2785) 2025-11-18 08:07:51 +09:00
Cheng
698559c231 Test every commit in main branch (#2781) 2025-11-18 08:07:22 +09:00
Cheng
ecc4879b07 Do not run CPU tests in CUDA builds (#2784) 2025-11-18 07:27:09 +09:00
Cheng
32b18d8b66 Use std::optional for mask_arr arg (#2763)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.8) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (cuda-12.8) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (cuda-12.9) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-17 10:43:33 +09:00
Cheng
472c43a0c8 Build and test with multiple CUDA versions (#2780) 2025-11-17 09:19:02 +09:00
Cheng
b7214ff01e Remove pip cache in GitHub Actions (#2776)
* Correctly set pip cache key

* [Debug] Try disabling pip cache
2025-11-17 08:19:59 +09:00
Cheng
76414c8971 Run CI for pushes (#2777) 2025-11-17 07:19:01 +09:00
Awni Hannun
49e4566df3 fix release 2 (#2767)
* fix release 2

* login

* fix
2025-11-16 11:39:53 -08:00
Awni Hannun
aad49f932f [CUDA] Tune ops per buffer based on device (#2761)
* tune ops per buffer based on device

* tune memory limit as well

* add tuning for spark
2025-11-16 06:29:49 -08:00
Cheng
86765cce34 Use ccache in GitHub Actions (#2773)
* Remove unnecessary steps

* Use ccache

* Log when using ccache

* Set max-size to 1GB

* Pass --no-build-isolation

* Remove more unused things
2025-11-16 07:58:14 +09:00
Cheng
1bedcbd556 Fix warnings with cmake 4.1 (#2774) 2025-11-16 07:12:47 +09:00
Cheng
9ac7dbe877 Fix MPI distributed tests with CUDA backend (#2775) 2025-11-16 07:12:18 +09:00
Awni Hannun
1bf605d56d use arch specific targets when possible (#2771) 2025-11-14 20:04:18 -08:00
Cheng
3c622ddd1d Separate test-linux from build-linux/cuda in GitHub Actions (#2765)
* Separate test-linux from build-linux/cuda in GitHub Actions

* Prefer unittest when possible

Co-authored-by: Mike Drob <mdrob@apache.org>

---------

Co-authored-by: Mike Drob <mdrob@apache.org>
2025-11-15 11:14:09 +09:00
Awni Hannun
27ff069175 Fix exporting with constants (#2769) 2025-11-14 12:52:08 -08:00
Cheng
3b2ffcefc3 [CUDA] cuDNN forward attention (#2743)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Separate sdpa kernels in another file

* Initial support for cuDNN SDPA

* Diable a few corner cases

* Remove scaled_dot_product_attention.h

* Use cuDNN attention for prefilling

* cuDNN SDPA requires Ampere and later

* Address reviews

* Do contiguous copy of inputs
2025-11-14 09:23:56 +09:00
Awni Hannun
b65f882df3 fix release (#2759) 2025-11-13 15:34:01 -08:00
Cheng
b704e9e77a [CUDA] Check CUDA error in synchronize (#2757) 2025-11-14 07:10:23 +09:00
Awni Hannun
66519fb348 fix slice (#2758) 2025-11-13 11:30:02 -08:00
Awni Hannun
8973550ff3 export custom kernel (#2756) 2025-11-13 11:29:50 -08:00
Mike Drob
3f866be665 minor debugging for publishing (#2739)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* minor debugging for publishing

* fix logic
2025-11-12 06:33:39 -08:00
Awni Hannun
23f81ed1c1 Linux on arm (#2751)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* try linux on arm

* ssh

* fix
2025-11-11 11:41:14 -08:00
wrmsr
3fe2250c00 Fix irregular_strides benchmark shape type (#2754) 2025-11-11 11:40:22 -08:00
Awni Hannun
047114b988 remove circle (#2753) 2025-11-11 11:39:47 -08:00
wrmsr
9320eb89a8 Fix dequantize python sig (dtype default) (#2752) 2025-11-11 09:55:24 -08:00
Awni Hannun
75819d70ea patch bump (#2750) 2025-11-11 08:49:14 -08:00
139 changed files with 10465 additions and 1469 deletions

View File

@@ -1,579 +0,0 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "26.0.0"
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.10
brew install doxygen
python3.10 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
name: Run style checks
command: |
pip install pre-commit
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
- run:
name: Install dependencies
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e ".[dev]" -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "26.0.0"
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
- run:
name: Install Python package
command: |
uv venv --python 3.10
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e . -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build example extension
command: |
source .venv/bin/activate
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
- run:
name: Build small binary
command: |
source .venv/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine:
image: "linux-cuda-12:<< parameters.image_date >>"
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Set CCache size
command: ccache --max-size 1G
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
name: Run Python tests
command: |
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release:
parameters:
python_version:
type: string
default: "3.10"
xcode_version:
type: string
default: "26.0.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m4pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
- checkout
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
mkdir -p ~/miniconda3
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install twine
pip install build
- run:
name: Install Python package
command: |
conda activate env
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
pip install . -v
- run:
name: Generate package stubs
command: |
conda activate env
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
conda activate env
twine upload dist/*
- store_artifacts:
path: dist/
build_linux_release:
parameters:
python_version:
type: string
default: "3.10"
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
name: Build wheel
command: |
PYTHON=python<< parameters.python_version >>
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
TZ=Etc/UTC sudo apt-get -y install tzdata
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.build_env >> pip install ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: xlarge
steps:
- checkout
- run:
name: Build wheel
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
pip install auditwheel
pip install patchelf
pip install build
pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w
bash python/scripts/repair_cuda.sh
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"]
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
- build_cuda_release
build_dev_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]

View File

@@ -2,8 +2,8 @@ name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
nvcc-location:
description: 'Location of nvcc compiler'
toolkit:
description: 'The CUDA toolkit'
required: true
runs:
@@ -12,7 +12,7 @@ runs:
- name: Build package
shell: bash
env:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all

View File

@@ -2,10 +2,9 @@ name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
nvcc-location:
description: 'Location of nvcc compiler'
toolkit:
description: 'The CUDA toolkit'
required: true
default: '/usr/local/cuda-12.9/bin/nvcc'
runs:
using: "composite"
@@ -14,32 +13,14 @@ runs:
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: pip install -e ".[dev]" -v
- name: Run Python tests - CPU
shell: bash
env:
LOW_MEMORY: 1
DEVICE: cpu
run: python -m unittest discover python/tests -v
- name: Run Python tests - GPU
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
run: python -m tests discover python/tests -v
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
run: pip install --no-build-isolation -e ".[dev]" -v
- name: Build CPP only
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
-DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)
- name: Run CPP tests
shell: bash
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"

View File

@@ -1,19 +1,19 @@
name: 'Build Documentation'
description: 'Build documentation on a mac'
description: 'Build documentation'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-macos
uses: ./.github/actions/setup-linux
- name: Install dependencies
shell: sh
shell: bash
run: |
brew install doxygen
uv pip install --upgrade pip cmake
uv pip install -r docs/requirements.txt
uv pip install . -v
sudo apt-get install -y doxygen
source .venv/bin/activate
pip install -r docs/requirements.txt
pip install . -v
- name: Build documentation
shell: bash
@@ -24,8 +24,8 @@ runs:
make html O=-W
- name: Create artifact tar
shell: sh
run: tar -cf artifact.tar --cd docs --dereference build/html index.html
shell: bash
run: tar -cf artifact.tar -C docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
@@ -35,4 +35,4 @@ runs:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error
if-no-files-found: error

View File

@@ -7,6 +7,13 @@ inputs:
type: boolean
required: false
default: false
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs:
using: "composite"
@@ -23,11 +30,11 @@ runs:
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
bash python/scripts/repair_linux.sh
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}

View File

@@ -9,33 +9,17 @@ runs:
env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
DEBUG: 1
run: pip install -e ".[dev]" -v
run: pip install --no-build-isolation -e ".[dev]" -v
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Run Python tests
shell: bash
run: |
python -m unittest discover python/tests -v
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
- name: Build CPP only
shell: bash
run: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
- name: Run CPP tests
shell: sh
run: ./build/tests/tests

View File

@@ -16,18 +16,19 @@ runs:
using: "composite"
steps:
- name: Build Python package
shell: bash
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
uv pip install build
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=1 uv run -m build -w
pip install build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash
shell: bash -l {0}
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=2 uv run -m build -w
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w

View File

@@ -5,47 +5,47 @@ runs:
using: "composite"
steps:
- name: Install dependencies
shell: sh
env:
DEBUG: 1
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
shell: bash -l {0}
run: |
uv pip install --upgrade pip
uv pip install cmake setuptools nanobind==2.4.0
uv pip install -e . -v
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs
shell: bash
shell: bash -l {0}
run: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
pip install typing_extensions
python setup.py generate_stubs
- name: Install tests dependencies
shell: sh
shell: bash -l {0}
run: |
uv pip install numpy torch tensorflow unittest-xml-reporting
pip install numpy torch tensorflow unittest-xml-reporting
- name: Run Python tests
shell: bash
shell: bash -l {0}
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
shell: bash
shell: bash -l {0}
run: |
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project test.py
pip install -r requirements.txt
python setup.py build_ext --inplace
python test.py
- name: Build CPP only
shell: bash
shell: bash -l {0}
run: |
mkdir -p build
cd build
@@ -53,7 +53,7 @@ runs:
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
shell: bash
shell: bash -l {0}
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
@@ -61,7 +61,7 @@ runs:
run: ./build/tests/tests
- name: Build small binary with JIT
shell: bash
shell: bash -l {0}
run: |
mkdir -p build
cd build
@@ -74,7 +74,7 @@ runs:
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
shell: bash
shell: bash -l {0}
env:
LOW_MEMORY: 1
DEVICE: gpu
@@ -82,7 +82,7 @@ runs:
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
uv run -m xmlrunner discover \
pip install -e . -v
python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit

View File

@@ -2,14 +2,10 @@ name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
runner-type:
description: 'Whether to set this up as a linux or CUDA runner'
toolkit:
description: 'Which toolkit to install'
required: false
default: 'linux'
type: choice
options:
- linux
- cuda
default: 'cpu'
python-version:
description: 'Version of python to set up'
required: false
@@ -18,56 +14,62 @@ inputs:
runs:
using: "composite"
steps:
- name: Free disk space
shell: sh
if: inputs.runner-type == 'linux'
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Use ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
max-size: 1GB
- name: Install common dependencies
env:
TZ: Etc/UTC
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip
sudo apt autoremove -y
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
cache: 'pip'
- name: setup python venv
- name: Setup Python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
pip install setuptools cmake nanobind==2.4.0
echo PATH=$PATH >> $GITHUB_ENV
pip install --upgrade pip cmake
# Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
- name: Install MPI
if: inputs.runner-type == 'linux'
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Network CUDA installation from packages
id: install-cuda
if: inputs.runner-type == 'cuda'
- name: Install CUDA toolkit
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
env:
TZ: Etc/UTC
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
# Note: the CI machine does not meet CUDA 13's driver requirement.
# Compatibility matrix:
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# The `nvcc` is installed into `/usr/local/cuda-VERSION/bin/nvcc` - but
# it's *not* on the default toolkit path.
PACKAGES: |
{
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
}
run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
export ARCH=${{ runner.arch == 'arm64' && 'arm64' || 'x86_64' }}
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
sudo apt-get install -y \
libnccl2 libnccl-dev \
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
- name: Package and Driver Report
if: inputs.runner-type == 'cuda'
- name: CUDA packages and driver report
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms

View File

@@ -17,9 +17,8 @@ runs:
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- name: Setup uv
uses: astral-sh/setup-uv@v6
- uses: conda-incubator/setup-miniconda@v3
with:
python-version: ${{ inputs.python-version }}
activate-environment: true
miniconda-version: "latest"
python-version: ${{ inputs.python-version }}

69
.github/actions/test-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,69 @@
name: 'Run Linux tests'
inputs:
cpu-only:
description: 'Skip GPU tests'
required: false
default: false
runs:
using: "composite"
steps:
- name: Run MPI tests
shell: bash
run: |
echo "::group::MPI tests"
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
echo "::endgroup::"
- name: Run distributed tests
if: ${{ inputs.cpu-only == 'true' }}
shell: bash
run: |
echo "::group::Distributed tests"
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
echo "::endgroup::"
- name: Run Python tests - CPU
if: ${{ inputs.cpu-only == 'true' }}
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::Python tests - CPU"
python -m unittest discover python/tests -v
echo "::endgroup::"
- name: Run Python tests - GPU
if: ${{ inputs.cpu-only == 'false' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::Python tests - GPU"
python -m tests discover python/tests -v
echo "::endgroup::"
- name: Run CPP tests - CPU
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::CPP tests - CPU"
./build/tests/tests
echo "::endgroup::"
- name: Run CPP tests - GPU
if: ${{ inputs.cpu-only == 'false' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::CPP tests - GPU"
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
echo "::endgroup::"

View File

@@ -8,7 +8,7 @@ permissions:
jobs:
build:
runs-on: [self-hosted, macos]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
@@ -25,4 +25,4 @@ jobs:
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
uses: actions/deploy-pages@v4

View File

@@ -21,6 +21,7 @@ jobs:
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: "x86_64"
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
@@ -34,19 +35,25 @@ jobs:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
python_version: ["3.11", "3.12", "3.13", "3.14"]
runner:
- ubuntu-22.04
- ubuntu-22.04-arm
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
with:
cpu-only: true
build_mac_release:
if: github.repository == 'ml-explore/mlx'
@@ -60,7 +67,6 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
@@ -72,16 +78,6 @@ jobs:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
build_cuda_with_tests:
if: github.repository == 'ml-explore/mlx'
runs-on: gpu-t4-4-core
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
@@ -89,36 +85,14 @@ jobs:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
toolkit: 'cuda-12.9'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -1,28 +1,52 @@
name: Build and Test
on: pull_request
on:
pull_request:
push:
branches:
- main
# For testing CI without starting a pull request:
- test/*
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
check_lint:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
runs-on: ubuntu-22.04
needs: check_lint
strategy:
matrix:
runner:
- ubuntu-22.04
- ubuntu-22.04-arm
fail-fast: false
runs-on: ${{ matrix.runner }}
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
with:
cpu-only: true
mac_build_and_test:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
macos-target: ["14.0", "15.0"]
runs-on: [self-hosted, macos]
env:
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
needs: check_lint
steps:
- uses: actions/checkout@v5
@@ -31,18 +55,25 @@ jobs:
cuda_build_and_test:
if: github.repository == 'ml-explore/mlx'
strategy:
fail-fast: false
matrix:
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-cuda
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
runs-on: ubuntu-22.04
needs: check_lint
steps:
- uses: actions/checkout@v5
@@ -50,6 +81,7 @@ jobs:
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:

View File

@@ -5,6 +5,11 @@ on:
tags:
- 'v*'
workflow_dispatch:
inputs:
dev_release:
description: "Do a dev release or regular release"
required: true
default: "false"
permissions:
contents: read
@@ -12,16 +17,13 @@ permissions:
jobs:
setup:
runs-on: ubuntu-latest
outputs:
pypi_env: ${{ github.event_name == 'push' && 'pypi' || 'test-pypi' }}
pypi_url: ${{ github.event_name == 'push' && 'https://upload.pypi.org/legacy/' || 'https://test.pypi.org/legacy/' }}
steps:
- name: Set publishing variables
run: echo "Publishing setup complete"
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
@@ -45,9 +47,11 @@ jobs:
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
@@ -56,16 +60,19 @@ jobs:
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: ${{ matrix.arch }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
overwrite: true
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
overwrite: true
name: mlx-cpu-${{ matrix.arch }}
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
@@ -76,22 +83,25 @@ jobs:
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: sh
shell: bash -l {0}
run: |
uv pip install --upgrade pip
uv pip install cmake setuptools nanobind==2.4.0
uv pip install -e . -v
pip install --upgrade pip
pip install cmake setuptools nanobind==2.4.0
pip install -e . -v
- name: Generate package stubs
shell: bash
shell: bash -l {0}
run: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
pip install typing_extensions
python setup.py generate_stubs
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
@@ -105,12 +115,14 @@ jobs:
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-metal
path: dist/mlx_metal-*.whl
@@ -119,18 +131,20 @@ jobs:
runs-on: ubuntu-22-large
env:
PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
toolkit: 'cuda-12.9'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
toolkit: 'cuda-12.9'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
overwrite: true
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
@@ -141,7 +155,7 @@ jobs:
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
@@ -159,7 +173,7 @@ jobs:
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cuda:
name: Upload CUDA release to PyPI
@@ -168,7 +182,7 @@ jobs:
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
@@ -180,7 +194,7 @@ jobs:
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
repository-url: https://upload.pypi.org/legacy/
pypi-publish-cpu:
name: Upload CPU release to PyPI
@@ -189,19 +203,20 @@ jobs:
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cpu
pattern: mlx-cpu-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
repository-url: https://upload.pypi.org/legacy/
pypi-publish-metal:
name: Upload Metal release to PyPI
@@ -210,7 +225,7 @@ jobs:
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
@@ -222,5 +237,4 @@ jobs:
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
repository-url: https://upload.pypi.org/legacy/

View File

@@ -74,6 +74,7 @@ endif()
if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")

View File

@@ -75,7 +75,7 @@ void time_irregular_binary_ops_3D() {
void time_irregular_binary_ops_4D() {
auto device = mx::default_device();
std::vector<int> shape = {8, 8, 512, 512};
mx::Shape shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape);
@@ -115,7 +115,7 @@ void time_irregular_binary_ops_4D() {
void time_irregular_reshape() {
auto device = mx::default_device();
std::vector<int> shape;
mx::Shape shape;
auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device);
};
@@ -170,7 +170,7 @@ void time_irregular_astype_1D() {
void time_irregular_astype_2D() {
auto device = mx::default_device();
int size = 2048;
std::vector<int> shape = {size, size};
mx::Shape shape = {size, size};
auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device);

View File

@@ -1,6 +1,5 @@
# Copyright © 2023 Apple Inc.
import argparse
import os
import subprocess
import time

View File

@@ -0,0 +1,212 @@
import math
import os
import subprocess
import time
from copy import copy
from functools import partial
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
from matplotlib.ticker import FuncFormatter
RESULTS_DIR = "./results"
if not os.path.isdir(RESULTS_DIR):
os.mkdir(RESULTS_DIR)
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
TORCH_DEVICE = torch.device(
"mps"
if torch.backends.mps.is_available()
else ("cuda" if torch.cuda.is_available() else "cpu")
)
N_WARMUP = 5
N_ITER_BENCH = 50
N_ITER_FUNC = 20
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
D_TYPES = ("float32", "float16")
def _power_of_two_formatter(value, _position):
if value <= 0:
return ""
exponent = int(round(math.log2(value)))
if abs(value - (1 << exponent)) / value > 1e-6:
return f"{value:g}"
return f"$2^{{{exponent}}}$"
def torch_sync():
if TORCH_DEVICE.type == "cuda":
torch.cuda.synchronize()
elif TORCH_DEVICE.type == "mps":
torch.mps.synchronize()
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
outs = []
for _ in range(N_ITER_FUNC):
out = copy(self_arr)
out[mask_arr] = src_arr
outs.append(out)
mx.eval(outs)
return outs
@torch.no_grad()
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
outs = []
for _ in range(N_ITER_FUNC):
out = self_tensor.clone()
out.masked_scatter_(mask_tensor, src_tensor)
outs.append(out)
torch_sync()
return outs
def measure(fn):
for _ in range(N_WARMUP):
fn()
start = time.perf_counter_ns()
for _ in range(N_ITER_BENCH):
fn()
end = time.perf_counter_ns()
return (end - start) * 1e-9
def bytes_touched(length, true_count, item_size):
mask_bytes = length
self_bytes = length * item_size * 2 # read + write
src_bytes = true_count * item_size
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
def build_case(length, density, np_dtype, torch_dtype):
true_count = max(1, int(round(length * density)))
rng = np.random.default_rng()
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
mask_np = np.zeros(length, dtype=bool)
mask_np[:true_count] = True
rng.shuffle(mask_np)
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
self_mlx = mx.array(self_np)
mask_mlx = mx.array(mask_np)
src_mlx = mx.array(src_np)
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
# Correctness check once per configuration
mx_out = mx.array(self_np)
mx_out[mask_mlx] = src_mlx
mx.eval(mx_out)
torch_out = self_torch.clone()
torch_out.masked_scatter_(mask_torch, src_torch)
atol = 5e-3 if np_dtype == np.float16 else 1e-5
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
raise AssertionError("masked_scatter results diverged between MLX and Torch")
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
def bench_case(length, density, dtype):
np_dtype = getattr(np, dtype)
torch_dtype = getattr(torch, dtype)
(
self_mlx,
mask_mlx,
src_mlx,
self_torch,
mask_torch,
src_torch,
true_count,
) = build_case(length, density, np_dtype, torch_dtype)
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
time_torch = measure(
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
)
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
bytes_per_gb = float(1024**3)
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
return time_mlx, time_torch, mlx_gbps, torch_gbps
def plot_density(ax_perf, ax_speedup, density, dtype):
mlx_gbps = []
torch_gbps = []
mlx_times = []
torch_times = []
for length in VECTOR_LENGTHS:
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
mlx_gbps.append(gbps_mlx)
torch_gbps.append(gbps_torch)
mlx_times.append(t_mlx)
torch_times.append(t_torch)
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
ax_perf.set_xscale("log", base=2)
ax_perf.set_xticks(VECTOR_LENGTHS)
formatter = FuncFormatter(_power_of_two_formatter)
ax_perf.xaxis.set_major_formatter(formatter)
ax_perf.set_title(f"density={density:.2f}")
ax_perf.set_ylabel("GB/s")
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
ax_perf.legend()
speedup = np.array(torch_times) / np.array(mlx_times)
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
ax_speedup.set_xscale("log", base=2)
ax_speedup.set_xticks(VECTOR_LENGTHS)
ax_speedup.xaxis.set_major_formatter(formatter)
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
def main():
for dtype in D_TYPES:
fig, axs = plt.subplots(
len(MASK_DENSITIES),
2,
figsize=(10, 12),
layout="constrained",
sharex=True,
)
for i, density in enumerate(MASK_DENSITIES):
plot_density(axs[i][0], axs[i][1], density, dtype)
axs[i][0].set_xlabel("vector length")
axs[i][1].set_xlabel("vector length")
fig.suptitle(
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
)
output_path = os.path.join(
RESULTS_DIR,
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
)
fig.savefig(output_path)
plt.close(fig)
if __name__ == "__main__":
main()

3
cmake/Findnvpl.cmake Normal file
View File

@@ -0,0 +1,3 @@
# This file does nothing but to suppress the cmake warning: "By not providing
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.

View File

@@ -70,7 +70,8 @@ Differences from NumPy
* Indexing does not perform bounds checking. Indexing out of bounds is
undefined behavior.
* Boolean mask based indexing is not yet supported.
* Boolean mask based indexing is supported for assignment only (see
:ref:`boolean-mask-assignment`).
The reason for the lack of bounds checking is that exceptions cannot propagate
from the GPU. Performing bounds checking for array indices before launching the
@@ -143,3 +144,51 @@ expected. For example:
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
and ones elsewhere.
.. _boolean-mask-assignment:
Boolean Mask Assignment
-----------------------
MLX supports boolean indices using NumPy syntax. A mask must already be
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
Other index types are routed through the standard scatter code.
.. code-block:: shell
>>> a = mx.array([1.0, 2.0, 3.0])
>>> mask = mx.array([True, False, True])
>>> updates = mx.array([5.0, 6.0])
>>> a[mask] = updates
>>> a
array([5.0, 2.0, 6.0], dtype=float32)
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
assignments, ``updates`` must provide at least as many elements as there are
``True`` entries in ``mask``.
.. code-block:: shell
>>> a = mx.zeros((2, 3))
>>> mask = mx.array([[True, False, True],
[False, False, True]])
>>> a[mask] = 1.0
>>> a
array([[1.0, 0.0, 1.0],
[0.0, 0.0, 1.0]], dtype=float32)
Boolean masks follow NumPy semantics:
- The mask shape must match the shape of the axes it indexes exactly. No mask
broadcasting occurs.
- Any axes not covered by the mask are taken in full.
.. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
axes and therefore raise errors.

View File

@@ -167,7 +167,7 @@ void array::copy_shared_buffer(
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
int64_t offset /* = 0 */) {
array_desc_->data = other.array_desc_->data;
array_desc_->strides = strides;
array_desc_->flags = flags;

View File

@@ -439,7 +439,7 @@ class array {
const Strides& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
int64_t offset = 0);
void copy_shared_buffer(const array& other);

View File

@@ -14,17 +14,13 @@ std::tuple<int64_t, Strides> prepare_slice(
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
}
// Normalize the offset
if (data_offset < 0) {
data_offset += in.data_size();
}
return std::make_tuple(data_offset, inp_strides);
}
void shared_buffer_slice(
const array& in,
const Strides& out_strides,
size_t data_offset,
int64_t data_offset,
size_t data_size,
array& out) {
// Compute row/col contiguity
@@ -51,17 +47,24 @@ void slice(
// Calculate out strides, initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
int64_t data_end = 1;
for (int i = 0; i < start_indices.size(); ++i) {
if (in.shape()[i] > 1) {
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
data_end += end_idx * in.strides()[i];
// Get the location of the end based on the inp strides and out.shape()
int64_t low_idx = 0;
int64_t high_idx = 0;
for (int i = 0; i < inp_strides.size(); ++i) {
auto delta = inp_strides[i] * (out.shape()[i] - 1);
if (inp_strides[i] > 0) {
high_idx += delta;
} else {
low_idx += delta;
}
}
if (data_end < 0) {
data_end += in.data_size();
int64_t data_size = (high_idx - low_idx) + 1;
if (data_size < 0) {
std::ostringstream msg;
msg << "[slice] Computed invalid data size: " << data_size << ".";
throw std::runtime_error(msg.str());
}
size_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
}

View File

@@ -12,6 +12,167 @@ namespace mlx::core {
namespace {
template <typename T>
complex64_t to_complex(T r, T i) {
return {static_cast<float>(r), static_cast<float>(i)};
}
template <typename T, class Enable = void>
struct EigWork {};
template <typename T>
struct EigWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using O = complex64_t;
char jobl;
char jobr;
int N;
int lwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
T work;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
if (compute_eigenvectors) {
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
}
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, O* values, O* vectors) {
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
T* vec_tmp = nullptr;
if (vectors) {
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
}
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
eig_tmp,
eig_tmp + N,
vectors ? vec_tmp : nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
work,
&lwork,
&info);
for (int i = 0; i < N; ++i) {
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
}
if (vectors) {
for (int i = 0; i < N; ++i) {
if (values[i].imag() != 0) {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] =
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
vectors[(i + 1) * N + j] =
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
}
}
}
}
}
};
template <>
struct EigWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
using O = T;
char jobl;
char jobr;
int N;
int lwork;
int lrwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
T work;
R rwork;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&rwork,
&info);
lwork = static_cast<int>(work.real());
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
}
void run(T* a, T* values, T* vectors) {
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
values,
vectors,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&info);
}
};
template <typename T>
void eig_impl(
array& a,
@@ -19,101 +180,39 @@ void eig_impl(
array& values,
bool compute_eigenvectors,
Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>();
auto eig_ptr = values.data<OT>();
auto val_ptr = values.data<complex64_t>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_output_array(values);
OT* vec_ptr = nullptr;
complex64_t* vec_ptr = nullptr;
if (compute_eigenvectors) {
encoder.set_output_array(vectors);
vec_ptr = vectors.data<OT>();
vec_ptr = vectors.data<complex64_t>();
}
encoder.dispatch([a_ptr,
val_ptr,
vec_ptr,
eig_ptr,
compute_eigenvectors,
N = vectors.shape(-1),
size = vectors.size()]() mutable {
// Work query
char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
int info;
{
T work;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
}
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
for (size_t i = 0; i < size / (N * N); ++i) {
geev<T>(
&jobl,
&jobr,
&N,
a_ptr,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
}
work.run(a_ptr, val_ptr, vec_ptr);
a_ptr += N * N;
val_ptr += N;
if (vec_ptr) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
}
}
}
vec_ptr += N * N;
}
a_ptr += N * N;
eig_ptr += N;
if (info != 0) {
if (work.info != 0) {
std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
<< work.info;
throw std::runtime_error(msg.str());
}
}
@@ -165,8 +264,17 @@ void Eig::eval_cpu(
case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case float64:
eig_impl<double>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case complex64:
eig_impl<std::complex<float>>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default:
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
throw std::runtime_error(
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
}
}

View File

@@ -747,4 +747,108 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
});
}
template <typename T>
void masked_scatter_impl(const array& mask, const array& src, array& out) {
ContiguousIterator mask_it(mask);
ContiguousIterator src_it(src);
ContiguousIterator out_it(out);
const bool* mask_ptr = mask.data<bool>();
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
const size_t batch_count = mask.shape(0);
const size_t mask_batch_size = mask.size() / batch_count;
const size_t src_batch_size = src.size() / batch_count;
for (uint b = 0; b < batch_count; ++b) {
size_t src_consumed = 0;
src_it.seek(b * src_batch_size);
for (size_t i = 0; i < mask_batch_size; ++i) {
if (mask_ptr[mask_it.loc]) {
if (src_consumed >= src_batch_size) {
throw std::runtime_error(
"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.");
}
dst_ptr[out_it.loc] = src_ptr[src_it.loc];
src_it.step();
++src_consumed;
}
mask_it.step();
out_it.step();
}
}
}
void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
auto& dst = inputs[0];
auto& mask = inputs[1];
auto& src = inputs[2];
// Copy src into out (copy allocates memory for out)
auto ctype =
dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(dst, out, ctype, stream());
if (mask.size() == 0) {
return;
}
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(mask);
encoder.set_input_array(src);
encoder.set_output_array(out);
encoder.dispatch([mask = array::unsafe_weak_copy(mask),
src = array::unsafe_weak_copy(src),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
masked_scatter_impl<bool>(mask, src, out);
break;
case uint8:
masked_scatter_impl<uint8_t>(mask, src, out);
break;
case uint16:
masked_scatter_impl<uint16_t>(mask, src, out);
break;
case uint32:
masked_scatter_impl<uint32_t>(mask, src, out);
break;
case uint64:
masked_scatter_impl<uint64_t>(mask, src, out);
break;
case int8:
masked_scatter_impl<int8_t>(mask, src, out);
break;
case int16:
masked_scatter_impl<int16_t>(mask, src, out);
break;
case int32:
masked_scatter_impl<int32_t>(mask, src, out);
break;
case int64:
masked_scatter_impl<int64_t>(mask, src, out);
break;
case float16:
masked_scatter_impl<float16_t>(mask, src, out);
break;
case float32:
masked_scatter_impl<float>(mask, src, out);
break;
case float64:
masked_scatter_impl<double>(mask, src, out);
break;
case bfloat16:
masked_scatter_impl<bfloat16_t>(mask, src, out);
break;
case complex64:
masked_scatter_impl<complex64_t>(mask, src, out);
break;
}
});
}
} // namespace mlx::core

View File

@@ -45,9 +45,7 @@
INSTANTIATE_LAPACK_REAL(geqrf)
INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesdd)
INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri)
@@ -63,3 +61,20 @@ INSTANTIATE_LAPACK_REAL(trtri)
}
INSTANTIATE_LAPACK_COMPLEX(heevd)
#define INSTANTIATE_LAPACK_ALL(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, double>) { \
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_ALL(geev)
INSTANTIATE_LAPACK_ALL(gesdd)

View File

@@ -8,6 +8,183 @@
namespace mlx::core {
template <typename T, class Enable = void>
struct SVDWork {};
template <typename T>
struct SVDWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <>
struct SVDWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
const int lrwork =
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
int lwork_query = -1;
int work_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension.real();
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <typename T>
void svd_impl(
const array& a,
@@ -27,6 +204,8 @@ void svd_impl(
const int N = a.shape(-1);
const int K = std::min(M, N);
using R = typename SVDWork<T>::R;
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
@@ -42,7 +221,7 @@ void svd_impl(
encoder.set_input_array(a);
auto in_ptr = in.data<T>();
T* u_ptr;
T* s_ptr;
R* s_ptr;
T* vt_ptr;
if (compute_uv) {
@@ -58,7 +237,7 @@ void svd_impl(
encoder.set_output_array(s);
encoder.set_output_array(vt);
s_ptr = s.data<T>();
s_ptr = s.data<R>();
u_ptr = u.data<T>();
vt_ptr = vt.data<T>();
} else {
@@ -68,96 +247,26 @@ void svd_impl(
encoder.set_output_array(s);
s_ptr = s.data<T>();
s_ptr = s.data<R>();
u_ptr = nullptr;
vt_ptr = nullptr;
}
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
auto jobz = (u_ptr) ? "A" : "N";
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
static const int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
auto jobz = (u_ptr) ? 'A' : 'N';
SVDWork<T> svd_work(N, M, K, jobz);
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
svd_work.run(
in_ptr + M * N * i,
s_ptr + K * i,
vt_ptr ? vt_ptr + N * N * i : nullptr,
u_ptr ? u_ptr + M * M * i : nullptr);
}
});
encoder.add_temporary(in);
}
template <typename T>
void compute_svd(
const array& a,
bool compute_uv,
std::vector<array>& outputs,
Stream stream) {}
void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -168,9 +277,12 @@ void SVD::eval_cpu(
case float64:
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
break;
case complex64:
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
break;
default:
throw std::runtime_error(
"[SVD::eval_cpu] only supports float32 or float64.");
"[SVD::eval_cpu] only supports float32, float64, or complex64.");
}
}

View File

@@ -44,6 +44,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
@@ -125,7 +126,11 @@ endif()
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
# managed memory.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES "native")
execute_process(
COMMAND bash detect_cuda_arch.sh
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
OUTPUT_STRIP_TRAILING_WHITESPACE)
endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
@@ -137,6 +142,7 @@ FetchContent_Declare(
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl)
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
set_target_properties(mlx PROPERTIES CCCL_DIR "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX.
FetchContent_Declare(
@@ -162,7 +168,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.14.0
GIT_TAG v1.16.0
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)

View File

@@ -92,7 +92,7 @@ CudaAllocator::CudaAllocator()
[this](CudaBuffer* buf) { cuda_free(buf); }) {
size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.95;
memory_limit_ = total * 0.9;
max_pool_size_ = memory_limit_;
int device_count = 0;
@@ -119,7 +119,8 @@ void copy_to_managed(CudaBuffer& buf) {
buf.data = new_data;
}
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
Buffer
CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
if (size == 0) {
return Buffer{new CudaBuffer{nullptr, 0, -1}};
}
@@ -134,9 +135,8 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
size = page_size * ((size + page_size - 1) / page_size);
}
int device = -1;
if (size > small_block_size && stream != nullptr) {
CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device));
if (size <= small_block_size || stream == nullptr) {
device = -1;
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
@@ -176,18 +176,14 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
// Copy to managed here if the buffer is not on the right device
if (buf->device != device) {
if (buf->device >= 0 && buf->device != device) {
copy_to_managed(*buf);
}
return Buffer{buf};
}
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
return malloc_impl(size, stream);
}
Buffer CudaAllocator::malloc(size_t size) {
return malloc_impl(size, nullptr);
return malloc_async(size, -1, nullptr);
}
void CudaAllocator::free(Buffer buffer) {
@@ -223,9 +219,9 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
scalar_pool_.free(buf);
} else {
if (buf->device >= 0) {
cudaFreeAsync(buf->data, free_streams_[buf->device]);
CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
} else {
cudaFree(buf->data);
CHECK_CUDA_ERROR(cudaFree(buf->data));
}
delete buf;
}
@@ -277,8 +273,9 @@ CudaAllocator& allocator() {
return *allocator_;
}
Buffer malloc_async(size_t size, cudaStream_t stream) {
auto buffer = allocator().malloc_async(size, stream);
Buffer malloc_async(size_t size, CommandEncoder& encoder) {
auto buffer = allocator().malloc_async(
size, encoder.device().cuda_device(), encoder.stream());
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes.";

View File

@@ -13,6 +13,8 @@
namespace mlx::core::cu {
class CommandEncoder;
using allocator::Buffer;
// Stores cuda-managed unified memory.
@@ -48,7 +50,7 @@ class SmallSizePool {
class CudaAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, cudaStream_t stream);
Buffer malloc_async(size_t size, int device, cudaStream_t stream);
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
@@ -62,7 +64,6 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache();
private:
Buffer malloc_impl(size_t size, cudaStream_t stream);
void cuda_free(CudaBuffer* buf);
CudaAllocator();
@@ -80,6 +81,6 @@ class CudaAllocator : public allocator::Allocator {
CudaAllocator& allocator();
Buffer malloc_async(size_t size, cudaStream_t stream);
Buffer malloc_async(size_t size, CommandEncoder& encoder);
} // namespace mlx::core::cu

View File

@@ -42,7 +42,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {

View File

@@ -143,7 +143,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
// Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_);

View File

@@ -367,9 +367,8 @@ void binary_op_gpu(
auto bopt = get_binary_op_type(a, b);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(
a, b, out, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
binary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -246,12 +246,10 @@ void binary_two_op_gpu_inplace(
auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b);
auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(
a, b, out_a, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
set_binary_op_output_data(
a, b, out_b, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
if (out_a.size() == 0) {
return;

View File

@@ -298,7 +298,7 @@ void Compiled::eval_gpu(
// Put outputs.
compiled_allocate_outputs(
inputs, outputs, is_constant_, contiguous, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
return cu::malloc_async(n, encoder);
});
for (auto& x : outputs) {
args.append(x);

View File

@@ -277,11 +277,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
array in = inputs[0];
array wt = inputs[1];
array out = out_;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
Dtype dtype = out.dtype();
// Search cache.
ConvCacheKey cache_key{
BytesKey<ConvCacheKey> cache_key;
cache_key.pod = {
encoder.device().cuda_device(),
dtype_to_cudnn_type(dtype),
vector_key(in.shape()),

View File

@@ -86,7 +86,7 @@ array unfold_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
encoder.add_temporary(unfolded);
int filter_size = params.C;

View File

@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
int mat_N,
ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
encoder.add_temporary(unfolded);
int filter_size = params.C;

View File

@@ -7,9 +7,8 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
auto& encoder = cu::get_command_encoder(s);
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
bool donated = set_copy_output_data(
in, out, ctype, [&](auto n) { return cu::malloc_async(n, encoder); });
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
@@ -104,7 +103,7 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
return;
}
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_input_array(in);
encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
@@ -114,7 +113,7 @@ void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
copy_gpu_inplace(
in,
out,

View File

@@ -135,9 +135,7 @@ bool prepare_cudnn_plan(
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder.stream()),
{workspace_size},
uint8);
cu::malloc_async(workspace_size, encoder), {workspace_size}, uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}

View File

@@ -44,13 +44,13 @@ inline SmallVector<T> convert_vector(const Vec& vec) {
// There are 2 differences from the const_param util from kernel_utils.cuh:
// 1. The rest of array is filled with 0.
// 2. This util can be used in .cpp files.
template <typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) {
template <int NDIM = MAX_NDIM, typename T, template <typename U> class Vec>
inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
fmt::format("ndim can not be larger than {}.", NDIM));
}
std::array<T, MAX_NDIM> result = {};
std::array<T, NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin());
return result;
}

View File

@@ -57,7 +57,7 @@ std::string build_kernel(
const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) {
const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header;
@@ -81,17 +81,17 @@ std::string build_kernel(
kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
if (std::get<0>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Shape ";
kernel_source += name;
kernel_source += "_shape,\n";
}
if (shape_infos[i].strides) {
if (std::get<1>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Strides ";
kernel_source += name;
kernel_source += "_strides,\n";
}
if (shape_infos[i].ndim) {
if (std::get<2>(shape_infos[i])) {
kernel_source += " const __grid_constant__ int ";
kernel_source += name;
kernel_source += "_ndim,\n";
@@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel(
"[custom_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
std::vector<std::tuple<bool, bool, bool>> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
std::tuple<bool, bool, bool> shape_info;
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
@@ -254,8 +254,8 @@ std::vector<array> precompiled_cuda_kernel(
std::optional<float> init_value,
bool ensure_row_contiguous,
StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false});
std::vector<std::tuple<bool, bool, bool>> shape_infos(
inputs.size(), {false, false, false});
return array::make_arrays(
output_shapes,
output_dtypes,
@@ -289,7 +289,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s);
} else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
}
}
@@ -327,13 +327,13 @@ void CustomKernel::eval_gpu(
const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i];
args.append(in);
if (shape_info.shape) {
if (std::get<0>(shape_info)) {
args.append_ndim(in.shape());
}
if (shape_info.strides) {
if (std::get<1>(shape_info)) {
args.append_ndim(in.strides());
}
if (shape_info.ndim) {
if (std::get<2>(shape_info)) {
args.append<int32_t>(in.ndim());
}
}

View File

@@ -0,0 +1,13 @@
#!/bin/bash
arch=`__nvcc_device_query`
case "$arch" in
"90")
echo "90a" ;;
"100")
echo "100a" ;;
"121")
echo "121a" ;;
*)
echo "native" ;;
esac

View File

@@ -46,6 +46,7 @@ Device::Device(int device) : device_(device) {
"Device {} does not support synchronization in managed memory.",
device_));
}
// The cublasLt handle is used by matmul.
make_current();
CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_));
@@ -114,18 +115,17 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
}
// Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
enc.empty_node_count_++;
CommandEncoder::GraphNode empty{NULL, "E", std::to_string(enc.node_count_++)};
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
// Insert the concurrent -> empty node dependencies
for (auto& from : enc.concurrent_nodes_) {
enc.from_nodes_.push_back(from.node);
enc.to_nodes_.push_back(empty.node);
enc.graph_key_ += from.id;
enc.graph_key_ += from.node_type;
enc.graph_key_ += empty.id;
enc.graph_key_ += empty.node_type;
enc.graph_deps_key_ += from.id;
enc.graph_deps_key_ += "-";
enc.graph_deps_key_ += empty.id;
enc.graph_deps_key_ += "-";
}
// Insert the input -> concurrent node dependencies without updating output
@@ -140,9 +140,6 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
}
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
if (node.node_type == 'G') {
graph_node_count_++;
}
node.id = std::to_string(node_count_++);
if (in_concurrent_) {
concurrent_nodes_.push_back(std::move(node));
@@ -154,6 +151,10 @@ void CommandEncoder::insert_graph_dependencies(GraphNode node) {
}
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
for (auto& node : nodes) {
graph_nodes_key_ += node.node_type;
graph_nodes_key_ += "-";
}
std::vector<GraphNode> deps;
{
// Dependencies must be added in the same order to produce a consistent
@@ -181,20 +182,49 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
for (auto& to : nodes) {
from_nodes_.push_back(from.node);
to_nodes_.push_back(to.node);
graph_key_ += from.id;
graph_key_ += from.node_type;
graph_key_ += to.id;
graph_key_ += to.node_type;
graph_deps_key_ += from.id;
graph_deps_key_ += "-";
graph_deps_key_ += to.id;
graph_deps_key_ += "-";
}
}
}
// Can be tuned with MLX_MAX_OPS_PER_BUFFER, MLX_MAX_MB_PER_BUFFER
std::pair<int, int> get_graph_limits(Device& d) {
auto cc =
d.compute_capability_major() * 100 + d.compute_capability_minor() * 10;
int ops = 20;
int mb = 100;
switch (cc) {
case 800: // A100
ops = 20;
mb = 400;
break;
case 900: // H100
ops = 30;
mb = 400;
break;
case 1000: // B200
ops = 50;
mb = 500;
break;
case 1210: // DGX Spark
ops = 20;
mb = 25;
break;
}
return {env::max_ops_per_buffer(ops), env::max_mb_per_buffer(mb)};
}
CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {
std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d);
}
void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task));
@@ -204,6 +234,7 @@ void CommandEncoder::set_input_array(const array& arr) {
if (!use_cuda_graphs()) {
return;
}
bytes_in_graph_ += arr.data_size();
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id);
}
@@ -278,13 +309,46 @@ void CommandEncoder::add_kernel_node(
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
insert_graph_dependencies(GraphNode{node, "K"});
}
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
CUgraphNode node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'});
insert_graph_dependencies(GraphNode{node, "K"});
}
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
// CUDA graphs do not get updated correctly if a kernel node getting updated
// has a different cluster shape than the node it's being updated with.
size_t num_nodes = 0;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) {
return true;
}
std::vector<cudaGraphNode_t> nodes(num_nodes);
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
for (const auto& node : nodes) {
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
if (type != cudaGraphNodeTypeKernel) {
return false;
}
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only dim.x can be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
return false;
}
// Only one child node allowed when subgraph uses clusters
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
return false;
}
cluster_dim_x = cluster_dim.clusterDim.x;
}
return true;
}
void CommandEncoder::add_graph_node(cudaGraph_t child) {
@@ -297,12 +361,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
return;
}
cudaGraphNode_t node;
int cluster_dim_x = 0;
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'});
insert_graph_dependencies(
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
}
int CommandEncoder::get_num_ops() {
return node_count_;
bool CommandEncoder::needs_commit() {
return (node_count_ > max_ops_per_graph_) ||
((bytes_in_graph_ >> 20) > max_mb_per_graph_);
}
void CommandEncoder::commit() {
@@ -322,53 +390,55 @@ void CommandEncoder::commit() {
from_nodes_.size()));
}
graph_key_ += ".";
graph_key_ += std::to_string(node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(graph_node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_);
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graph_exec, graph_, &info);
update_result = info.result;
#else
cudaGraphNode_t error_node;
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
graph_exec.reset();
}
}
if (graph_exec == nullptr) {
graph_exec.instantiate(graph_);
}
device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
if (!is_graph_updatable_) {
CudaGraphExec graph_exec;
graph_exec.instantiate(graph_);
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
} else {
auto graph_key = graph_nodes_key_ + ":" + graph_deps_key_;
auto& graph_exec = graph_cache_[graph_key];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graph_exec, graph_, &info);
update_result = info.result;
#else
cudaGraphNode_t error_node;
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
graph_exec.reset();
}
}
if (graph_exec == nullptr) {
graph_exec.instantiate(graph_);
}
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
}
// Reset state
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear();
to_nodes_.clear();
graph_key_.clear();
graph_deps_key_.clear();
graph_nodes_key_.clear();
node_map_.clear();
graph_ = CudaGraph(device_);
is_graph_updatable_ = true;
}
// Put completion handlers in a batch.
worker_.commit(stream_);
node_count_ = 0;
bytes_in_graph_ = 0;
}
void CommandEncoder::synchronize() {
cudaStreamSynchronize(stream_);
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream_));
auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); });

View File

@@ -84,7 +84,7 @@ class CommandEncoder {
}
void add_completed_handler(std::function<void()> task);
int get_num_ops();
bool needs_commit();
void commit();
Device& device() {
@@ -106,8 +106,9 @@ class CommandEncoder {
cudaGraphNode_t node;
// K = kernel
// E = empty
// G = subgraph
char node_type;
// G* = subgraph (with metadata)
// Symbols ':', '-' are reserved as separators
std::string node_type;
std::string id;
};
@@ -119,18 +120,21 @@ class CommandEncoder {
CudaGraph graph_;
Worker worker_;
char node_count_{0};
char graph_node_count_{0};
char empty_node_count_{0};
bool in_concurrent_{false};
std::vector<cudaGraphNode_t> from_nodes_;
std::vector<cudaGraphNode_t> to_nodes_;
std::string graph_key_;
std::string graph_nodes_key_;
std::string graph_deps_key_;
std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_;
LRUCache<std::string, CudaGraphExec> graph_cache_;
std::vector<std::uintptr_t> active_deps_;
std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
size_t bytes_in_graph_{0};
bool is_graph_updatable_{true};
int max_ops_per_graph_;
int max_mb_per_graph_;
};
class Device {
@@ -166,6 +170,7 @@ class Device {
int device_;
int compute_capability_major_;
int compute_capability_minor_;
std::string device_name_;
cublasLtHandle_t lt_;
cudnnHandle_t cudnn_;
std::unordered_map<int, CommandEncoder> encoders_;

View File

@@ -26,7 +26,7 @@ void AllReduce::eval_gpu(
out.copy_shared_buffer(in);
return {in, out};
} else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
return {in, out};
}
};
@@ -74,7 +74,7 @@ void AllGather::eval_gpu(
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);
@@ -103,7 +103,7 @@ void ReduceScatter::eval_gpu(
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);

View File

@@ -11,9 +11,6 @@
namespace mlx::core::gpu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;
bool is_available() {
return true;
}
@@ -53,8 +50,7 @@ void eval(array& arr) {
encoder.add_temporary(s);
}
if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
if (encoder.needs_commit()) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });

View File

@@ -370,7 +370,7 @@ void CublasGemm::execute(
// Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace(
cu::malloc_async(nbytes, encoder.stream()),
cu::malloc_async(nbytes, encoder),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);

View File

@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder),
{batch_count * 3},
uint64);
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets
auto pointers = array(
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder),
{batch_count * 4},
uint64);

View File

@@ -61,7 +61,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) {
return;
}
@@ -241,7 +241,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) {
return;
}

View File

@@ -279,11 +279,14 @@ void compile(
// Compile program.
std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device);
auto cc = device.compute_capability_major();
std::string arch_tag = (cc == 90 || cc == 100 || cc == 121) ? "a" : "";
std::string compute = fmt::format(
"--gpu-architecture={}_{}{}",
"--gpu-architecture={}_{}{}{}",
use_sass ? "sm" : "compute",
device.compute_capability_major(),
device.compute_capability_minor());
cc,
device.compute_capability_minor(),
arch_tag);
args.push_back(compute.c_str());
std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) {

View File

@@ -244,7 +244,7 @@ void LayerNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(),
x.strides(),
x.flags());
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
g_in_gw = true;
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
encoder.add_temporary(gw_temp);
}
}

View File

@@ -32,7 +32,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream());
auto size = out.size();
auto nbytes = size * out.itemsize();
out.set_data(cu::malloc_async(nbytes, encoder.stream()));
out.set_data(cu::malloc_async(nbytes, encoder));
auto out_ptr = malloc(nbytes);
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
if (swap_endianness_) {

View File

@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else {
auto n = in.shape(-1);
auto flags = in.flags();
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
}
flags.col_contiguous = col_contig;
out.set_data(
cu::malloc_async(in.nbytes() / n, encoder.stream()),
cu::malloc_async(in.nbytes() / n, encoder),
in.data_size() / n,
std::move(strides),
flags);

View File

@@ -135,12 +135,19 @@ class LRUCache {
};
// Turn a POD struct into a container key by doing bytes compare.
//
// Usage:
// BytesKey<MyKey> key;
// key.pod = { ... };
template <typename T>
struct BytesKey {
T pod;
static_assert(std::is_standard_layout_v<T>, "T is not POD");
BytesKey(T pod) : pod(std::move(pod)) {}
BytesKey() {
// Make sure the paddings between members are filled with 0.
memset(&pod, 0, sizeof(T));
}
BytesKey(const BytesKey& other) {
memcpy(&pod, &other.pod, sizeof(T));

View File

@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return;
}
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
int M = a_pre.shape(-2);
int N = b_pre.shape(-1);
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
gemm_and_bias(
encoder,
M,
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);

View File

@@ -37,6 +37,7 @@ NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
NO_GPU(MaskedScatter)
namespace distributed {
NO_GPU_MULTI(Send)

View File

@@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu(
auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto& w = outputs[0];
w.set_data(cu::malloc_async(w.nbytes(), enc.stream()));
w.set_data(cu::malloc_async(w.nbytes(), enc));
if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[2], enc, s);
@@ -72,11 +72,11 @@ void fast::Quantize::eval_gpu(
auto& wq = outputs[0];
auto& scales = outputs[1];
wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream()));
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream()));
wq.set_data(cu::malloc_async(wq.nbytes(), enc));
scales.set_data(cu::malloc_async(scales.nbytes(), enc));
if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2];
biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream()));
biases.set_data(cu::malloc_async(biases.nbytes(), enc));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
} else {
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);

View File

@@ -139,30 +139,36 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
uint32_t num_keys = keys.size() / 2;
size_t num_keys = keys.size() / 2;
uint32_t elems_per_key = out.size() / num_keys;
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) {
return;
}
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
uint32_t half_size = out_per_key / 2;
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
size_t half_size = out_per_key / 2;
bool odd = out_per_key % 2;
if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {
throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported");
}
encoder.set_input_array(keys);
encoder.set_output_array(out);
dim3 grid_dims{num_keys, half_size + odd};
int64_t total = grid_dims.x * grid_dims.y;
int32_t threads_y = 1;
while ((total / threads_y) >= (1U << 31)) {
int64_t total = num_keys * (half_size + odd);
uint32_t threads_y = 1;
while ((total / threads_y) >= UINT_MAX) {
threads_y *= 2;
}
int32_t threads_x = cuda::ceil_div(total, threads_y);
uint32_t threads_x = cuda::ceil_div(total, threads_y);
dim3 grid_dims{
static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
auto& stream = encoder.stream();
if (keys.flags().row_contiguous) {

View File

@@ -66,7 +66,7 @@ void all_reduce(
Reduce::ReduceType reduce_type) {
constexpr int N_READS = 8;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
auto get_args = [](size_t size, int N) {
int threads = std::min(512UL, (size + N - 1) / N);
@@ -107,8 +107,7 @@ void all_reduce(
encoder.set_input_array(in);
if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate);
dispatch_all_types(dt, [&](auto type_tag) {

View File

@@ -28,7 +28,7 @@ void init_reduce(
Reduce::ReduceType reduce_type) {
// Allocate if needed
if (out.data_shared_ptr() == nullptr) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
}
encoder.set_output_array(out);

View File

@@ -96,7 +96,7 @@ inline void allocate_same_layout(
const std::vector<int>& axes,
cu::CommandEncoder& encoder) {
if (in.flags().row_contiguous) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
return;
}
@@ -135,7 +135,7 @@ inline void allocate_same_layout(
fl.col_contiguous = cc;
fl.contiguous = true;
out.set_data(
cu::malloc_async(out.nbytes(), encoder.stream()),
cu::malloc_async(out.nbytes(), encoder),
data_size,
final_strides,
fl,

View File

@@ -190,7 +190,7 @@ void RMSNorm::eval_gpu(
out.copy_shared_buffer(x);
} else {
out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(),
x.strides(),
x.flags());
@@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu(
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
}
if (g_copied && !g_in_gx) {
encoder.add_temporary(g);
@@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu(
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
encoder.add_temporary(gw_temp);
}
}

View File

@@ -292,14 +292,14 @@ void RoPE::eval_gpu(
donated = true;
out.copy_shared_buffer(in);
} else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];

View File

@@ -0,0 +1,537 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cudnn_utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast_primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace fe = cudnn_frontend;
namespace {
#define CHECK_CUDNN_FE_ERROR(cmd) \
do { \
auto error = cmd; \
if (!error.is_good()) { \
throw std::runtime_error( \
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
} \
} while (0)
std::vector<int64_t> normalized_strides(const array& x) {
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
if (std::all_of(
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
strides.back() = 1;
return strides;
}
if (!x.flags().row_contiguous || x.ndim() < 2) {
return strides;
}
for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1];
}
}
return strides;
}
void set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x) {
tensor->set_uid(uid)
.set_dim({x.shape().begin(), x.shape().end()})
.set_stride(normalized_strides(x));
}
array prepare_sdpa_input(const array& x, Stream s) {
// SDPA kernel's requirements on inputs:
// 1. last dim's stride be 1;
// 2. pointer be aligned.
if (x.strides(-1) != 1 || get_alignment(x) < 16) {
array x_copy = contiguous_copy_gpu(x, s);
auto& encoder = cu::get_command_encoder(s);
encoder.add_temporary(x_copy);
return x_copy;
}
return x;
}
constexpr int QKV_NDIM = 4;
struct SDPACacheKey {
int device_id;
cudnnDataType_t cudnn_dtype;
std::array<int, QKV_NDIM> q_shape;
std::array<int, QKV_NDIM> k_shape;
std::array<int, QKV_NDIM> v_shape;
std::array<int64_t, QKV_NDIM> q_strides;
std::array<int64_t, QKV_NDIM> k_strides;
std::array<int64_t, QKV_NDIM> v_strides;
bool do_causal;
bool output_logsumexp;
};
inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
bool do_causal,
bool output_logsumexp = true) {
BytesKey<SDPACacheKey> cache_key;
cache_key.pod = {
encoder.device().cuda_device(),
dtype_to_cudnn_type(q.dtype()),
vector_key<QKV_NDIM>(q.shape()),
vector_key<QKV_NDIM>(k.shape()),
vector_key<QKV_NDIM>(v.shape()),
vector_key<QKV_NDIM>(q.strides()),
vector_key<QKV_NDIM>(k.strides()),
vector_key<QKV_NDIM>(v.strides()),
do_causal,
output_logsumexp,
};
return cache_key;
}
auto& sdpa_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16);
return cache;
}
auto& sdpa_backward_cache() {
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16);
return cache;
}
enum UIDS {
Q,
K,
V,
SCALE,
O,
STATS,
// Backward graph:
D_Q,
D_K,
D_V,
D_O,
};
fe::graph::Graph build_sdpa_graph(
cudnnHandle_t handle,
const array& q,
const array& k,
const array& v,
bool do_causal,
bool output_logsumexp,
const array& o,
const array& stats) {
auto dtype = fe::DataType_t::HALF;
if (q.dtype() == bfloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
fe::graph::Graph graph;
graph.set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
set_tensor_attrs(q_, Q, q);
set_tensor_attrs(k_, K, k);
set_tensor_attrs(v_, V, v);
auto scale = graph.tensor(fe::graph::Tensor_attributes()
.set_name("Scale")
.set_uid(SCALE)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto options = fe::graph::SDPA_attributes()
.set_name("sdpa_cudnn")
.set_attn_scale(scale)
.set_causal_mask(do_causal)
.set_generate_stats(output_logsumexp);
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
o_->set_output(true);
set_tensor_attrs(o_, O, o);
if (output_logsumexp) {
stats_->set_output(true).set_data_type(fe::DataType_t::FLOAT);
set_tensor_attrs(stats_, STATS, stats);
}
CHECK_CUDNN_FE_ERROR(graph.validate());
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
graph.select_behavior_notes(
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
return graph;
}
fe::graph::Graph build_sdpa_backward_graph(
cudnnHandle_t handle,
const array& q,
const array& k,
const array& v,
bool do_causal,
const array& o,
const array& d_o,
const array& stats,
array& d_q,
array& d_k,
array& d_v) {
auto dtype = fe::DataType_t::HALF;
if (q.dtype() == bfloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
fe::graph::Graph graph;
graph.set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
auto o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("O"));
auto d_o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("D_O"));
auto stats_ = graph.tensor(fe::graph::Tensor_attributes().set_name("STATS"));
set_tensor_attrs(q_, Q, q);
set_tensor_attrs(k_, K, k);
set_tensor_attrs(v_, V, v);
set_tensor_attrs(o_, O, o);
set_tensor_attrs(d_o_, D_O, d_o);
set_tensor_attrs(stats_, STATS, stats);
stats_->set_data_type(fe::DataType_t::FLOAT);
auto scale = graph.tensor(fe::graph::Tensor_attributes()
.set_name("Scale")
.set_uid(SCALE)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto options = fe::graph::SDPA_backward_attributes()
.set_name("sdpa_backward_cudnn")
.set_attn_scale(scale)
.set_causal_mask(do_causal);
auto [d_q_, d_k_, d_v_] =
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
d_q_->set_output(true);
d_k_->set_output(true);
d_v_->set_output(true);
set_tensor_attrs(d_q_, D_Q, d_q);
set_tensor_attrs(d_k_, D_K, d_k);
set_tensor_attrs(d_v_, D_V, d_v);
CHECK_CUDNN_FE_ERROR(graph.validate());
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
graph.select_behavior_notes(
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
return graph;
}
void execute_graph(
cu::CommandEncoder& encoder,
cudnnHandle_t handle,
fe::graph::Graph& graph,
std::unordered_map<int64_t, void*>& variant_pack) {
int64_t workspace_size = 0;
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder),
{static_cast<int>(workspace_size)},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
cudnnSetStream(handle, encoder.stream());
CudaGraph cuda_graph(encoder.device());
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
handle, variant_pack, workspace_ptr, cuda_graph));
encoder.add_graph_node(cuda_graph);
}
} // namespace
bool supports_sdpa_cudnn(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool do_causal,
Stream s) {
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
if (!enabled) {
return false;
}
// cuDNN SDPA requires Ampere and later.
if (cu::device(s.device).compute_capability_major() < 8) {
return false;
}
if (has_mask) {
// TODO: Support array masks.
if (!do_causal) {
return false;
}
// FIXME: Causal mask generates wrong results when L_Q != L_K.
if (q.shape(2) != k.shape(2)) {
return false;
}
}
// Only use cuDNN for prefilling and training.
if (q.shape(2) != k.shape(2)) {
return false;
}
// D_qk and D_v must be a multiple of 8 with maximum value 128.
if ((q.shape(-1) % 8 != 0) || (q.shape(-1) > 128) || (v.shape(-1) % 8 != 0) ||
(v.shape(-1) > 128)) {
return false;
}
Dtype dtype = q.dtype();
return dtype == float16 || dtype == bfloat16;
}
void sdpa_cudnn(
const array& q,
const array& k,
const array& v,
float scale,
array& o,
array& stats,
bool do_causal,
bool output_logsumexp,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
// TODO: Handle donation.
// TODO: Make O use same memory layout with Q.
o.set_data(cu::malloc_async(o.nbytes(), encoder));
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
if (output_logsumexp) {
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
encoder.set_output_array(stats);
}
// Search cache.
auto cache_key =
build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp);
auto it = sdpa_cache().find(cache_key);
if (it == sdpa_cache().end()) {
auto graph = build_sdpa_graph(
handle, q, k, v, do_causal, output_logsumexp, o, stats);
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
std::unordered_map<int64_t, void*> variant_pack{
{Q, const_cast<void*>(gpu_ptr<void>(q))},
{K, const_cast<void*>(gpu_ptr<void>(k))},
{V, const_cast<void*>(gpu_ptr<void>(v))},
{SCALE, &scale},
{O, gpu_ptr<void>(o)}};
if (output_logsumexp) {
variant_pack[STATS] = gpu_ptr<void>(stats);
}
execute_graph(encoder, handle, graph, variant_pack);
}
void sdpa_backward_cudnn(
const array& q,
const array& k,
const array& v,
float scale,
const array& o,
const array& stats,
bool do_causal,
const array& d_o,
array& d_q,
array& d_k,
array& d_v,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
// TODO: Handle donation.
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_input_array(o);
encoder.set_input_array(stats);
encoder.set_input_array(d_o);
encoder.set_output_array(d_q);
encoder.set_output_array(d_k);
encoder.set_output_array(d_v);
// Search cache.
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal);
auto it = sdpa_backward_cache().find(cache_key);
if (it == sdpa_backward_cache().end()) {
auto graph = build_sdpa_backward_graph(
handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v);
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
std::unordered_map<int64_t, void*> variant_pack{
{Q, const_cast<void*>(gpu_ptr<void>(q))},
{K, const_cast<void*>(gpu_ptr<void>(k))},
{V, const_cast<void*>(gpu_ptr<void>(v))},
{SCALE, &scale},
{O, const_cast<void*>(gpu_ptr<void>(o))},
{STATS, const_cast<void*>(gpu_ptr<void>(stats))},
{D_O, const_cast<void*>(gpu_ptr<void>(d_o))},
{D_Q, gpu_ptr<void>(d_q)},
{D_K, gpu_ptr<void>(d_k)},
{D_V, gpu_ptr<void>(d_v)}};
execute_graph(encoder, handle, graph, variant_pack);
}
// Defined in scaled_dot_product_attention.cu file.
bool supports_sdpa_vector(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
bool output_logsumexp);
void sdpa_vector(
const array& q,
const array& k,
const array& v,
float scale,
array& o,
bool do_causal,
const std::optional<array>& sinks,
Stream s);
namespace fast {
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
bool is_training,
bool output_logsumexp,
Stream s) {
if (s.device == Device::cpu) {
return true;
}
return !supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
auto& s = stream();
array q = prepare_sdpa_input(inputs[0], s);
array k = prepare_sdpa_input(inputs[1], s);
array v = prepare_sdpa_input(inputs[2], s);
auto& out = outputs[0];
auto& stats = outputs[1];
bool has_mask = inputs.size() - has_sinks_ > 3;
bool has_arr_mask = has_mask && !do_causal_;
if (supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
if (has_sinks_) {
sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s);
} else {
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
}
} else {
sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s);
}
}
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
// The frontend adds a padding mask when sequence length is not a multiple of
// tile size.
if (q.shape(2) % 128 != 0) {
return true;
}
return s.device == Device::cpu;
}
void ScaledDotProductAttentionVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ScaledDotProductAttentionVJP::eval_gpu");
auto& s = stream();
assert(inputs.size() == 6);
array q = prepare_sdpa_input(inputs[0], s);
array k = prepare_sdpa_input(inputs[1], s);
array v = prepare_sdpa_input(inputs[2], s);
array o = prepare_sdpa_input(inputs[3], s);
array stats = prepare_sdpa_input(inputs[4], s);
array d_o = prepare_sdpa_input(inputs[5], s);
assert(outputs.size() == 3);
auto& d_q = outputs[0];
auto& d_k = outputs[1];
auto& d_v = outputs[2];
sdpa_backward_cudnn(
q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s);
}
} // namespace fast
} // namespace mlx::core

View File

@@ -6,10 +6,6 @@
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include <nvtx3/nvtx3.hpp>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
@@ -565,10 +561,9 @@ void sdpa_vector_2pass_fallback(
array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream()));
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
sums.set_data(cu::malloc_async(sums.nbytes(), encoder));
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder));
encoder.add_temporary(intermediate);
encoder.add_temporary(sums);
@@ -663,21 +658,16 @@ void sdpa_vector_fallback(
} // namespace
namespace fast {
bool ScaledDotProductAttention::use_fallback(
bool supports_sdpa_vector(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
bool output_logsumexp) {
if (output_logsumexp) {
return false;
}
const int value_head_dim = v.shape(-1);
@@ -691,29 +681,24 @@ bool ScaledDotProductAttention::use_fallback(
const bool supported_vector_config =
sdpa_supported_head_dim && query_sequence_length < 4;
const bool supported_config = supported_vector_config;
return has_arr_mask || !supported_config;
return supported_vector_config && !has_arr_mask;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
auto& s = stream();
void sdpa_vector(
const array& q_pre,
const array& k_pre,
const array& v_pre,
float scale,
array& o,
bool do_causal,
const std::optional<array>& sinks_pre,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out;
std::vector<array> copies;
// Define some copy functions to ensure the layout of the inputs is as
// expected.
copies.reserve(inputs.size());
copies.reserve(4);
auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) {
@@ -731,8 +716,8 @@ void ScaledDotProductAttention::eval_gpu(
};
std::optional<array> sinks = std::nullopt;
if (has_sinks_) {
sinks = copy_unless(is_matrix_contiguous, inputs.back());
if (sinks_pre) {
sinks = copy_unless(is_matrix_contiguous, sinks_pre.value());
}
// We are in vector mode ie single query
@@ -788,7 +773,7 @@ void ScaledDotProductAttention::eval_gpu(
};
o.set_data(
cu::malloc_async(o.nbytes(), encoder.stream()),
cu::malloc_async(o.nbytes(), encoder),
o.size(),
{str_oB, str_oH, str_oL, str_oD},
flags);
@@ -798,8 +783,7 @@ void ScaledDotProductAttention::eval_gpu(
encoder.add_temporary(cp);
}
return sdpa_vector_fallback(
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks);
}
// Full attention mode should never reach here
@@ -808,6 +792,4 @@ void ScaledDotProductAttention::eval_gpu(
}
}
} // namespace fast
} // namespace mlx::core

View File

@@ -374,7 +374,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(in);
} else {
out.set_data(
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
in.data_size(),
in.strides(),
in.flags());

View File

@@ -24,7 +24,7 @@ void concatenate_gpu(
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
out.set_data(cu::malloc_async(out.nbytes(), encoder));
auto strides = out.strides();
auto flags = out.flags();
@@ -89,7 +89,7 @@ array compute_dynamic_offset(
if (donate) {
offset.copy_shared_buffer(indices);
} else {
offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream()));
offset.set_data(cu::malloc_async(offset.itemsize(), encoder));
}
encoder.add_temporary(offset);

View File

@@ -118,7 +118,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(x);
} else {
out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(),
x.strides(),
x.flags());

View File

@@ -49,14 +49,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array trans = swapaxes_in_eval(in, axis, last_dim);
in = contiguous_copy_gpu(trans, s);
encoder.add_temporary(in);
out = array(
cu::malloc_async(out.nbytes(), encoder.stream()),
in.shape(),
out.dtype());
out =
array(cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
encoder.add_temporary(out);
} else {
out.set_data(
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
in.data_size(),
in.strides(),
in.flags());
@@ -74,17 +72,13 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
if (argsort) {
// Indices in the sorted dimension.
array indices(
cu::malloc_async(out.nbytes(), encoder.stream()),
in.shape(),
out.dtype());
cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
encoder.add_temporary(indices);
// In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it.
array discard(
cu::malloc_async(in.nbytes(), encoder.stream()),
in.shape(),
in.dtype());
cu::malloc_async(in.nbytes(), encoder), in.shape(), in.dtype());
encoder.add_temporary(discard);
size_t size;
@@ -104,9 +98,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
stream));
array temp(
cu::malloc_async(size, encoder.stream()),
{static_cast<int>(size)},
uint8);
cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Start capturing after allocations
@@ -148,9 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
stream));
array temp(
cu::malloc_async(size, encoder.stream()),
{static_cast<int>(size)},
uint8);
cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
encoder.add_temporary(temp);
// Start capturing after allocations

View File

@@ -257,9 +257,8 @@ void ternary_op_gpu(
auto& c = inputs[2];
auto topt = get_ternary_op_type(a, b, c);
auto& encoder = cu::get_command_encoder(s);
set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_ternary_op_output_data(
a, b, c, out, topt, [&](auto n) { return cu::malloc_async(n, encoder); });
ternary_op_gpu_inplace<Op>(inputs, out, s);
}

View File

@@ -208,9 +208,8 @@ void unary_op_gpu(
const char* op,
const Stream& s) {
auto& encoder = cu::get_command_encoder(s);
set_unary_output_data(inputs[0], out, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_unary_output_data(
inputs[0], out, [&](auto n) { return cu::malloc_async(n, encoder); });
unary_op_gpu_inplace<Op>(inputs, out, op, s);
}

View File

@@ -5,6 +5,7 @@
#include "mlx/dtype_utils.h"
#include <fmt/format.h>
#include <vector>
namespace mlx::core {
@@ -60,7 +61,7 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
case float64:
return "double";
case complex64:
return "complex64_t";
return "mlx::core::cu::complex64_t";
default:
return "unknown";
}

View File

@@ -44,7 +44,7 @@ void Worker::commit(cudaStream_t stream) {
}
signal_event_.record(stream);
signal_event_.wait(signal_stream_);
cudaLaunchHostFunc(signal_stream_, signal, this);
CHECK_CUDA_ERROR(cudaLaunchHostFunc(signal_stream_, signal, this));
}
void Worker::thread_fn() {

View File

@@ -11,7 +11,7 @@ void slice_gpu(
array& out,
const Shape& start_indices,
const Shape& strides,
const Stream& s) {
const Stream&) {
slice(in, out, start_indices, strides);
}

View File

@@ -28,6 +28,7 @@ make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(indexing/scatter kernels/indexing/indexing.h)
make_jit_source(indexing/masked_scatter)
make_jit_source(indexing/gather kernels/indexing/indexing.h)
make_jit_source(indexing/gather_front kernels/indexing/indexing.h)
make_jit_source(indexing/gather_axis)

View File

@@ -32,7 +32,7 @@ std::string write_signature(
const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<std::string>& attributes,
const std::vector<CustomKernelShapeInfo>& shape_infos,
const std::vector<std::tuple<bool, bool, bool>>& shape_infos,
bool atomic_outputs) {
std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 16384);
@@ -88,19 +88,19 @@ std::string write_signature(
index++;
// Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) {
if (shape_infos[i].shape) {
if (std::get<0>(shape_infos[i])) {
kernel_source +=
(" const constant int* " + name + "_shape [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].strides) {
if (std::get<1>(shape_infos[i])) {
kernel_source +=
(" const constant int64_t* " + name + "_strides [[buffer(" +
std::to_string(index) + ")]],\n");
index++;
}
if (shape_infos[i].ndim) {
if (std::get<2>(shape_infos[i])) {
kernel_source +=
(" const constant int& " + name + "_ndim [[buffer(" +
std::to_string(index) + ")]],\n");
@@ -184,12 +184,12 @@ CustomKernelFunction metal_kernel(
throw std::invalid_argument(
"[metal_kernel] Must specify at least one output.");
}
std::vector<CustomKernelShapeInfo> shape_infos;
std::vector<std::tuple<bool, bool, bool>> shape_infos;
for (auto& n : input_names) {
CustomKernelShapeInfo shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
std::tuple<bool, bool, bool> shape_info;
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info);
}
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
@@ -388,15 +388,15 @@ void CustomKernel::eval_gpu(
index++;
if (in.ndim() > 0) {
int ndim = in.ndim();
if (shape_info.shape) {
if (std::get<0>(shape_info)) {
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
index++;
}
if (shape_info.strides) {
if (std::get<1>(shape_info)) {
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
index++;
}
if (shape_info.ndim) {
if (std::get<2>(shape_info)) {
compute_encoder.set_bytes(ndim, index);
index++;
}

View File

@@ -382,11 +382,8 @@ MTL::CommandQueue* Device::get_queue(Stream stream) {
bool Device::command_buffer_needs_commit(int index) {
auto& stream = get_stream_(index);
if (stream.buffer_ops > max_ops_per_buffer_ ||
(stream.buffer_sizes >> 20) > max_mb_per_buffer_) {
return true;
}
return false;
return (stream.buffer_ops > max_ops_per_buffer_) ||
((stream.buffer_sizes >> 20) > max_mb_per_buffer_);
}
MTL::CommandBuffer* Device::get_command_buffer(int index) {

View File

@@ -265,4 +265,19 @@ Device& device(mlx::core::Device);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
inline bool is_nax_available() {
auto _check_nax = []() {
bool can_use_nax = false;
if (__builtin_available(
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
can_use_nax = true;
}
can_use_nax &=
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
return can_use_nax;
};
static bool is_nax_available_ = _check_nax();
return is_nax_available_;
}
} // namespace mlx::core::metal

View File

@@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
@@ -8,7 +9,9 @@
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/scan.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/dtype.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
@@ -641,4 +644,84 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
const array& dst = inputs[0];
const array& mask = inputs[1];
const array& src = inputs[2];
auto& s = stream();
auto& d = metal::device(s.device);
const size_t total = mask.size();
const CopyType ct = (total == 1)
? CopyType::Scalar
: (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_gpu(dst, out, ct, s);
if (total == 0) {
return;
}
array mask_flat = flatten_in_eval(mask, 1, -1, s);
if (mask_flat.data<void>() != mask.data<void>()) {
d.add_temporary(mask_flat, s.index);
}
if (!mask_flat.flags().row_contiguous) {
mask_flat = contiguous_copy_gpu(mask_flat, s);
d.add_temporary(mask_flat, s.index);
}
// Prefix (exclusive) of mask → scatter_offsets
array scatter_offsets(mask_flat.shape(), uint32, nullptr, {});
scatter_offsets.set_data(allocator::malloc(scatter_offsets.nbytes()));
d.add_temporary(scatter_offsets, s.index);
scan_gpu_inplace(
mask_flat,
scatter_offsets,
Scan::Sum,
/*axis=*/1,
/*reverse=*/false,
/*inclusive=*/false,
s);
// Kernel selection/build
static constexpr std::string_view kBaseName = "masked_assign";
const std::string dtype_tag = type_to_name(out.dtype());
const std::string value_type = get_type_string(out.dtype());
const std::string contiguous =
(src.flags().row_contiguous) ? "true" : "false";
const std::string kernel_name =
fmt::format("{}_{}_{}", kBaseName, dtype_tag, contiguous);
auto lib = d.get_library(kernel_name, [&]() {
std::string source = metal::utils();
source += metal::masked_scatter();
source += fmt::format(
std::string(masked_assign_kernel), kernel_name, value_type, contiguous);
return source;
});
auto kernel = d.get_kernel(kernel_name, lib);
// Binding
int bind_idx = 0;
const int ndim = static_cast<int>(src.ndim());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(mask_flat, bind_idx++);
compute_encoder.set_input_array(scatter_offsets, bind_idx++);
compute_encoder.set_input_array(src, bind_idx++);
compute_encoder.set_output_array(out, bind_idx++);
compute_encoder.set_vector_bytes(src.shape(), bind_idx++);
compute_encoder.set_vector_bytes(src.strides(), bind_idx++);
compute_encoder.set_bytes(ndim, bind_idx++);
compute_encoder.set_bytes(src.size() / src.shape(0), bind_idx++);
compute_encoder.set_bytes(mask_flat.size() / mask.shape(0), bind_idx++);
// Dispatch
auto group_dims = get_block_dims(total, 1, 1);
MTL::Size grid_dims(total, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -11,6 +11,7 @@ const char* ternary_ops();
const char* reduce_utils();
const char* gather();
const char* scatter();
const char* masked_scatter();
const char* arange();
const char* unary();

View File

@@ -70,3 +70,7 @@ constexpr std::string_view scatter_kernels = R"(
gid);
}}
)";
constexpr std::string_view masked_assign_kernel = R"(
template [[host_name("{0}")]] [[kernel]] decltype(masked_assign_impl<{1}, {2}>) masked_assign_impl<{1}, {2}>;
)";

View File

@@ -9,7 +9,14 @@ set(BASE_HEADERS
utils.h)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
set(METAL_FLAGS
-x
metal
-Wall
-Wextra
-fno-fast-math
-Wno-c++17-extensions
-Wno-c++20-extensions)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif()
@@ -120,6 +127,30 @@ if(NOT MLX_METAL_JIT)
build_kernel(gemv_masked steel/utils.h)
endif()
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
26.2))
set(STEEL_NAX_HEADERS
steel/defines.h
steel/utils.h
steel/gemm/transforms.h
steel/gemm/nax.h
steel/gemm/gemm_nax.h
steel/utils/type_traits.h
steel/utils/integral_constant.h)
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
build_kernel(fp_quantized_nax fp_quantized_nax.h ${STEEL_NAX_HEADERS})
set(STEEL_NAX_ATTN_HEADERS
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
steel/utils/integral_constant.h)
build_kernel(steel/attn/kernels/steel_attention_nax ${STEEL_NAX_ATTN_HEADERS})
endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,74 @@
// Copyright © 2025 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized_utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
#include "mlx/backend/metal/kernels/fp_quantized_nax.h"
#define instantiate_quantized_batched(mode, name, type, bm, bn, bk, wm, wn, batched) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \
fp_ ## name, \
type, \
32, \
4, \
batched)
#define instantiate_quantized_aligned(mode, name, type, bm, bn, bk, wm, wn, aligned) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \
fp_ ## name, \
type, \
32, \
4, \
aligned)
#define instantiate_quantized_aligned_batched(mode, name, type, bm, bn, bk, wm, wn, aligned, batched) \
instantiate_kernel( \
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \
fp_ ## name, \
type, \
32, \
4, \
aligned, \
batched)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
32, \
4, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, true) \
instantiate_quantized_aligned(mxfp4, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, false) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 1) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 0) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 1) \
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 0)
#define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, mxfp4_gather_qmm_rhs_nax_nt, type, 64, 64, 64, 2, 2, true) \
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, mxfp4_gather_qmm_rhs_nax_nn, type, 64, 64, 64, 2, 2, false)
#define instantiate_quantized_types(type) \
instantiate_quantized_all_aligned(type) \
instantiate_quantized_all_rhs(type)
instantiate_quantized_types(float)
instantiate_quantized_types(bfloat16_t)
instantiate_quantized_types(float16_t)
// clang-format on

View File

@@ -0,0 +1,38 @@
// Copyright © 2025 Apple Inc.
#pragma once
template <typename T, bool src_contiguous>
[[kernel]] void masked_assign_impl(
const device bool* mask [[buffer(0)]],
const device uint* scatter_offsets [[buffer(1)]],
const device T* src [[buffer(2)]],
device T* out [[buffer(3)]],
const constant int* src_shapes [[buffer(4)]],
const constant int64_t* src_strides [[buffer(5)]],
const constant int& src_ndim [[buffer(6)]],
const constant int64_t& src_batch_size [[buffer(7)]],
const constant int64_t& mask_batch_size [[buffer(8)]],
uint idx [[thread_position_in_grid]]) {
const bool mask_value = mask[idx];
if (!mask_value) {
return;
}
const uint src_index = scatter_offsets[idx];
if (src_index >= src_batch_size) {
return;
}
const uint batch_idx = idx / mask_batch_size;
if (src_contiguous) {
out[idx] = src[batch_idx * src_batch_size + src_index];
} else {
out[idx] = src[elem_to_loc<uint>(
batch_idx * src_batch_size + src_index,
src_shapes,
src_strides,
src_ndim)];
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
// Copyright © 2023-2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
#include "mlx/backend/metal/kernels/quantized_nax.h"
#define instantiate_quantized(name, type, group_size, bits, bm, bn, bk, wm, wn) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
bits, bm, bk, bn, wm, wn)
#define instantiate_quantized_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
batched, bm, bk, bn, wm, wn)
#define instantiate_quantized_aligned(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \
name, \
type, \
group_size, \
bits, \
aligned, bm, bk, bn, wm, wn)
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
aligned, \
batched, bm, bk, bn, wm, wn)
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
group_size, \
bits, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 0)
#define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qmm_n_nax, type, group_size, bits)
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_gather_qmm_n_nax, type, group_size, bits, 64, 64, 64, 2, 2)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true) \
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 0) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 0)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nt, type, group_size, bits, 64, 64, 64, 2, 2, true) \
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nn, type, group_size, bits, 64, 64, 64, 2, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_rhs(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \
instantiate_quantized_funcs(float16_t, group_size, bits) \
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
#define instantiate_quantized_groups(bits) \
instantiate_quantized_types(128, bits) \
instantiate_quantized_types(64, bits) \
instantiate_quantized_types(32, bits)
#define instantiate_quantized_all() \
instantiate_quantized_groups(2) \
instantiate_quantized_groups(3) \
instantiate_quantized_groups(4) \
instantiate_quantized_groups(5) \
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8)
instantiate_quantized_all() // clang-format on

View File

@@ -51,6 +51,7 @@ using namespace metal;
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
instantiate_scan_helper(sum_bool__uint32, bool, uint32_t, CumSum, 4)
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4)

View File

@@ -0,0 +1,476 @@
// Copyright © 2024-25 Apple Inc.
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool align_Q [[function_constant(200)]];
constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
template <typename T>
struct TransformScale {
T scale;
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
METAL_FUNC T apply(T x) const {
return scale * x;
}
};
struct MaxOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return metal::max(x, y);
}
};
struct SumOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x + y;
}
};
struct MulOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x * y;
}
};
struct SubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x - y;
}
};
struct ExpSubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return fast::exp2(x - y);
}
};
struct DivOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x / y;
}
};
// clang-format off
template <
typename T,
int BQ,
int BK,
int BD,
int WM,
int WN,
typename MaskType = float,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant AttnParams* params [[buffer(4)]],
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
(void)simd_lane_id;
// Move to correct block
ulong3 tidl{tid.x, tid.y, tid.z};
Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Sequence
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch
kv_head_idx * params->K_strides[1]; // Head
V += tidl.z * params->V_strides[0] + // Batch
kv_head_idx * params->V_strides[1]; // Head
O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Sequence
if (has_mask) {
mask += tidl.z * mask_params->M_strides[0] + // Batch
tidl.y * mask_params->M_strides[1]; // Head
}
const metal::uniform<float> scale2 =
make_uniform(params->scale) * make_uniform(1.44269504089f);
// Prepare MMA tiles
constexpr short UQ = 16;
constexpr short UD = 32;
constexpr int kNWarps = WM * WN;
static_assert(
BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0,
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
// Q seq frags per warp
constexpr int TQ = BQ / (kNWarps * UQ);
// HeadDim frags (all warps load the same frags)
constexpr int TD = BD / UD;
static_assert(TQ == 1, "Check TQ");
using OSubTile = NAXSubTile<AccumType, UQ, UD>;
NAXTile<AccumType, TQ, TD, OSubTile> Otile;
Otile.clear();
// Prepare mma tile offsets
const short2 simd_coord = OSubTile::NAXFrag_t::get_coord();
const short sm = simd_coord.y;
const short sn = simd_coord.x;
const short tm = UQ * TQ * simd_group_id;
Q += (tm + sm) * int(params->Q_strides[2]) + sn;
K += sm * int(params->K_strides[2]) + sn;
V += sm * int(params->V_strides[2]) + sn;
// Init row reduction variables
constexpr short kRowsPT = decltype(Otile)::kRowsPerThread;
metal::vec<AccumType, kRowsPT> max_score;
metal::vec<AccumType, kRowsPT> sum_score{0};
// Init to -Inf
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::finite_min;
}
if (has_sinks) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
sum_score[i] = 1;
}
}
int kb_lim = params->NK;
if (do_causal) {
int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK;
kb_lim = min(params->NK, kb_lim);
}
const bool is_last_bq = int(tid.x) == (params->NQ_aligned);
// const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);
const bool is_last_q = is_last_bq;
const short lim_rows_q = params->qL_rem - (tm + sm);
const short lim_rows_k = params->kL_rem - sm;
// Loop over KV seq length
for (int kb = 0; kb < kb_lim; kb++) {
const int is_last_k = (kb == (params->NK_aligned));
// Do S = Q @ K.T
constexpr short UDs = 16;
constexpr short UKs = 32;
constexpr short TDs = BD / UDs;
constexpr short TKs = BK / UKs;
using SSubTile = NAXSubTile<AccumType, UQ, UKs>;
using QSubTile = NAXSubTile<T, UQ, UDs>;
using KSubTile = NAXSubTile<T, UKs, UDs>;
NAXTile<AccumType, TQ, TKs, SSubTile> Stile;
Stile.clear();
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TKs; ik++) {
STEEL_PRAGMA_UNROLL
for (short id = 0; id < TDs; id++) {
NAXTile<T, 1, 1, QSubTile> Qtile;
NAXTile<T, 1, 1, KSubTile> Ktile;
const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs;
const int K_load_off =
ik * UKs * int(params->K_strides[2]) + id * UDs;
if (!align_Q && is_last_q) {
// Qtile.load_rows(
// Q + Q_load_off,
// int(params->Q_strides[2]),
// lim_rows_q - iq * UQ);
Qtile.load_safe(
Q + Q_load_off,
int(params->Q_strides[2]),
short2(BD, lim_rows_q - iq * UQ));
} else {
Qtile.load(Q + Q_load_off, int(params->Q_strides[2]));
}
if (!align_K && is_last_k) {
// Ktile.load_rows(
// K + K_load_off,
// int(params->K_strides[2]),
// lim_rows_k - ik * UKs);
Ktile.load_safe(
K + K_load_off,
int(params->K_strides[2]),
short2(BD, lim_rows_k - ik * UKs));
} else {
Ktile.load(K + K_load_off, int(params->K_strides[2]));
}
subtile_matmad_nax(
Stile.subtile_at(iq, ik),
Qtile.subtile_at(0, 0),
metal::false_type{},
Ktile.subtile_at(0, 0),
metal::true_type{});
}
}
}
// Scale S
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
Stile.elems()[ii] *= float(scale2);
}
// Scale and Retile S
constexpr short UK = 16;
constexpr short TK = BK / UK;
using PSubTile = NAXSubTile<AccumType, UQ, UK>;
NAXTile<AccumType, TQ, TK, PSubTile> Ptile;
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
Ptile.elems()[ii] = Stile.elems()[ii];
}
// Mask out length sequence
if (!align_K && is_last_k) {
constexpr auto neg_inf = Limits<AccumType>::finite_min;
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
const short col_pos = sn + ik * UK;
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
const auto loc = ii * PSubTile::kFragThrCols + jj;
fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc];
}
}
}
}
}
// Mask out if causal
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
constexpr auto neg_inf = Limits<AccumType>::finite_min;
const int base_row = tid.x * BQ + params->qL_off + tm;
const int base_col = kb * BK;
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
const short row_pos = base_row + iq * UQ;
const short col_pos = base_col + ik * UK;
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm;
const auto c = col_pos + jj + sn;
const auto loc = ii * PSubTile::kFragThrCols + jj;
fg[loc] = (r < c) ? neg_inf : fg[loc];
}
}
}
}
}
// Other masking as needed
if (has_mask) {
constexpr auto neg_inf = Limits<AccumType>::finite_min;
const int base_row = tid.x * BQ + tm;
const int base_col = kb * BK;
constexpr bool is_bool = is_same_v<MaskType, bool>;
using melem_t = typename metal::conditional_t<is_bool, bool, AccumType>;
using MSubTile = NAXSubTile<melem_t, UQ, UK>;
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
const short row_pos = base_row + iq * UQ + sm;
const short col_pos = base_col + ik * UK + sn;
MSubTile mfrag;
mfrag.load_safe(
mask,
int(mask_params->M_strides[2]),
Int<1>{},
params->qL,
params->kL,
row_pos,
col_pos);
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) {
if constexpr (is_bool) {
fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf;
} else {
fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]);
}
}
}
}
}
// Do softmax
// Temp variables
metal::vec<AccumType, kRowsPT> new_max;
metal::vec<AccumType, kRowsPT> factor;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
new_max[i] = max_score[i];
}
// Row max
Ptile.template row_reduce<MaxOp>(new_max);
// exp(Si - rowmax(Si))
Ptile.template row_bin_op<ExpSubOp>(new_max);
// Factor exp(rowmax(Si) - rowmax(Si-1))
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
factor[i] = fast::exp2(max_score[i] - new_max[i]);
max_score[i] = new_max[i];
}
// Row Sum
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
sum_score[i] = sum_score[i] * factor[i];
}
Ptile.template row_reduce<SumOp>(sum_score);
// Update O
Otile.template row_bin_op<MulOp>(factor);
simdgroup_barrier(mem_flags::mem_none);
// Do O = P @ V
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short id = 0; id < TD; id++) {
if constexpr (BD == 128) {
if (id == 2) {
threadgroup_barrier(mem_flags::mem_none);
}
}
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
using VSubTile = NAXSubTile<T, UK, UD>;
NAXTile<T, 1, 1, VSubTile> Vtile;
const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD;
if (!align_K && is_last_k) {
// Vtile.load_rows(
// V + V_load_off,
// int(params->V_strides[2]),
// lim_rows_k - ik * UK);
Vtile.load_safe(
V + V_load_off,
int(params->V_strides[2]),
short2(BD, lim_rows_k - ik * UK));
} else {
Vtile.load(V + V_load_off, int(params->V_strides[2]));
}
subtile_matmad_nax(
Otile.subtile_at(iq, id),
Ptile.subtile_at(iq, ik),
metal::bool_constant<false>{},
Vtile.subtile_at(0, 0),
metal::bool_constant<false>{});
}
}
}
// Prepare for next iteration
K += BK * int(params->K_strides[2]);
V += BK * int(params->V_strides[2]);
}
// Normalize output
threadgroup_barrier(mem_flags::mem_none);
metal::vec<AccumType, kRowsPT> rcp;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
rcp[i] = (1.f / sum_score[i]);
}
Otile.template row_bin_op<MulOp>(rcp);
// Store results
O += (tm + sm) * int(params->O_strides[2]) + sn;
if (!align_Q && is_last_q) {
if (lim_rows_q <= 0)
return;
// Otile.store_rows(O, params->O_strides[2], lim_rows_q);
Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q));
} else {
Otile.store(O, int(params->O_strides[2]));
}
}

View File

@@ -0,0 +1,33 @@
// Copyright © 2024-25 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/nax.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h"
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
instantiate_kernel( \
"steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \
"_wm" #wm "_wn" #wn "_mask" #mname, \
attention_nax, dtype, bq, bk, bd, wm, wn, mtype, float)
#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
instantiate_attn(iname, itype, 64, 32, 128, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 64, 32, 64, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 64, 64, 128, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 64, 64, 64, 4, 1, mname, mtype)
#define instantiate_attn_mask_helper(iname, itype) \
instantiate_attn_shapes_helper(iname, itype, iname, itype) \
instantiate_attn_shapes_helper(iname, itype, bool_, bool)
instantiate_attn_mask_helper(float16, half);
instantiate_attn_mask_helper(bfloat16, bfloat);
instantiate_attn_mask_helper(float32, float);
// clang-format on

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,7 @@
// Copyright © 2024 Apple Inc.
#pragma once
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
#define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)")

View File

@@ -0,0 +1,154 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
using namespace metal;
namespace mlx::steel {
template <
typename T,
short SM,
short SN,
short SK,
short BK,
bool transpose_a,
bool transpose_b,
bool kAlignedM,
bool kAlignedN,
bool kAlignedK,
short UM,
short UN,
short UK,
typename AccumType = float>
auto gemm_loop(
const device T* A,
const device T* B,
const constant GEMMParams* params [[buffer(4)]],
const short sgp_sm,
const short sgp_sn) {
constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
constexpr short TK = SK / UK;
constexpr int RA = transpose_a ? TK : TM;
constexpr int CA = transpose_a ? TM : TK;
constexpr int RB = transpose_b ? TN : TK;
constexpr int CB = transpose_b ? TK : TN;
using DSubTile = NAXSubTile<AccumType, UM, UN>;
using ASubTile =
NAXSubTile<T, (transpose_a ? UK : UM), (transpose_a ? UM : UK)>;
using BSubTile =
NAXSubTile<T, (transpose_b ? UN : UK), (transpose_b ? UK : UN)>;
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
Dtile.clear();
int gemm_k_iterations_ = params->gemm_k_iterations_aligned;
STEEL_PRAGMA_NO_UNROLL
for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) {
threadgroup_barrier(mem_flags::mem_none);
STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
NAXTile<T, RA, CA, ASubTile> Atile;
NAXTile<T, RB, CB, BSubTile> Btile;
const int k = kk1;
volatile int compiler_barrier;
const int A_offset = transpose_a ? k * params->lda : k;
const int B_offset = transpose_b ? k : k * params->ldb;
if constexpr (kAlignedM) {
Atile.load(A + A_offset, params->lda);
} else {
const short rmax = transpose_a ? SK : sgp_sm;
const short cmax = transpose_a ? sgp_sm : SK;
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
}
if constexpr (kAlignedN) {
Btile.load(B + B_offset, params->ldb);
} else {
const short rmax = transpose_b ? sgp_sn : SK;
const short cmax = transpose_b ? SK : sgp_sn;
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
}
tile_matmad_nax(
Dtile,
Atile,
metal::bool_constant<transpose_a>{},
Btile,
metal::bool_constant<transpose_b>{});
(void)compiler_barrier;
}
A += transpose_a ? (BK * params->lda) : BK;
B += transpose_b ? BK : (BK * params->ldb);
}
if constexpr (!kAlignedK) {
simdgroup_barrier(mem_flags::mem_none);
const short rem_bk = params->K - gemm_k_iterations_ * BK;
STEEL_PRAGMA_NO_UNROLL
for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) {
NAXTile<T, 1, 1, ASubTile> Atile;
NAXTile<T, 1, 1, BSubTile> Btile;
STEEL_PRAGMA_UNROLL
for (int mm = 0; mm < TM; mm++) {
STEEL_PRAGMA_UNROLL
for (int nn = 0; nn < TN; nn++) {
STEEL_PRAGMA_UNROLL
for (int kk = 0; kk < TK; kk++) {
const int m = mm * UM;
const int n = nn * UN;
const int k = kk1 + kk * UK;
const short psk = max(0, rem_bk - k);
const int A_offset =
transpose_a ? (m + k * params->lda) : (m * params->lda + k);
const int B_offset =
transpose_b ? (k + n * params->ldb) : (k * params->ldb + n);
{
const short psm = kAlignedM ? SM : max(0, sgp_sm - m);
const short rmax = transpose_a ? psk : psm;
const short cmax = transpose_a ? psm : psk;
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
}
{
const short psn = kAlignedN ? SN : max(0, sgp_sn - n);
const short rmax = transpose_b ? psn : psk;
const short cmax = transpose_b ? psk : psn;
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
}
subtile_matmad_nax(
Dtile.subtile_at(mm, nn),
Atile.subtile_at(0, 0),
metal::bool_constant<transpose_a>{},
Btile.subtile_at(0, 0),
metal::bool_constant<transpose_b>{});
}
}
}
}
}
return Dtile;
}
} // namespace mlx::steel

View File

@@ -0,0 +1,207 @@
// Copyright © 2025 Apple Inc.
using namespace mlx::steel;
constant bool has_batch [[function_constant(10)]];
constant bool use_out_source [[function_constant(100)]];
constant bool do_axpby [[function_constant(110)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
// clang-format off
template <
bool kAlignedM,
bool kAlignedN,
typename NAXTile_t,
typename T>
void gemm_epilogue(
thread NAXTile_t& Dtile,
const device T* C,
const constant GEMMParams* params,
const constant GEMMAddMMParams* addmm_params,
const short sgp_sm,
const short sgp_sn) { // clang-format on
(void)params;
constexpr short UM = NAXTile_t::kSubTileRows;
constexpr short UN = NAXTile_t::kSubTileCols;
using CSubTile = NAXSubTile<T, UM, UN>;
using V = typename NAXTile_t::elem_type;
constexpr short TM = NAXTile_t::kTileRows;
constexpr short TN = NAXTile_t::kTileCols;
constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile;
STEEL_PRAGMA_UNROLL
for (short mm = 0; mm < TM; mm++) {
STEEL_PRAGMA_UNROLL
for (short nn = 0; nn < TN; nn++) {
const short m = mm * UM;
const short n = nn * UN;
CSubTile CTile;
if constexpr (kAlignedM && kAlignedN) {
CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n);
} else {
CTile.load_safe(
C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n);
}
auto delems = Dtile.subtile_at(mm, nn).elems();
auto celems = CTile.elems();
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemsPerSubTile; i++) {
if (do_axpby) {
delems[i] = addmm_params->alpha * delems[i] +
addmm_params->beta * static_cast<V>(celems[i]);
} else {
delems[i] += static_cast<V>(celems[i]);
}
}
}
}
}
// clang-format off
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device T* C [[buffer(2), function_constant(use_out_source)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on
// Find block
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
// Exit early if out of bounds
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Adjust for batch
if (has_batch) {
const constant auto* A_bstrides = batch_strides;
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
}
D += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
if (use_out_source) {
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
}
constexpr short UM = 16;
constexpr short UN = 32;
constexpr short UK = 16;
constexpr short SM = BM / WM;
constexpr short SN = BN / WN;
constexpr short SK = 32;
constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
const short tm = SM * (simd_group_id / WN);
const short tn = SN * (simd_group_id % WN);
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
A += transpose_a ? tm : (tm * params->lda);
B += transpose_b ? (tn * params->ldb) : tn;
D += tm * params->ldd + tn;
if (use_out_source) {
C += tm * addmm_params->ldc + tn * addmm_params->fdc;
}
using DSubTile = NAXSubTile<AccumType, UM, UN>;
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
dispatch_bool(align_K, [&](auto kAlignedK) {
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
Dtile = gemm_loop<
T,
SM,
SN,
SK,
BK,
transpose_a,
transpose_b,
kAlignedM.value,
kAlignedN.value,
kAlignedK.value,
UM,
UN,
UK,
AccumType>(A, B, params, sgp_sm, sgp_sn);
if (use_out_source) {
gemm_epilogue<kAlignedM.value, kAlignedN.value>(
Dtile, C, params, addmm_params, sgp_sm, sgp_sn);
}
if constexpr (kAlignedM && kAlignedN) {
Dtile.store(D, int(params->ldd));
} else {
Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm));
}
});
});
});
}

View File

@@ -0,0 +1,35 @@
// Copyright © 2025 Apple Inc.
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h"
// clang-format off
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gemm_fused_nax_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \
gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
instantiate_gemm_shapes_helper(float32, float, float32, float);
// clang-format on

View File

@@ -0,0 +1,132 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
gather_mm_rhs_nax(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device uint32_t* rhs_indices [[buffer(2)]],
device T* C [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
constexpr short UM = 16;
constexpr short UN = 32;
constexpr short UK = 16;
constexpr short SM = BM / WM;
constexpr short SN = BN / WN;
constexpr short SK = 32;
constexpr short TM = SM / UM;
constexpr short TN = SN / UN;
if (params->tiles_n <= static_cast<int>(tid.x) ||
params->tiles_m <= static_cast<int>(tid.y)) {
return;
}
// Find the block in A, B, C
const int c_row = tid.y * BM;
const int c_col = tid.x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
C += c_row_long * params->ldd + c_col_long;
rhs_indices += c_row;
const short tm = SM * (simd_group_id / WN);
const short tn = SN * (simd_group_id % WN);
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
A += transpose_a ? tm : (tm * params->lda);
B += transpose_b ? (tn * params->ldb) : tn;
C += tm * params->ldd + tn;
rhs_indices += tm;
// Do as many matmuls as necessary
uint32_t index;
short offset;
uint32_t index_next = rhs_indices[0];
short offset_next = 0;
int n = 0;
while (n < sgp_sm) {
n++;
offset = offset_next;
index = index_next;
offset_next = sgp_sm;
for (; n < sgp_sm; n++) {
if (rhs_indices[n] != index) {
offset_next = n;
index_next = rhs_indices[n];
break;
}
}
threadgroup_barrier(mem_flags::mem_none);
using DSubTile = NAXSubTile<AccumType, UM, UN>;
NAXTile<AccumType, TM, TN, DSubTile> Ctile;
dispatch_bool(align_K, [&](auto kAlignedK) {
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
auto do_gemm = gemm_loop<
T,
SM,
SN,
SK,
BK,
transpose_a,
transpose_b,
kAlignedM.value,
kAlignedN.value,
kAlignedK.value,
UM,
UN,
UK,
AccumType>;
Ctile = do_gemm(
A, B + index * params->batch_stride_b, params, sgp_sm, sgp_sn);
if constexpr (kAlignedN.value) {
if (offset_next - offset == SM) {
Ctile.store(C, int(params->ldd));
} else {
Ctile.store_slice(
C,
int(params->ldd),
short2(0, offset),
short2(SN, offset_next));
}
} else {
Ctile.store_slice(
C,
int(params->ldd),
short2(0, offset),
short2(sgp_sn, offset_next));
}
});
});
});
}
}

View File

@@ -0,0 +1,39 @@
// Copyright © 2024 Apple Inc.
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/utils.h"
// clang-format off
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_kernel( \
"steel_gather_mm_rhs_nax_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn, \
gather_mm_rhs_nax, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
float)
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 128, 128, 1, 4) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 32, 128, 128, 1, 4) \
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 64, 128, 128, 2, 4)
// clang-format on
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
instantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More