Compare commits

..

89 Commits

Author SHA1 Message Date
Awni Hannun
1b591ec736 No VJP for mask or sinks in attention (#2909)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-13 19:48:39 -08:00
Awni Hannun
47d2505ea9 Fix attention for large sizes (#2903)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-13 06:54:30 -08:00
Cheng
bedefed784 Fix ccache getting disabled (#2905)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-13 13:00:51 +09:00
Melissa Kilby
ccaaa7d6df fix: possible heap-buffer-overflow in RandomBits::eval_cpu (#2877)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-12 02:11:18 -08:00
Awni Hannun
f3e5ca5414 [CUDA] Add host nodes to subgraph types for graph update (#2901)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-11 19:13:44 -08:00
Awni Hannun
81dfe5f137 Fix grad in place updates (#2899)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-11 14:44:58 -08:00
Anastasiia Filippova
012fb220a1 fp quantize (#2892)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-11 06:11:25 -08:00
Nathan Goldbaum
e1fee0074b Update nanobind pin to most recent version (#2896) 2025-12-11 06:07:36 -08:00
CCYeh
3c8ce9b00e Fix input buffer donation in compile (#2897) 2025-12-11 06:07:03 -08:00
David Koski
937ce79660 do not use simd neon intrinsics on x86 (#2893)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-10 12:23:28 -08:00
Nathan Goldbaum
208f5441a7 bump minimum required Python version (#2891)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-09 16:54:38 -08:00
Awni Hannun
b862d842e1 Allow events in sub graph to be updatable (#2886)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-09 12:34:37 -08:00
Satyam singh
f7a400951a Fix docs: replace mx.random.randn with mx.random.normal (#2890) 2025-12-09 11:46:30 -08:00
Awni Hannun
27232db1ba [CUDA] Enable more graphs to be updatable (#2883)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-08 06:18:01 -08:00
Awni Hannun
a4b3bc969b Try not to fail when there should be memory available (#2869)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-07 06:11:00 -08:00
Awni Hannun
667c0f3bb9 [Metal] No copy array init (#2875)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-05 13:36:45 -08:00
Cheng
6245824d42 Make allocator::malloc throw on allocation failure (#2874)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-05 17:44:38 +09:00
Awni Hannun
39289ef025 [CUDA] Release build for cuda 13 (#2872) 2025-12-04 21:42:26 -08:00
Awni Hannun
aefc9bd3f6 [CUDA] Faster general copy (#2873) 2025-12-04 21:42:15 -08:00
Angelos Katharopoulos
997cfc7699 Add a 2-pass col reduce for CUDA (#2863)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-04 15:53:59 -08:00
Awni Hannun
1fa8dc5797 Do a PyPi release for cuda on arm (#2866) 2025-12-04 15:28:29 -08:00
Awni Hannun
a6d6717181 fix compile copying (#2871) 2025-12-04 12:32:56 -08:00
Awni Hannun
941cfe23d7 Layer norm throws on dimension mismatch (#2870) 2025-12-04 11:21:05 -08:00
romanoneg
9abb0b8123 Added support for pytree types that inherit from tuple and typing.namedtuple (#2845) 2025-12-04 11:06:45 -08:00
Tian En "TianHeng
50d3914c67 Update gumbel function signature parameters (#2868)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-03 15:37:35 -08:00
Awni Hannun
cacbdbf995 Fix init from double (#2861)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-03 06:08:11 -08:00
Awni Hannun
193cdcd81a Fix graph updating (#2857)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-02 17:12:24 -08:00
Awni Hannun
d8ceae7b77 Reduce JVP (#2854) 2025-12-02 16:17:47 -08:00
Awni Hannun
eff0e31f00 Fix export scatters (#2852)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-02 11:24:40 -08:00
Awni Hannun
6c5785bc2f use thread local cpature mode (#2850)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-01 19:02:47 -08:00
CCYeh
8879ee00eb Support more Numpy interfaces for masked_scatter (#2832) 2025-12-01 17:51:02 -08:00
Cheng
6e762fe2e2 [CUDA] Migrate conv code to new cuDNN APIs (#2847)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-02 07:55:43 +09:00
Cheng
2b95d0c270 [CUDA] Use cuDNN attention when T_q != T_kv (#2843)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-11-27 09:58:43 +09:00
Chaoran Yu
b054838780 Added clarification to apply_fn parameter of apply_to_modules (#2831)
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-26 15:40:56 -08:00
Awni Hannun
dd79d3c465 [CUDA] Faster rms norm for small dimension (#2838) 2025-11-26 15:10:41 -08:00
Cheng
704fd1ae28 [CUDA] Support array mask in SDPA (#2822)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-26 11:08:58 +09:00
Cheng
c9f4dc851f Merge build-cuda and build-linux actions (#2783)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-11-25 20:06:42 +09:00
Cheng
f8bd675655 [CUDA] Output of SDPA should have same layout with inputs (#2826)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-25 15:22:58 +09:00
Cheng
23a9168d34 [CUDA] Add debug env to save cuda graphs to dot files (#2825) 2025-11-25 15:22:36 +09:00
Awni Hannun
bca205e287 [CUDA] Exit on crash and more helpful errors (#2830) 2025-11-24 19:46:03 -08:00
CCYeh
1d4eacb737 Fix mx.core.linspace type annotation (#2820)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-11-24 14:15:08 -08:00
dependabot[bot]
8abd37ad05 Bump actions/checkout from 5 to 6 (#2828)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-24 06:04:46 -08:00
Andrey Portnoy
3e05cea9f8 Force cudaGraphExec reinstantiation when clusters are used (#2813)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-22 12:43:49 -08:00
CCYeh
5b0f047226 Fix mx.core.load type annotation (#2819) 2025-11-22 11:09:44 -08:00
Harsh Sutaria
618c87af8c Add float64 Eig and complex64 SVD/Eig support (Fixes #2708) (#2737)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-22 06:51:36 -08:00
Cheng
d5f61a93fa Fix typo: refs/head/main => refs/heads/main (#2818)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-22 09:43:35 +09:00
Awni Hannun
4a09264236 Tolerance for some ops tests on cuda (#2815) 2025-11-21 16:06:16 -08:00
Awni Hannun
0dbc7e5bee Centralize NAX condition (#2811)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-21 13:28:15 -08:00
Awni Hannun
0d68efd461 patch bump for future version (#2804)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-20 09:26:20 -08:00
Awni Hannun
f9e1a14135 [CUDA] Partly fix random for large sizes (#2798) 2025-11-20 07:27:50 -08:00
Awni Hannun
d8e9ded928 Fix cuda allocator copy condition (#2800) 2025-11-20 07:06:55 -08:00
Awni Hannun
60939d010c Fix macos release target and linux arm release (#2802)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-19 21:37:50 -08:00
Awni Hannun
fdcd2923fd patch + fix docs build (#2799) 2025-11-19 16:16:26 -08:00
Jagrit Digani
54f1cc6e3e Add Neural Accelerator Support (#2772) 2025-11-19 15:06:00 -08:00
CCYeh
b3825ac149 Add Masked Scatter (#2663)
Co-authored-by: Awni Hannun <awni@apple.com>
Co-authored-by: Angelos Katharopoulos <katharas@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2025-11-19 14:53:32 -08:00
Awni Hannun
7f4b7e553c version (#2797) 2025-11-19 14:11:16 -08:00
Awni Hannun
ad16f41a7f Fix version tag (#2790)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-19 08:55:57 -08:00
Awni Hannun
f46877bc08 more accurate rope fallback (#2792) 2025-11-19 06:07:21 -08:00
Cheng
6f35017d1b [CUDA] cuDNN backward attention (#2762)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-19 08:13:50 +09:00
Awni Hannun
b167f0df1c build docs on linux (#2787)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-18 08:01:03 -08:00
Cheng
a9f0d6b160 Avoid duplicate CI runs when starting a PR from upstream branch (#2788)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-18 15:16:25 +09:00
Cheng
940f4c7818 Fix building with CUDA < 12.8 (#2782)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.6) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-11-18 12:55:19 +09:00
Cheng
35f81728f1 Remove unneeded tests in nightly build (#2786) 2025-11-18 08:09:58 +09:00
Cheng
4442ed86c1 Fix nightly build (#2785) 2025-11-18 08:07:51 +09:00
Cheng
698559c231 Test every commit in main branch (#2781) 2025-11-18 08:07:22 +09:00
Cheng
ecc4879b07 Do not run CPU tests in CUDA builds (#2784) 2025-11-18 07:27:09 +09:00
Cheng
32b18d8b66 Use std::optional for mask_arr arg (#2763)
Some checks failed
Build and Test / check_lint (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04) (push) Has been cancelled
Build and Test / linux_build_and_test (ubuntu-22.04-arm) (push) Has been cancelled
Build and Test / mac_build_and_test (14.0) (push) Has been cancelled
Build and Test / mac_build_and_test (15.0) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.8) (push) Has been cancelled
Build and Test / cuda_build_and_test (cuda-12.9) (push) Has been cancelled
Build and Test / build_documentation (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (cuda-12.8) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (cuda-12.9) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-17 10:43:33 +09:00
Cheng
472c43a0c8 Build and test with multiple CUDA versions (#2780) 2025-11-17 09:19:02 +09:00
Cheng
b7214ff01e Remove pip cache in GitHub Actions (#2776)
* Correctly set pip cache key

* [Debug] Try disabling pip cache
2025-11-17 08:19:59 +09:00
Cheng
76414c8971 Run CI for pushes (#2777) 2025-11-17 07:19:01 +09:00
Awni Hannun
49e4566df3 fix release 2 (#2767)
* fix release 2

* login

* fix
2025-11-16 11:39:53 -08:00
Awni Hannun
aad49f932f [CUDA] Tune ops per buffer based on device (#2761)
* tune ops per buffer based on device

* tune memory limit as well

* add tuning for spark
2025-11-16 06:29:49 -08:00
Cheng
86765cce34 Use ccache in GitHub Actions (#2773)
* Remove unnecessary steps

* Use ccache

* Log when using ccache

* Set max-size to 1GB

* Pass --no-build-isolation

* Remove more unused things
2025-11-16 07:58:14 +09:00
Cheng
1bedcbd556 Fix warnings with cmake 4.1 (#2774) 2025-11-16 07:12:47 +09:00
Cheng
9ac7dbe877 Fix MPI distributed tests with CUDA backend (#2775) 2025-11-16 07:12:18 +09:00
Awni Hannun
1bf605d56d use arch specific targets when possible (#2771) 2025-11-14 20:04:18 -08:00
Cheng
3c622ddd1d Separate test-linux from build-linux/cuda in GitHub Actions (#2765)
* Separate test-linux from build-linux/cuda in GitHub Actions

* Prefer unittest when possible

Co-authored-by: Mike Drob <mdrob@apache.org>

---------

Co-authored-by: Mike Drob <mdrob@apache.org>
2025-11-15 11:14:09 +09:00
Awni Hannun
27ff069175 Fix exporting with constants (#2769) 2025-11-14 12:52:08 -08:00
Cheng
3b2ffcefc3 [CUDA] cuDNN forward attention (#2743)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Separate sdpa kernels in another file

* Initial support for cuDNN SDPA

* Diable a few corner cases

* Remove scaled_dot_product_attention.h

* Use cuDNN attention for prefilling

* cuDNN SDPA requires Ampere and later

* Address reviews

* Do contiguous copy of inputs
2025-11-14 09:23:56 +09:00
Awni Hannun
b65f882df3 fix release (#2759) 2025-11-13 15:34:01 -08:00
Cheng
b704e9e77a [CUDA] Check CUDA error in synchronize (#2757) 2025-11-14 07:10:23 +09:00
Awni Hannun
66519fb348 fix slice (#2758) 2025-11-13 11:30:02 -08:00
Awni Hannun
8973550ff3 export custom kernel (#2756) 2025-11-13 11:29:50 -08:00
Mike Drob
3f866be665 minor debugging for publishing (#2739)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* minor debugging for publishing

* fix logic
2025-11-12 06:33:39 -08:00
Awni Hannun
23f81ed1c1 Linux on arm (#2751)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* try linux on arm

* ssh

* fix
2025-11-11 11:41:14 -08:00
wrmsr
3fe2250c00 Fix irregular_strides benchmark shape type (#2754) 2025-11-11 11:40:22 -08:00
Awni Hannun
047114b988 remove circle (#2753) 2025-11-11 11:39:47 -08:00
wrmsr
9320eb89a8 Fix dequantize python sig (dtype default) (#2752) 2025-11-11 09:55:24 -08:00
Awni Hannun
75819d70ea patch bump (#2750) 2025-11-11 08:49:14 -08:00
182 changed files with 12656 additions and 2291 deletions

View File

@@ -1,579 +0,0 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
parameters:
nightly_build:
type: boolean
default: false
test_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "26.0.0"
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install
command: |
xcodebuild -downloadComponent MetalToolchain
brew install python@3.10
brew install doxygen
python3.10 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
name: Run style checks
command: |
pip install pre-commit
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
- run:
name: Install dependencies
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e ".[dev]" -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
python -m unittest discover python/tests -v
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "26.0.0"
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
resource_class: m4pro.medium
steps:
- checkout
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
brew install openmpi uv
- run:
name: Install Python package
command: |
uv venv --python 3.10
uv pip install \
nanobind==2.4.0 \
cmake \
numpy \
torch \
tensorflow \
unittest-xml-reporting
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
uv pip install -e . -v
- run:
name: Generate package stubs
command: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- run:
name: Run Python tests
command: |
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- run:
name: Build example extension
command: |
source .venv/bin/activate
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project python test.py
- store_test_results:
path: test-results
- run:
name: Build CPP only
command: |
source .venv/bin/activate
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
- run:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
- run:
name: Build small binary
command: |
source .venv/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j `sysctl -n hw.ncpu`
- run:
name: Run Python tests with JIT
command: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
METAL_DEBUG_ERROR_MODE=0 \
uv run --no-project python -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit
cuda_build_and_test:
parameters:
image_date:
type: string
default: "2023.11.1"
machine:
image: "linux-cuda-12:<< parameters.image_date >>"
resource_class: gpu.nvidia.small.gen2
steps:
- checkout
- restore_cache:
keys:
- cuda-<< parameters.image_date >>-{{ arch }}-
- run:
name: Install dependencies
command: |
sudo apt-get update
sudo apt-get install libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libnccl2 libnccl-dev
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
rm -rf ccache-4.11.3-linux-x86_64
curl -LsSf https://astral.sh/uv/install.sh | sh
- run:
name: Set CCache size
command: ccache --max-size 1G
- run:
name: Install Python package
command: |
uv venv
uv pip install cmake
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
uv pip install -e ".[dev]" -v
- run:
name: Run Python tests
command: |
source .venv/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
- run:
name: Build CPP only
command: |
source .venv/bin/activate
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=`which nvcc` \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j `nproc`
- run:
name: Run CPP tests
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
- run:
name: CCache report
command: |
ccache --show-stats
ccache --zero-stats
ccache --cleanup
- save_cache:
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
paths:
- /home/circleci/.cache/ccache
build_release:
parameters:
python_version:
type: string
default: "3.10"
xcode_version:
type: string
default: "26.0.0"
build_env:
type: string
default: ""
macosx_deployment_target:
type: string
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: m4pro.medium
environment:
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
steps:
- checkout
- run:
name: Install dependencies
command: |
xcodebuild -downloadComponent MetalToolchain
mkdir -p ~/miniconda3
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate
conda init --all
conda create -n env python=<< parameters.python_version >> -y
conda activate env
pip install --upgrade cmake
pip install nanobind==2.4.0
pip install --upgrade setuptools
pip install numpy
pip install twine
pip install build
- run:
name: Install Python package
command: |
conda activate env
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
pip install . -v
- run:
name: Generate package stubs
command: |
conda activate env
pip install typing_extensions
python setup.py generate_stubs
- run:
name: Build Python package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when:
condition:
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
conda activate env
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
conda activate env
twine upload dist/*
- store_artifacts:
path: dist/
build_linux_release:
parameters:
python_version:
type: string
default: "3.10"
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: large
steps:
- checkout
- run:
name: Build wheel
command: |
PYTHON=python<< parameters.python_version >>
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
sudo apt-get update
TZ=Etc/UTC sudo apt-get -y install tzdata
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
$PYTHON -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.build_env >> pip install ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
python setup.py clean --all
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
bash python/scripts/repair_linux.sh
- when:
condition:
equal: ["3.10", << parameters.python_version >>]
steps:
- run:
name: Build common package
command: |
source env/bin/activate
python setup.py clean --all
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload packages
command: |
source env/bin/activate
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
build_cuda_release:
parameters:
build_env:
type: string
default: ""
machine:
image: ubuntu-2204:current
resource_class: xlarge
steps:
- checkout
- run:
name: Build wheel
command: |
export DEBIAN_FRONTEND=noninteractive
export NEEDRESTART_MODE=a
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install zip
pip install auditwheel
pip install patchelf
pip install build
pip install twine
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
python -m build -w
bash python/scripts/repair_cuda.sh
- when:
condition: << parameters.build_env >>
steps:
- run:
name: Upload package
command: |
twine upload wheelhouse/*.whl
- store_artifacts:
path: wheelhouse/
workflows:
build_and_test:
when:
and:
- matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test:
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test
- cuda_build_and_test:
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
- build_documentation
build_pypi_release:
when:
and:
- not: << pipeline.parameters.nightly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- build_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"]
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
- build_linux_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["PYPI_RELEASE=1"]
- build_cuda_release:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
matrix:
parameters:
build_env: ["PYPI_RELEASE=1"]
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
macosx_deployment_target: ["13.5", "15.0"]
- linux_build_and_test:
requires: [ hold ]
- cuda_build_and_test:
requires: [ hold ]
matrix:
parameters:
image_date: ["2023.11.1", "2025.05.1"]
nightly_build:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.nightly_build >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
- build_cuda_release
build_dev_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
jobs:
- build_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"]
- build_linux_release:
matrix:
parameters:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["DEV_RELEASE=1"]
- build_cuda_release:
matrix:
parameters:
build_env: ["DEV_RELEASE=1"]

View File

@@ -2,9 +2,13 @@ name: 'Build CUDA wheel'
description: 'Build CUDA wheel' description: 'Build CUDA wheel'
inputs: inputs:
nvcc-location: arch:
description: 'Location of nvcc compiler' description: 'Platform architecture tag'
required: true required: true
type: choice
options:
- x86_64
- aarch64
runs: runs:
using: "composite" using: "composite"
@@ -12,9 +16,9 @@ runs:
- name: Build package - name: Build package
shell: bash shell: bash
env: env:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
run: | run: |
pip install auditwheel build patchelf setuptools pip install auditwheel build patchelf setuptools
python setup.py clean --all python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh bash python/scripts/repair_cuda.sh ${{ inputs.arch }}

View File

@@ -1,45 +0,0 @@
name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
nvcc-location:
description: 'Location of nvcc compiler'
required: true
default: '/usr/local/cuda-12.9/bin/nvcc'
runs:
using: "composite"
steps:
- name: Install Python package
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: pip install -e ".[dev]" -v
- name: Run Python tests - CPU
shell: bash
env:
LOW_MEMORY: 1
DEVICE: cpu
run: python -m unittest discover python/tests -v
- name: Run Python tests - GPU
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
run: python -m tests discover python/tests -v
- name: Build CPP only
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)
- name: Run CPP tests
shell: bash
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"

View File

@@ -1,19 +1,19 @@
name: 'Build Documentation' name: 'Build Documentation'
description: 'Build documentation on a mac' description: 'Build documentation'
runs: runs:
using: "composite" using: "composite"
steps: steps:
- name: Setup machine - name: Setup machine
uses: ./.github/actions/setup-macos uses: ./.github/actions/setup-linux
- name: Install dependencies - name: Install dependencies
shell: sh shell: bash
run: | run: |
brew install doxygen sudo apt-get install -y doxygen
uv pip install --upgrade pip cmake source .venv/bin/activate
uv pip install -r docs/requirements.txt pip install -r docs/requirements.txt
uv pip install . -v pip install . -v
- name: Build documentation - name: Build documentation
shell: bash shell: bash
@@ -24,8 +24,8 @@ runs:
make html O=-W make html O=-W
- name: Create artifact tar - name: Create artifact tar
shell: sh shell: bash
run: tar -cf artifact.tar --cd docs --dereference build/html index.html run: tar -cf artifact.tar -C docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar # Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact - name: Upload artifact

View File

@@ -7,6 +7,13 @@ inputs:
type: boolean type: boolean
required: false required: false
default: false default: false
arch:
description: 'Platform architecture tag'
required: true
type: choice
options:
- x86_64
- aarch64
runs: runs:
using: "composite" using: "composite"
@@ -23,11 +30,11 @@ runs:
pip install auditwheel patchelf build pip install auditwheel patchelf build
python setup.py clean --all python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w MLX_BUILD_STAGE=1 python -m build -w
bash python/scripts/repair_linux.sh bash python/scripts/repair_linux.sh ${{ inputs.arch }}
- name: Build backend package - name: Build backend package
if: ${{ inputs.build-backend }} if: ${{ inputs.build-backend }}
shell: bash shell: bash
run: | run: |
python setup.py clean --all python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64 auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}

View File

@@ -1,15 +1,32 @@
name: 'Build and Test on Linux' name: 'Build and Test on Linux'
description: 'Build and test MLX on Linux'
inputs:
toolkit:
description: 'The toolkit to build with'
required: false
default: 'cpu'
runs: runs:
using: "composite" using: "composite"
steps: steps:
- name: Install Python package - name: Install Python package
id: python_build
shell: sh shell: sh
env: env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
DEBUG: 1 DEBUG: 1
run: pip install -e ".[dev]" -v CMAKE_ARGS: >-
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
run: |
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
# There is no GPU in arm64 runner, use a common arch.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
# Can not build tests when the built executables can not run.
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
fi
pip install --no-build-isolation -e ".[dev]" -v
# Pass the CMAKE_ARGS to following steps.
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
- name: Generate package stubs - name: Generate package stubs
shell: sh shell: sh
@@ -17,25 +34,8 @@ runs:
pip install typing_extensions pip install typing_extensions
python setup.py generate_stubs python setup.py generate_stubs
- name: Run Python tests
shell: bash
run: |
python -m unittest discover python/tests -v
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
- name: Build CPP only - name: Build CPP only
shell: bash shell: bash
run: | run: |
mkdir -p build && cd build cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG cmake --build build -j $(nproc)
make -j $(nproc)
- name: Run CPP tests
shell: sh
run: ./build/tests/tests

View File

@@ -16,18 +16,19 @@ runs:
using: "composite" using: "composite"
steps: steps:
- name: Build Python package - name: Build Python package
shell: bash shell: bash -l {0}
env: env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }} MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: | run: |
uv pip install build pip install build
uv run --no-project setup.py clean --all python setup.py clean --all
MLX_BUILD_STAGE=1 uv run -m build -w MLX_BUILD_STAGE=1 python -m build -w
- name: Build backend package - name: Build backend package
if: ${{ inputs.build-backend }} if: ${{ inputs.build-backend }}
shell: bash shell: bash -l {0}
env: env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }} MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: | run: |
uv run --no-project setup.py clean --all python setup.py clean --all
MLX_BUILD_STAGE=2 uv run -m build -w MLX_BUILD_STAGE=2 python -m build -w

View File

@@ -5,47 +5,47 @@ runs:
using: "composite" using: "composite"
steps: steps:
- name: Install dependencies - name: Install dependencies
shell: sh
env: env:
DEBUG: 1 DEBUG: 1
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
shell: bash -l {0}
run: | run: |
uv pip install --upgrade pip pip install --upgrade pip
uv pip install cmake setuptools nanobind==2.4.0 pip install cmake setuptools nanobind==2.10.2
uv pip install -e . -v pip install -e . -v
- name: Generate package stubs - name: Generate package stubs
shell: bash shell: bash -l {0}
run: | run: |
uv pip install typing_extensions pip install typing_extensions
uv run --no-project setup.py generate_stubs python setup.py generate_stubs
- name: Install tests dependencies - name: Install tests dependencies
shell: sh shell: bash -l {0}
run: | run: |
uv pip install numpy torch tensorflow unittest-xml-reporting pip install numpy torch tensorflow unittest-xml-reporting
- name: Run Python tests - name: Run Python tests
shell: bash shell: bash -l {0}
env: env:
LOW_MEMORY: 1 LOW_MEMORY: 1
run: | run: |
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2) mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension - name: Build example extension
shell: bash shell: bash -l {0}
run: | run: |
cd examples/extensions cd examples/extensions
uv pip install -r requirements.txt pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace python setup.py build_ext --inplace
uv run --no-project test.py python test.py
- name: Build CPP only - name: Build CPP only
shell: bash shell: bash -l {0}
run: | run: |
mkdir -p build mkdir -p build
cd build cd build
@@ -53,7 +53,7 @@ runs:
make -j $(sysctl -n hw.ncpu) make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests - name: Run CPP tests
shell: bash shell: bash -l {0}
env: env:
DEVICE: gpu DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1 METAL_DEVICE_WRAPPER_TYPE: 1
@@ -61,7 +61,7 @@ runs:
run: ./build/tests/tests run: ./build/tests/tests
- name: Build small binary with JIT - name: Build small binary with JIT
shell: bash shell: bash -l {0}
run: | run: |
mkdir -p build mkdir -p build
cd build cd build
@@ -74,7 +74,7 @@ runs:
make -j $(sysctl -n hw.ncpu) make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT - name: Run Python tests with JIT
shell: bash shell: bash -l {0}
env: env:
LOW_MEMORY: 1 LOW_MEMORY: 1
DEVICE: gpu DEVICE: gpu
@@ -82,7 +82,7 @@ runs:
METAL_DEBUG_ERROR_MODE: 0 METAL_DEBUG_ERROR_MODE: 0
run: | run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \ CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v pip install -e . -v
uv run -m xmlrunner discover \ python -m xmlrunner discover \
-v python/tests \ -v python/tests \
-o test-results/gpu_jit -o test-results/gpu_jit

View File

@@ -2,72 +2,82 @@ name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds' description: 'Install dependencies for Linux builds'
inputs: inputs:
runner-type: toolkit:
description: 'Whether to set this up as a linux or CUDA runner' description: 'Which toolkit to install'
required: false required: false
default: 'linux' default: 'cpu'
type: choice
options:
- linux
- cuda
python-version: python-version:
description: 'Version of python to set up' description: 'Version of python to set up'
required: false required: false
default: '3.10' default: '3.10'
use-ccache:
description: 'Whether to enable ccache'
required: false
default: 'true'
runs: runs:
using: "composite" using: "composite"
steps: steps:
- name: Free disk space
shell: sh
if: inputs.runner-type == 'linux'
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Install common dependencies - name: Install common dependencies
env:
TZ: Etc/UTC
shell: bash shell: bash
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
sudo apt autoremove -y
- name: Use ccache
if: ${{ inputs.use-ccache == 'true' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}
max-size: 1GB
# ccache-action bug: running "apt-get update" fails on large arm runner.
update-package-index: false
- uses: actions/setup-python@v6 - uses: actions/setup-python@v6
with: with:
python-version: ${{ inputs.python-version }} python-version: ${{ inputs.python-version }}
cache: 'pip'
- name: setup python venv - name: Setup Python venv
shell: bash shell: bash
run: | run: |
python -m venv .venv python -m venv .venv
source .venv/bin/activate source .venv/bin/activate
pip install setuptools cmake nanobind==2.10.2
echo PATH=$PATH >> $GITHUB_ENV echo PATH=$PATH >> $GITHUB_ENV
pip install --upgrade pip cmake # Make cmake search .venv for nanobind
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
- name: Install MPI - name: Install MPI
if: inputs.runner-type == 'linux'
shell: bash shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Network CUDA installation from packages - name: Install CUDA toolkit
id: install-cuda if: ${{ startsWith(inputs.toolkit, 'cuda') }}
if: inputs.runner-type == 'cuda' shell: bash
env: env:
TZ: Etc/UTC # Note: the CI machine does not meet CUDA 13's driver requirement.
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64 # Compatibility matrix:
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
PACKAGES: |
{
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
}
run: | run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb # The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
# Jetson specific. SBSA means Arm Server Base System Architecture.
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update sudo apt-get update
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9 sudo apt-get install -y \
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers libnccl2 libnccl-dev \
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly. ${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
- name: Package and Driver Report - name: CUDA packages and driver report
if: inputs.runner-type == 'cuda' if: ${{ startsWith(inputs.toolkit, 'cuda') }}
shell: bash shell: bash
run: | run: |
sudo apt-get install -y ubuntu-drivers-common dkms sudo apt-get install -y ubuntu-drivers-common dkms

View File

@@ -18,8 +18,7 @@ runs:
shell: bash shell: bash
run: xcodebuild -showComponent MetalToolchain run: xcodebuild -showComponent MetalToolchain
- name: Setup uv - uses: conda-incubator/setup-miniconda@v3
uses: astral-sh/setup-uv@v6
with: with:
python-version: ${{ inputs.python-version }} miniconda-version: "latest"
activate-environment: true python-version: ${{ inputs.python-version }}

69
.github/actions/test-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,69 @@
name: 'Run Linux tests'
inputs:
has-gpu:
description: 'Run GPU tests'
required: false
default: false
runs:
using: "composite"
steps:
- name: Run MPI tests
shell: bash
run: |
echo "::group::MPI tests"
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
echo "::endgroup::"
- name: Run distributed tests
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
run: |
echo "::group::Distributed tests"
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
echo "::endgroup::"
- name: Run Python tests - CPU
if: ${{ inputs.has-gpu == 'false' }}
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::Python tests - CPU"
python -m unittest discover python/tests -v
echo "::endgroup::"
- name: Run Python tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::Python tests - GPU"
python -m tests discover python/tests -v
echo "::endgroup::"
- name: Run CPP tests - CPU
shell: bash
env:
DEVICE: cpu
run: |
echo "::group::CPP tests - CPU"
./build/tests/tests
echo "::endgroup::"
- name: Run CPP tests - GPU
if: ${{ inputs.has-gpu == 'true' }}
shell: bash
env:
DEVICE: gpu
run: |
echo "::group::CPP tests - GPU"
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
echo "::endgroup::"

108
.github/workflows/build_and_test.yml vendored Normal file
View File

@@ -0,0 +1,108 @@
name: Build and Test
on:
pull_request:
push:
branches:
- main
# For testing CI without starting a pull request:
- test/*
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
check_lint:
name: Check Lint
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
name: Linux (cpu, ${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
cuda_build_and_test:
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
if: github.repository == 'ml-explore/mlx'
needs: check_lint
strategy:
fail-fast: false
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.6', 'cuda-12.9']
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/build-linux
with:
toolkit: ${{ matrix.toolkit }}
- uses: ./.github/actions/test-linux
if: matrix.arch == 'x86_64'
with:
has-gpu: true
mac_build_and_test:
name: macOS (${{ matrix.macos-target }})
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
macos-target: ["14.0", "15.0"]
runs-on: [self-hosted, macos]
env:
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
build_documentation:
name: Build Documentation
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22.04
needs: check_lint
steps:
- uses: actions/checkout@v6
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora (${{ matrix.arch }})
needs: check_lint
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v6
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -8,9 +8,9 @@ permissions:
jobs: jobs:
build: build:
runs-on: [self-hosted, macos] runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: ./.github/actions/build-docs - uses: ./.github/actions/build-docs
deploy: deploy:

View File

@@ -16,11 +16,12 @@ jobs:
python_version: ["3.10", "3.14"] python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux - uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux-release - uses: ./.github/actions/build-linux-release
with: with:
build-backend: ${{ matrix.python-version == '3.10' }} build-backend: ${{ matrix.python-version == '3.10' }}
arch: "x86_64"
- name: Upload mlx artifacts - name: Upload mlx artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v5
with: with:
@@ -39,14 +40,18 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"] python_version: ["3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04 runner:
- ubuntu-22.04
- ubuntu-22.04-arm
runs-on: ${{ matrix.runner }}
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux - uses: ./.github/actions/setup-linux
with: with:
python-version: ${{ matrix.python_version }} python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux - uses: ./.github/actions/build-linux
- uses: ./.github/actions/test-linux
build_mac_release: build_mac_release:
if: github.repository == 'ml-explore/mlx' if: github.repository == 'ml-explore/mlx'
@@ -55,12 +60,11 @@ jobs:
python-version: ["3.10", "3.13"] python-version: ["3.10", "3.13"]
runs-on: [self-hosted, macos] runs-on: [self-hosted, macos]
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos - uses: ./.github/actions/setup-macos
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos - uses: ./.github/actions/build-macos
- name: Build macOS 15 package - name: Build macOS 15 package
uses: ./.github/actions/build-macos-release uses: ./.github/actions/build-macos-release
with: with:
@@ -72,53 +76,21 @@ jobs:
macos-target: 14.0 macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }} build-backend: ${{ matrix.python-version == '3.10' }}
build_cuda_with_tests:
if: github.repository == 'ml-explore/mlx'
runs-on: gpu-t4-4-core
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_cuda_release: build_cuda_release:
if: github.repository == 'ml-explore/mlx' if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large runs-on: ubuntu-22-large
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux - uses: ./.github/actions/setup-linux
with: with:
runner-type: 'cuda' toolkit: 'cuda-12.9'
- name: Build Python package - name: Build Python package
uses: ./.github/actions/build-cuda-release uses: ./.github/actions/build-cuda-release
with: with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc' toolkit: 'cuda-12.9'
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v5
with: with:
name: mlx-cuda name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl path: wheelhouse/mlx_cuda-*.whl
retention-days: 7 retention-days: 7
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -1,71 +0,0 @@
name: Build and Test
on: pull_request
permissions:
contents: read
jobs:
check_lint:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
mac_build_and_test:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
cuda_build_and_test:
if: github.repository == 'ml-explore/mlx'
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -5,6 +5,11 @@ on:
tags: tags:
- 'v*' - 'v*'
workflow_dispatch: workflow_dispatch:
inputs:
dev_release:
description: "Do a dev release or regular release"
required: true
default: "false"
permissions: permissions:
contents: read contents: read
@@ -12,18 +17,15 @@ permissions:
jobs: jobs:
setup: setup:
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs:
pypi_env: ${{ github.event_name == 'push' && 'pypi' || 'test-pypi' }}
pypi_url: ${{ github.event_name == 'push' && 'https://upload.pypi.org/legacy/' || 'https://test.pypi.org/legacy/' }}
steps: steps:
- name: Set publishing variables - name: Set publishing variables
run: echo "Publishing setup complete" run: echo "Publishing setup complete"
build_documentation: build_documentation:
if: github.repository == 'ml-explore/mlx' if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos] runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: ./.github/actions/build-docs - uses: ./.github/actions/build-docs
deploy_documentation: deploy_documentation:
@@ -45,27 +47,33 @@ jobs:
strategy: strategy:
matrix: matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"] python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04 arch: ['x86_64', 'aarch64']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
env: env:
PYPI_RELEASE: 1 PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux - uses: ./.github/actions/setup-linux
with: with:
python-version: ${{ matrix.python_version }} python-version: ${{ matrix.python_version }}
use-ccache: false
- uses: ./.github/actions/build-linux-release - uses: ./.github/actions/build-linux-release
with: with:
build-backend: ${{ matrix.python-version == '3.10' }} build-backend: ${{ matrix.python-version == '3.10' }}
arch: ${{ matrix.arch }}
- name: Upload MLX artifacts - name: Upload MLX artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v5
with: with:
name: linux-wheels-${{ matrix.python_version }} overwrite: true
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts - name: Upload CPU artifacts
if: matrix.python_version == '3.10' if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v5
with: with:
name: mlx-cpu overwrite: true
name: mlx-cpu-${{ matrix.arch }}
path: wheelhouse/mlx_cpu-*.whl path: wheelhouse/mlx_cpu-*.whl
build_mac_release: build_mac_release:
@@ -76,22 +84,25 @@ jobs:
runs-on: [self-hosted, macos] runs-on: [self-hosted, macos]
env: env:
PYPI_RELEASE: 1 PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: ./.github/actions/setup-macos - uses: ./.github/actions/setup-macos
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
shell: sh shell: bash -l {0}
run: | run: |
uv pip install --upgrade pip pip install --upgrade pip
uv pip install cmake setuptools nanobind==2.4.0 pip install cmake setuptools nanobind==2.10.2
uv pip install -e . -v pip install -e . -v
- name: Generate package stubs - name: Generate package stubs
shell: bash shell: bash -l {0}
run: | run: |
uv pip install typing_extensions pip install typing_extensions
uv run --no-project setup.py generate_stubs python setup.py generate_stubs
- name: Build macOS 14 package - name: Build macOS 14 package
uses: ./.github/actions/build-macos-release uses: ./.github/actions/build-macos-release
with: with:
@@ -105,32 +116,41 @@ jobs:
- name: Upload MLX artifacts - name: Upload MLX artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v5
with: with:
overwrite: true
name: mac-wheels-${{ matrix.python-version }} name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl path: dist/mlx-*.whl
- name: Upload Metal artifacts - name: Upload Metal artifacts
if: matrix.python-version == '3.10' if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v5
with: with:
overwrite: true
name: mlx-metal name: mlx-metal
path: dist/mlx_metal-*.whl path: dist/mlx_metal-*.whl
build_cuda_release: build_cuda_release:
if: github.repository == 'ml-explore/mlx' if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large strategy:
matrix:
arch: ['x86_64', 'aarch64']
toolkit: ['cuda-12.9', 'cuda-13.0']
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
env: env:
PYPI_RELEASE: 1 PYPI_RELEASE: 1
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v6
- uses: ./.github/actions/setup-linux - uses: ./.github/actions/setup-linux
with: with:
runner-type: 'cuda' toolkit: ${{ matrix.toolkit }}
use-ccache: false
- name: Build Python package - name: Build Python package
uses: ./.github/actions/build-cuda-release uses: ./.github/actions/build-cuda-release
with: with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc' arch: ${{ matrix.arch }}
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v5
with: with:
overwrite: true
name: mlx-cuda name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl path: wheelhouse/mlx_cuda-*.whl
@@ -141,7 +161,7 @@ jobs:
permissions: permissions:
id-token: write id-token: write
environment: environment:
name: ${{ needs.setup.outputs.pypi_env }} name: pypi
url: https://pypi.org/p/mlx url: https://pypi.org/p/mlx
steps: steps:
- uses: actions/download-artifact@v6 - uses: actions/download-artifact@v6
@@ -159,7 +179,7 @@ jobs:
- name: Publish package distributions to PyPI - name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1
with: with:
repository-url: ${{ needs.setup.outputs.pypi_url }} repository-url: https://upload.pypi.org/legacy/
pypi-publish-cuda: pypi-publish-cuda:
name: Upload CUDA release to PyPI name: Upload CUDA release to PyPI
@@ -168,7 +188,7 @@ jobs:
permissions: permissions:
id-token: write id-token: write
environment: environment:
name: ${{ needs.setup.outputs.pypi_env }} name: pypi
url: https://pypi.org/p/mlx-cuda url: https://pypi.org/p/mlx-cuda
steps: steps:
- uses: actions/download-artifact@v6 - uses: actions/download-artifact@v6
@@ -180,7 +200,7 @@ jobs:
- name: Publish package distributions to PyPI - name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1
with: with:
repository-url: ${{ needs.setup.outputs.pypi_url }} repository-url: https://upload.pypi.org/legacy/
pypi-publish-cpu: pypi-publish-cpu:
name: Upload CPU release to PyPI name: Upload CPU release to PyPI
@@ -189,19 +209,20 @@ jobs:
permissions: permissions:
id-token: write id-token: write
environment: environment:
name: ${{ needs.setup.outputs.pypi_env }} name: pypi
url: https://pypi.org/p/mlx-cpu url: https://pypi.org/p/mlx-cpu
steps: steps:
- uses: actions/download-artifact@v6 - uses: actions/download-artifact@v6
with: with:
name: mlx-cpu pattern: mlx-cpu-*
merge-multiple: true
path: dist path: dist
- name: Display structure of downloaded files - name: Display structure of downloaded files
run: ls -R dist run: ls -R dist
- name: Publish package distributions to PyPI - name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1
with: with:
repository-url: ${{ needs.setup.outputs.pypi_url }} repository-url: https://upload.pypi.org/legacy/
pypi-publish-metal: pypi-publish-metal:
name: Upload Metal release to PyPI name: Upload Metal release to PyPI
@@ -210,7 +231,7 @@ jobs:
permissions: permissions:
id-token: write id-token: write
environment: environment:
name: ${{ needs.setup.outputs.pypi_env }} name: pypi
url: https://pypi.org/p/mlx-metal url: https://pypi.org/p/mlx-metal
steps: steps:
- uses: actions/download-artifact@v6 - uses: actions/download-artifact@v6
@@ -222,5 +243,4 @@ jobs:
- name: Publish package distributions to PyPI - name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1 uses: pypa/gh-action-pypi-publish@release/v1
with: with:
repository-url: ${{ needs.setup.outputs.pypi_url }} repository-url: https://upload.pypi.org/legacy/

View File

@@ -74,6 +74,7 @@ endif()
if(MLX_USE_CCACHE) if(MLX_USE_CCACHE)
find_program(CCACHE_PROGRAM ccache) find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM) if(CCACHE_PROGRAM)
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
@@ -272,7 +273,7 @@ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.") message(STATUS "Building Python bindings.")
find_package( find_package(
Python 3.8 Python 3.10
COMPONENTS Interpreter Development.Module COMPONENTS Interpreter Development.Module
REQUIRED) REQUIRED)
execute_process( execute_process(

View File

@@ -75,7 +75,7 @@ void time_irregular_binary_ops_3D() {
void time_irregular_binary_ops_4D() { void time_irregular_binary_ops_4D() {
auto device = mx::default_device(); auto device = mx::default_device();
std::vector<int> shape = {8, 8, 512, 512}; mx::Shape shape = {8, 8, 512, 512};
auto a = mx::random::uniform(shape); auto a = mx::random::uniform(shape);
auto b = mx::random::uniform(shape); auto b = mx::random::uniform(shape);
@@ -115,7 +115,7 @@ void time_irregular_binary_ops_4D() {
void time_irregular_reshape() { void time_irregular_reshape() {
auto device = mx::default_device(); auto device = mx::default_device();
std::vector<int> shape; mx::Shape shape;
auto reshape_fn = [&shape, device](const mx::array& a) { auto reshape_fn = [&shape, device](const mx::array& a) {
return mx::reshape(a, shape, device); return mx::reshape(a, shape, device);
}; };
@@ -170,7 +170,7 @@ void time_irregular_astype_1D() {
void time_irregular_astype_2D() { void time_irregular_astype_2D() {
auto device = mx::default_device(); auto device = mx::default_device();
int size = 2048; int size = 2048;
std::vector<int> shape = {size, size}; mx::Shape shape = {size, size};
auto a = mx::random::uniform(shape); auto a = mx::random::uniform(shape);
TIMEM("2D regular", mx::astype, a, mx::int32, device); TIMEM("2D regular", mx::astype, a, mx::int32, device);

View File

@@ -1,6 +1,5 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import argparse
import os import os
import subprocess import subprocess
import time import time

View File

@@ -0,0 +1,212 @@
import math
import os
import subprocess
import time
from copy import copy
from functools import partial
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
from matplotlib.ticker import FuncFormatter
RESULTS_DIR = "./results"
if not os.path.isdir(RESULTS_DIR):
os.mkdir(RESULTS_DIR)
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
TORCH_DEVICE = torch.device(
"mps"
if torch.backends.mps.is_available()
else ("cuda" if torch.cuda.is_available() else "cpu")
)
N_WARMUP = 5
N_ITER_BENCH = 50
N_ITER_FUNC = 20
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
D_TYPES = ("float32", "float16")
def _power_of_two_formatter(value, _position):
if value <= 0:
return ""
exponent = int(round(math.log2(value)))
if abs(value - (1 << exponent)) / value > 1e-6:
return f"{value:g}"
return f"$2^{{{exponent}}}$"
def torch_sync():
if TORCH_DEVICE.type == "cuda":
torch.cuda.synchronize()
elif TORCH_DEVICE.type == "mps":
torch.mps.synchronize()
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
outs = []
for _ in range(N_ITER_FUNC):
out = copy(self_arr)
out[mask_arr] = src_arr
outs.append(out)
mx.eval(outs)
return outs
@torch.no_grad()
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
outs = []
for _ in range(N_ITER_FUNC):
out = self_tensor.clone()
out.masked_scatter_(mask_tensor, src_tensor)
outs.append(out)
torch_sync()
return outs
def measure(fn):
for _ in range(N_WARMUP):
fn()
start = time.perf_counter_ns()
for _ in range(N_ITER_BENCH):
fn()
end = time.perf_counter_ns()
return (end - start) * 1e-9
def bytes_touched(length, true_count, item_size):
mask_bytes = length
self_bytes = length * item_size * 2 # read + write
src_bytes = true_count * item_size
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
def build_case(length, density, np_dtype, torch_dtype):
true_count = max(1, int(round(length * density)))
rng = np.random.default_rng()
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
mask_np = np.zeros(length, dtype=bool)
mask_np[:true_count] = True
rng.shuffle(mask_np)
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
self_mlx = mx.array(self_np)
mask_mlx = mx.array(mask_np)
src_mlx = mx.array(src_np)
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
# Correctness check once per configuration
mx_out = mx.array(self_np)
mx_out[mask_mlx] = src_mlx
mx.eval(mx_out)
torch_out = self_torch.clone()
torch_out.masked_scatter_(mask_torch, src_torch)
atol = 5e-3 if np_dtype == np.float16 else 1e-5
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
raise AssertionError("masked_scatter results diverged between MLX and Torch")
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
def bench_case(length, density, dtype):
np_dtype = getattr(np, dtype)
torch_dtype = getattr(torch, dtype)
(
self_mlx,
mask_mlx,
src_mlx,
self_torch,
mask_torch,
src_torch,
true_count,
) = build_case(length, density, np_dtype, torch_dtype)
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
time_torch = measure(
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
)
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
bytes_per_gb = float(1024**3)
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
return time_mlx, time_torch, mlx_gbps, torch_gbps
def plot_density(ax_perf, ax_speedup, density, dtype):
mlx_gbps = []
torch_gbps = []
mlx_times = []
torch_times = []
for length in VECTOR_LENGTHS:
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
mlx_gbps.append(gbps_mlx)
torch_gbps.append(gbps_torch)
mlx_times.append(t_mlx)
torch_times.append(t_torch)
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
ax_perf.set_xscale("log", base=2)
ax_perf.set_xticks(VECTOR_LENGTHS)
formatter = FuncFormatter(_power_of_two_formatter)
ax_perf.xaxis.set_major_formatter(formatter)
ax_perf.set_title(f"density={density:.2f}")
ax_perf.set_ylabel("GB/s")
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
ax_perf.legend()
speedup = np.array(torch_times) / np.array(mlx_times)
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
ax_speedup.set_xscale("log", base=2)
ax_speedup.set_xticks(VECTOR_LENGTHS)
ax_speedup.xaxis.set_major_formatter(formatter)
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
def main():
for dtype in D_TYPES:
fig, axs = plt.subplots(
len(MASK_DENSITIES),
2,
figsize=(10, 12),
layout="constrained",
sharex=True,
)
for i, density in enumerate(MASK_DENSITIES):
plot_density(axs[i][0], axs[i][1], density, dtype)
axs[i][0].set_xlabel("vector length")
axs[i][1].set_xlabel("vector length")
fig.suptitle(
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
)
output_path = os.path.join(
RESULTS_DIR,
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
)
fig.savefig(output_path)
plt.close(fig)
if __name__ == "__main__":
main()

3
cmake/Findnvpl.cmake Normal file
View File

@@ -0,0 +1,3 @@
# This file does nothing but to suppress the cmake warning: "By not providing
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.

View File

@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
.. code-block:: shell .. code-block:: shell
pip install mlx[cuda] pip install mlx[cuda12]
To install the CUDA package from PyPi your system must meet the following To install the CUDA package from PyPi your system must meet the following
requirements: requirements:
- Nvidia architecture >= SM 7.0 (Volta) - Nvidia architecture >= SM 7.5
- Nvidia driver >= 550.54.14 - Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0 - CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35 - Linux distribution with glibc >= 2.35
- Python >= 3.10 - Python >= 3.10
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
CPU-only (Linux) CPU-only (Linux)
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^

View File

@@ -70,7 +70,8 @@ Differences from NumPy
* Indexing does not perform bounds checking. Indexing out of bounds is * Indexing does not perform bounds checking. Indexing out of bounds is
undefined behavior. undefined behavior.
* Boolean mask based indexing is not yet supported. * Boolean mask based indexing is supported for assignment only (see
:ref:`boolean-mask-assignment`).
The reason for the lack of bounds checking is that exceptions cannot propagate The reason for the lack of bounds checking is that exceptions cannot propagate
from the GPU. Performing bounds checking for array indices before launching the from the GPU. Performing bounds checking for array indices before launching the
@@ -143,3 +144,51 @@ expected. For example:
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx`` In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
and ones elsewhere. and ones elsewhere.
.. _boolean-mask-assignment:
Boolean Mask Assignment
-----------------------
MLX supports boolean indices using NumPy syntax. A mask must already be
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
Other index types are routed through the standard scatter code.
.. code-block:: shell
>>> a = mx.array([1.0, 2.0, 3.0])
>>> mask = mx.array([True, False, True])
>>> updates = mx.array([5.0, 6.0])
>>> a[mask] = updates
>>> a
array([5.0, 2.0, 6.0], dtype=float32)
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
assignments, ``updates`` must provide at least as many elements as there are
``True`` entries in ``mask``.
.. code-block:: shell
>>> a = mx.zeros((2, 3))
>>> mask = mx.array([[True, False, True],
[False, False, True]])
>>> a[mask] = 1.0
>>> a
array([[1.0, 0.0, 1.0],
[0.0, 0.0, 1.0]], dtype=float32)
Boolean masks follow NumPy semantics:
- The mask shape must match the shape of the axes it indexes exactly. The only
exception is a scalar boolean mask, which broadcasts to the full array.
- Any axes not covered by the mask are taken in full.
.. code-block:: shell
>>> a = mx.arange(1000).reshape(10, 10, 10)
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
axes and therefore raise errors.

View File

@@ -3,6 +3,6 @@ requires = [
"setuptools>=42", "setuptools>=42",
"cmake>=3.25", "cmake>=3.25",
"mlx>=0.18.0", "mlx>=0.18.0",
"nanobind==2.4.0", "nanobind==2.10.2",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.25 cmake>=3.25
mlx>=0.21.0 mlx>=0.21.0
nanobind==2.4.0 nanobind==2.10.2

View File

@@ -1,7 +1,6 @@
target_sources( target_sources(
mlx mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp

View File

@@ -1,24 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <sstream>
#include "mlx/allocator.h"
namespace mlx::core::allocator {
Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
void free(Buffer buffer) {
allocator().free(buffer);
}
} // namespace mlx::core::allocator

View File

@@ -28,16 +28,16 @@ class Buffer {
}; };
}; };
Buffer malloc(size_t size);
void free(Buffer buffer);
class Allocator { class Allocator {
/** Abstract base class for a memory allocator. */ /** Abstract base class for a memory allocator. */
public: public:
virtual Buffer malloc(size_t size) = 0; virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0; virtual size_t size(Buffer buffer) const = 0;
virtual Buffer make_buffer(void* ptr, size_t size) {
return Buffer{nullptr};
};
virtual void release(Buffer buffer) {}
Allocator() = default; Allocator() = default;
Allocator(const Allocator& other) = delete; Allocator(const Allocator& other) = delete;
@@ -49,4 +49,25 @@ class Allocator {
Allocator& allocator(); Allocator& allocator();
inline Buffer malloc(size_t size) {
return allocator().malloc(size);
}
inline void free(Buffer buffer) {
allocator().free(buffer);
}
// Make a Buffer from a raw pointer of the given size without a copy. If a
// no-copy conversion is not possible then the returned buffer.ptr() will be
// nullptr. Any buffer created with this function must be released with
// release(buffer)
inline Buffer make_buffer(void* ptr, size_t size) {
return allocator().make_buffer(ptr, size);
};
// Release a buffer from the allocator made with make_buffer
inline void release(Buffer buffer) {
allocator().release(buffer);
}
} // namespace mlx::core::allocator } // namespace mlx::core::allocator

View File

@@ -82,6 +82,28 @@ array::array(std::initializer_list<int> data, Dtype dtype)
init(data.begin()); init(data.begin());
} }
array::array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
auto buffer = allocator::make_buffer(data, nbytes());
if (buffer.ptr() == nullptr) {
set_data(allocator::malloc(nbytes()));
auto ptr = static_cast<char*>(data);
std::copy(ptr, ptr + nbytes(), this->data<char>());
deleter(data);
} else {
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
auto ptr = buffer.ptr();
allocator::release(buffer);
return deleter(ptr);
};
set_data(buffer, std::move(wrapped_deleter));
}
}
/* Build an array from a shared buffer */ /* Build an array from a shared buffer */
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter) array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
@@ -167,7 +189,7 @@ void array::copy_shared_buffer(
const Strides& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset /* = 0 */) { int64_t offset /* = 0 */) {
array_desc_->data = other.array_desc_->data; array_desc_->data = other.array_desc_->data;
array_desc_->strides = strides; array_desc_->strides = strides;
array_desc_->flags = flags; array_desc_->flags = flags;

View File

@@ -57,6 +57,16 @@ class array {
Shape shape, Shape shape,
Dtype dtype = TypeToDtype<T>()); Dtype dtype = TypeToDtype<T>());
/* Build an array from a raw pointer. The constructor will attempt to use the
* input data without a copy. The deleter will be called when the array no
* longer needs the underlying memory - after the array is destroyed in the
* no-copy case and after the copy otherwise. */
explicit array(
void* data,
Shape shape,
Dtype dtype,
const std::function<void(void*)>& deleter);
/* Build an array from a buffer */ /* Build an array from a buffer */
explicit array( explicit array(
allocator::Buffer data, allocator::Buffer data,
@@ -439,7 +449,7 @@ class array {
const Strides& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset = 0); int64_t offset = 0);
void copy_shared_buffer(const array& other); void copy_shared_buffer(const array& other);

View File

@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Not a constant // - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && is_constant(i)) { in.is_donatable() && !is_constant(i)) {
outputs[o++].copy_shared_buffer(in); outputs[o++].copy_shared_buffer(in);
} }
// Get representative input flags to properly set non-donated outputs // Get representative input flags to properly set non-donated outputs
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
is_constant(i)) { !is_constant(i)) {
outputs[o].copy_shared_buffer( outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size()); in, outputs[o].strides(), in.flags(), in.data_size());
o++; o++;

View File

@@ -14,17 +14,13 @@ std::tuple<int64_t, Strides> prepare_slice(
data_offset += start_indices[i] * in.strides()[i]; data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i]; inp_strides[i] = in.strides()[i] * strides[i];
} }
// Normalize the offset
if (data_offset < 0) {
data_offset += in.data_size();
}
return std::make_tuple(data_offset, inp_strides); return std::make_tuple(data_offset, inp_strides);
} }
void shared_buffer_slice( void shared_buffer_slice(
const array& in, const array& in,
const Strides& out_strides, const Strides& out_strides,
size_t data_offset, int64_t data_offset,
size_t data_size, size_t data_size,
array& out) { array& out) {
// Compute row/col contiguity // Compute row/col contiguity
@@ -51,17 +47,24 @@ void slice(
// Calculate out strides, initial offset // Calculate out strides, initial offset
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides); auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
int64_t data_end = 1;
for (int i = 0; i < start_indices.size(); ++i) { // Get the location of the end based on the inp strides and out.shape()
if (in.shape()[i] > 1) { int64_t low_idx = 0;
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1; int64_t high_idx = 0;
data_end += end_idx * in.strides()[i]; for (int i = 0; i < inp_strides.size(); ++i) {
auto delta = inp_strides[i] * (out.shape()[i] - 1);
if (inp_strides[i] > 0) {
high_idx += delta;
} else {
low_idx += delta;
} }
} }
if (data_end < 0) { int64_t data_size = (high_idx - low_idx) + 1;
data_end += in.data_size(); if (data_size < 0) {
std::ostringstream msg;
msg << "[slice] Computed invalid data size: " << data_size << ".";
throw std::runtime_error(msg.str());
} }
size_t data_size = (data_end - data_offset);
shared_buffer_slice(in, inp_strides, data_offset, data_size, out); shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
} }

View File

@@ -12,6 +12,167 @@ namespace mlx::core {
namespace { namespace {
template <typename T>
complex64_t to_complex(T r, T i) {
return {static_cast<float>(r), static_cast<float>(i)};
}
template <typename T, class Enable = void>
struct EigWork {};
template <typename T>
struct EigWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using O = complex64_t;
char jobl;
char jobr;
int N;
int lwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
T work;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
if (compute_eigenvectors) {
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
}
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, O* values, O* vectors) {
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
T* vec_tmp = nullptr;
if (vectors) {
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
}
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
eig_tmp,
eig_tmp + N,
vectors ? vec_tmp : nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
work,
&lwork,
&info);
for (int i = 0; i < N; ++i) {
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
}
if (vectors) {
for (int i = 0; i < N; ++i) {
if (values[i].imag() != 0) {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] =
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
vectors[(i + 1) * N + j] =
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
}
}
}
}
}
};
template <>
struct EigWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
using O = T;
char jobl;
char jobr;
int N;
int lwork;
int lrwork;
int info;
std::vector<array::Data> buffers;
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
T work;
R rwork;
int n_vecs_l = compute_eigenvectors ? N_ : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&rwork,
&info);
lwork = static_cast<int>(work.real());
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
}
void run(T* a, T* values, T* vectors) {
int n_vecs_l = vectors ? N : 1;
int n_vecs_r = 1;
geev<T>(
&jobl,
&jobr,
&N,
a,
&N,
values,
vectors,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(buffers[0].buffer.raw_ptr()),
&lwork,
static_cast<R*>(buffers[1].buffer.raw_ptr()),
&info);
}
};
template <typename T> template <typename T>
void eig_impl( void eig_impl(
array& a, array& a,
@@ -19,101 +180,39 @@ void eig_impl(
array& values, array& values,
bool compute_eigenvectors, bool compute_eigenvectors,
Stream stream) { Stream stream) {
using OT = std::complex<T>;
auto a_ptr = a.data<T>(); auto a_ptr = a.data<T>();
auto eig_ptr = values.data<OT>(); auto val_ptr = values.data<complex64_t>();
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_output_array(values); encoder.set_output_array(values);
OT* vec_ptr = nullptr; complex64_t* vec_ptr = nullptr;
if (compute_eigenvectors) { if (compute_eigenvectors) {
encoder.set_output_array(vectors); encoder.set_output_array(vectors);
vec_ptr = vectors.data<OT>(); vec_ptr = vectors.data<complex64_t>();
} }
encoder.dispatch([a_ptr, encoder.dispatch([a_ptr,
val_ptr,
vec_ptr, vec_ptr,
eig_ptr,
compute_eigenvectors, compute_eigenvectors,
N = vectors.shape(-1), N = vectors.shape(-1),
size = vectors.size()]() mutable { size = vectors.size()]() mutable {
// Work query
char jobr = 'N'; char jobr = 'N';
char jobl = compute_eigenvectors ? 'V' : 'N'; char jobl = compute_eigenvectors ? 'V' : 'N';
int n_vecs_r = 1;
int n_vecs_l = compute_eigenvectors ? N : 1;
int lwork = -1;
int info;
{
T work;
geev<T>(
&jobl,
&jobr,
&N,
nullptr,
&N,
nullptr,
nullptr,
nullptr,
&n_vecs_l,
nullptr,
&n_vecs_r,
&work,
&lwork,
&info);
lwork = static_cast<int>(work);
}
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)}; EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
auto vec_tmp_data =
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
for (size_t i = 0; i < size / (N * N); ++i) { for (size_t i = 0; i < size / (N * N); ++i) {
geev<T>( work.run(a_ptr, val_ptr, vec_ptr);
&jobl, a_ptr += N * N;
&jobr, val_ptr += N;
&N,
a_ptr,
&N,
eig_tmp,
eig_tmp + N,
vec_tmp,
&n_vecs_l,
nullptr,
&n_vecs_r,
static_cast<T*>(work_buf.buffer.raw_ptr()),
&lwork,
&info);
for (int i = 0; i < N; ++i) {
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
}
if (vec_ptr) { if (vec_ptr) {
for (int i = 0; i < N; ++i) {
if (eig_ptr[i].imag() != 0) {
// This vector and the next are a pair
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
vec_ptr[(i + 1) * N + j] = {
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
}
i += 1;
} else {
for (int j = 0; j < N; ++j) {
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
}
}
}
vec_ptr += N * N; vec_ptr += N * N;
} }
a_ptr += N * N; if (work.info != 0) {
eig_ptr += N;
if (info != 0) {
std::stringstream msg; std::stringstream msg;
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code " msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
<< info; << work.info;
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
} }
@@ -165,8 +264,17 @@ void Eig::eval_cpu(
case float32: case float32:
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream()); eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
break; break;
case float64:
eig_impl<double>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
case complex64:
eig_impl<std::complex<float>>(
a_copy, vectors, values, compute_eigenvectors_, stream());
break;
default: default:
throw std::runtime_error("[Eig::eval_cpu] only supports float32."); throw std::runtime_error(
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
} }
} }

View File

@@ -747,4 +747,108 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
}); });
} }
template <typename T>
void masked_scatter_impl(const array& mask, const array& src, array& out) {
ContiguousIterator mask_it(mask);
ContiguousIterator src_it(src);
ContiguousIterator out_it(out);
const bool* mask_ptr = mask.data<bool>();
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
const size_t batch_count = mask.shape(0);
const size_t mask_batch_size = mask.size() / batch_count;
const size_t src_batch_size = src.size() / batch_count;
for (uint b = 0; b < batch_count; ++b) {
size_t src_consumed = 0;
src_it.seek(b * src_batch_size);
for (size_t i = 0; i < mask_batch_size; ++i) {
if (mask_ptr[mask_it.loc]) {
if (src_consumed >= src_batch_size) {
throw std::runtime_error(
"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.");
}
dst_ptr[out_it.loc] = src_ptr[src_it.loc];
src_it.step();
++src_consumed;
}
mask_it.step();
out_it.step();
}
}
}
void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
auto& dst = inputs[0];
auto& mask = inputs[1];
auto& src = inputs[2];
// Copy src into out (copy allocates memory for out)
auto ctype =
dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;
copy_cpu(dst, out, ctype, stream());
if (mask.size() == 0) {
return;
}
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(mask);
encoder.set_input_array(src);
encoder.set_output_array(out);
encoder.dispatch([mask = array::unsafe_weak_copy(mask),
src = array::unsafe_weak_copy(src),
out = array::unsafe_weak_copy(out)]() mutable {
switch (out.dtype()) {
case bool_:
masked_scatter_impl<bool>(mask, src, out);
break;
case uint8:
masked_scatter_impl<uint8_t>(mask, src, out);
break;
case uint16:
masked_scatter_impl<uint16_t>(mask, src, out);
break;
case uint32:
masked_scatter_impl<uint32_t>(mask, src, out);
break;
case uint64:
masked_scatter_impl<uint64_t>(mask, src, out);
break;
case int8:
masked_scatter_impl<int8_t>(mask, src, out);
break;
case int16:
masked_scatter_impl<int16_t>(mask, src, out);
break;
case int32:
masked_scatter_impl<int32_t>(mask, src, out);
break;
case int64:
masked_scatter_impl<int64_t>(mask, src, out);
break;
case float16:
masked_scatter_impl<float16_t>(mask, src, out);
break;
case float32:
masked_scatter_impl<float>(mask, src, out);
break;
case float64:
masked_scatter_impl<double>(mask, src, out);
break;
case bfloat16:
masked_scatter_impl<bfloat16_t>(mask, src, out);
break;
case complex64:
masked_scatter_impl<complex64_t>(mask, src, out);
break;
}
});
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -45,9 +45,7 @@
INSTANTIATE_LAPACK_REAL(geqrf) INSTANTIATE_LAPACK_REAL(geqrf)
INSTANTIATE_LAPACK_REAL(orgqr) INSTANTIATE_LAPACK_REAL(orgqr)
INSTANTIATE_LAPACK_REAL(syevd) INSTANTIATE_LAPACK_REAL(syevd)
INSTANTIATE_LAPACK_REAL(geev)
INSTANTIATE_LAPACK_REAL(potrf) INSTANTIATE_LAPACK_REAL(potrf)
INSTANTIATE_LAPACK_REAL(gesdd)
INSTANTIATE_LAPACK_REAL(getrf) INSTANTIATE_LAPACK_REAL(getrf)
INSTANTIATE_LAPACK_REAL(getri) INSTANTIATE_LAPACK_REAL(getri)
INSTANTIATE_LAPACK_REAL(trtri) INSTANTIATE_LAPACK_REAL(trtri)
@@ -63,3 +61,20 @@ INSTANTIATE_LAPACK_REAL(trtri)
} }
INSTANTIATE_LAPACK_COMPLEX(heevd) INSTANTIATE_LAPACK_COMPLEX(heevd)
#define INSTANTIATE_LAPACK_ALL(FUNC) \
template <typename T, typename... Args> \
void FUNC(Args... args) { \
if constexpr (std::is_same_v<T, float>) { \
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, double>) { \
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
} \
}
INSTANTIATE_LAPACK_ALL(geev)
INSTANTIATE_LAPACK_ALL(gesdd)

View File

@@ -291,6 +291,17 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
num_keys, num_keys,
kshape = keys.shape(), kshape = keys.shape(),
kstrides = keys.strides()]() mutable { kstrides = keys.strides()]() mutable {
auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) {
if (4 * loc + 4 <= bytes_per_key) {
reinterpret_cast<uint32_t*>(cptr)[loc] = v;
} else {
std::copy(
reinterpret_cast<char*>(&v),
reinterpret_cast<char*>(&v) + bytes_per_key - 4 * loc,
cptr + 4 * loc);
}
};
size_t out_skip = (bytes_per_key + 4 - 1) / 4; size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2; auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0; bool even = out_skip % 2 == 0;
@@ -310,18 +321,12 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
if (count.first < half_size) { if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count); auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first; ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) { copy_remaining(cptr, count.second, rb.second);
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
} }
if (!even) { if (!even) {
count.second = 0; count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first; copy_remaining(
cptr, half_size, random::threefry2x32_hash(key, count).first);
} }
} }
}); });

View File

@@ -3,5 +3,9 @@
#include "mlx/backend/cpu/simd/base_simd.h" #include "mlx/backend/cpu/simd/base_simd.h"
#ifdef MLX_USE_ACCELERATE #ifdef MLX_USE_ACCELERATE
#if defined(__x86_64__)
// the accelerate_simd implementation require neon -- use base implementation
#else
#include "mlx/backend/cpu/simd/accelerate_simd.h" #include "mlx/backend/cpu/simd/accelerate_simd.h"
#endif #endif
#endif

View File

@@ -8,6 +8,183 @@
namespace mlx::core { namespace mlx::core {
template <typename T, class Enable = void>
struct SVDWork {};
template <typename T>
struct SVDWork<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using R = T;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension;
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <>
struct SVDWork<std::complex<float>> {
using T = std::complex<float>;
using R = float;
int N;
int M;
int K;
int lda;
int ldu;
int ldvt;
char jobz;
std::vector<array::Data> buffers;
int lwork;
SVDWork(int N, int M, int K, char jobz)
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
const int lrwork =
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
int lwork_query = -1;
int work_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
lwork = workspace_dimension.real();
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
}
void run(T* a, R* s, T* u, T* vt) {
int info;
gesdd<T>(
/* jobz = */ &jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ a,
/* lda = */ &lda,
/* s = */ s,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ u,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ vt,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
/* lwork = */ &lwork,
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
}
};
template <typename T> template <typename T>
void svd_impl( void svd_impl(
const array& a, const array& a,
@@ -27,6 +204,8 @@ void svd_impl(
const int N = a.shape(-1); const int N = a.shape(-1);
const int K = std::min(M, N); const int K = std::min(M, N);
using R = typename SVDWork<T>::R;
size_t num_matrices = a.size() / (M * N); size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy. // lapack clobbers the input, so we have to make a copy.
@@ -42,7 +221,7 @@ void svd_impl(
encoder.set_input_array(a); encoder.set_input_array(a);
auto in_ptr = in.data<T>(); auto in_ptr = in.data<T>();
T* u_ptr; T* u_ptr;
T* s_ptr; R* s_ptr;
T* vt_ptr; T* vt_ptr;
if (compute_uv) { if (compute_uv) {
@@ -58,7 +237,7 @@ void svd_impl(
encoder.set_output_array(s); encoder.set_output_array(s);
encoder.set_output_array(vt); encoder.set_output_array(vt);
s_ptr = s.data<T>(); s_ptr = s.data<R>();
u_ptr = u.data<T>(); u_ptr = u.data<T>();
vt_ptr = vt.data<T>(); vt_ptr = vt.data<T>();
} else { } else {
@@ -68,96 +247,26 @@ void svd_impl(
encoder.set_output_array(s); encoder.set_output_array(s);
s_ptr = s.data<T>(); s_ptr = s.data<R>();
u_ptr = nullptr; u_ptr = nullptr;
vt_ptr = nullptr; vt_ptr = nullptr;
} }
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() { encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ. auto jobz = (u_ptr) ? 'A' : 'N';
const int lda = N; SVDWork<T> svd_work(N, M, K, jobz);
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
auto jobz = (u_ptr) ? "A" : "N";
T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not
// used here but required by lapack).
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
static const int lwork_query = -1;
int info;
// Compute workspace size.
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Loop over matrices. // Loop over matrices.
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
gesdd<T>( svd_work.run(
/* jobz = */ jobz, in_ptr + M * N * i,
// M and N are swapped since lapack expects column-major. s_ptr + K * i,
/* m = */ &N, vt_ptr ? vt_ptr + N * N * i : nullptr,
/* n = */ &M, u_ptr ? u_ptr + M * M * i : nullptr);
/* a = */ in_ptr + M * N * i,
/* lda = */ &lda,
/* s = */ s_ptr + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
} }
}); });
encoder.add_temporary(in); encoder.add_temporary(in);
} }
template <typename T>
void compute_svd(
const array& a,
bool compute_uv,
std::vector<array>& outputs,
Stream stream) {}
void SVD::eval_cpu( void SVD::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
@@ -168,9 +277,12 @@ void SVD::eval_cpu(
case float64: case float64:
svd_impl<double>(inputs[0], outputs, compute_uv_, stream()); svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
break; break;
case complex64:
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
break;
default: default:
throw std::runtime_error( throw std::runtime_error(
"[SVD::eval_cpu] only supports float32 or float64."); "[SVD::eval_cpu] only supports float32, float64, or complex64.");
} }
} }

View File

@@ -44,6 +44,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
@@ -122,10 +123,21 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>") mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
endif() endif()
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with # Use native CUDA arch by default.
# managed memory.
if(NOT DEFINED MLX_CUDA_ARCHITECTURES) if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
set(MLX_CUDA_ARCHITECTURES "native") execute_process(
COMMAND __nvcc_device_query
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
OUTPUT_STRIP_TRAILING_WHITESPACE)
set(UPGRADABLE_ARCHITECTURES "90;100;121")
if(MLX_CUDA_ARCHITECTURES STREQUAL "")
message(
FATAL_ERROR
"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
# Use arch-specific compute capability whenever possible.
set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
endif()
endif() endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
@@ -137,6 +149,7 @@ FetchContent_Declare(
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip") URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
FetchContent_MakeAvailable(cccl) FetchContent_MakeAvailable(cccl)
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include") target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
set_target_properties(mlx PROPERTIES CCCL_DIR "${cccl_SOURCE_DIR}/include")
# Use fixed version of NVTX. # Use fixed version of NVTX.
FetchContent_Declare( FetchContent_Declare(
@@ -162,7 +175,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
FetchContent_Declare( FetchContent_Declare(
cudnn cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.14.0 GIT_TAG v1.16.0
GIT_SHALLOW TRUE GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL) EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON) set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)

View File

@@ -20,6 +20,19 @@ constexpr int page_size = 16384;
// Any allocations smaller than this will try to use the small pool // Any allocations smaller than this will try to use the small pool
constexpr int small_block_size = 8; constexpr int small_block_size = 8;
#if CUDART_VERSION >= 13000
inline cudaMemLocation cuda_mem_loc(int i) {
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
return loc;
}
#else
inline int cuda_mem_loc(int i) {
return i;
}
#endif // CUDART_VERSION >= 13000
// The small pool size in bytes. This should be a multiple of the host page // The small pool size in bytes. This should be a multiple of the host page
// size and small_block_size. // size and small_block_size.
constexpr int small_pool_size = 4 * page_size; constexpr int small_pool_size = 4 * page_size;
@@ -35,13 +48,7 @@ SmallSizePool::SmallSizePool() {
int device_count = 0; int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count)); CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) { for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000 auto loc = cuda_mem_loc(i);
cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice;
loc.id = i;
#else
int loc = i;
#endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc)); cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
} }
@@ -90,9 +97,10 @@ CudaAllocator::CudaAllocator()
page_size, page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) { [this](CudaBuffer* buf) { cuda_free(buf); }) {
size_t free, total; size_t free;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
memory_limit_ = total * 0.95; memory_limit_ = total_memory_ * 0.95;
free_limit_ = total_memory_ - memory_limit_;
max_pool_size_ = memory_limit_; max_pool_size_ = memory_limit_;
int device_count = 0; int device_count = 0;
@@ -104,6 +112,10 @@ CudaAllocator::CudaAllocator()
cudaStream_t s; cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking)); CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s); free_streams_.push_back(s);
cudaMemPool_t mem_pool;
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
mem_pools_.push_back(mem_pool);
} }
CHECK_CUDA_ERROR(cudaSetDevice(curr)); CHECK_CUDA_ERROR(cudaSetDevice(curr));
} }
@@ -119,7 +131,8 @@ void copy_to_managed(CudaBuffer& buf) {
buf.data = new_data; buf.data = new_data;
} }
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { Buffer
CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
if (size == 0) { if (size == 0) {
return Buffer{new CudaBuffer{nullptr, 0, -1}}; return Buffer{new CudaBuffer{nullptr, 0, -1}};
} }
@@ -134,9 +147,8 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
size = page_size * ((size + page_size - 1) / page_size); size = page_size * ((size + page_size - 1) / page_size);
} }
int device = -1; if (size <= small_block_size || stream == nullptr) {
if (size > small_block_size && stream != nullptr) { device = -1;
CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device));
} }
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
@@ -154,19 +166,35 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
} }
lock.unlock(); lock.unlock();
if (!buf) { if (!buf) {
buf = new CudaBuffer{nullptr, size, device}; void* data = nullptr;
cudaError_t err;
if (device == -1) { if (device == -1) {
err = cudaMallocManaged(&buf->data, size); CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
} else { } else {
err = cudaMallocAsync(&buf->data, size, stream); CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
} }
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { if (!data) {
throw std::runtime_error(fmt::format( std::ostringstream msg;
"cudaMallocManaged failed: {}.", cudaGetErrorString(err))); msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
} }
buf = new CudaBuffer{data, size, device};
} }
lock.lock(); lock.lock();
// If any cuda memory pool has too much reserved memory, clear some
// memory from the cache. This prevents graph / kernel execution failing
// from OOM
if (get_cache_memory() > 0) {
for (auto p : mem_pools_) {
size_t used = 0;
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
p, cudaMemPoolAttrReservedMemCurrent, &used));
if (used > (total_memory_ - free_limit_)) {
buffer_cache_.release_cached_buffers(free_limit_);
break;
}
}
}
} }
active_memory_ += buf->size; active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_); peak_memory_ = std::max(active_memory_, peak_memory_);
@@ -176,18 +204,14 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
} }
// Copy to managed here if the buffer is not on the right device // Copy to managed here if the buffer is not on the right device
if (buf->device != device) { if (buf->device >= 0 && buf->device != device) {
copy_to_managed(*buf); copy_to_managed(*buf);
} }
return Buffer{buf}; return Buffer{buf};
} }
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
return malloc_impl(size, stream);
}
Buffer CudaAllocator::malloc(size_t size) { Buffer CudaAllocator::malloc(size_t size) {
return malloc_impl(size, nullptr); return malloc_async(size, -1, nullptr);
} }
void CudaAllocator::free(Buffer buffer) { void CudaAllocator::free(Buffer buffer) {
@@ -223,9 +247,9 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
scalar_pool_.free(buf); scalar_pool_.free(buf);
} else { } else {
if (buf->device >= 0) { if (buf->device >= 0) {
cudaFreeAsync(buf->data, free_streams_[buf->device]); CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
} else { } else {
cudaFree(buf->data); CHECK_CUDA_ERROR(cudaFree(buf->data));
} }
delete buf; delete buf;
} }
@@ -277,8 +301,9 @@ CudaAllocator& allocator() {
return *allocator_; return *allocator_;
} }
Buffer malloc_async(size_t size, cudaStream_t stream) { Buffer malloc_async(size_t size, CommandEncoder& encoder) {
auto buffer = allocator().malloc_async(size, stream); auto buffer = allocator().malloc_async(
size, encoder.device().cuda_device(), encoder.stream());
if (size && !buffer.ptr()) { if (size && !buffer.ptr()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes."; msg << "[malloc_async] Unable to allocate " << size << " bytes.";

View File

@@ -13,6 +13,8 @@
namespace mlx::core::cu { namespace mlx::core::cu {
class CommandEncoder;
using allocator::Buffer; using allocator::Buffer;
// Stores cuda-managed unified memory. // Stores cuda-managed unified memory.
@@ -48,7 +50,7 @@ class SmallSizePool {
class CudaAllocator : public allocator::Allocator { class CudaAllocator : public allocator::Allocator {
public: public:
Buffer malloc(size_t size) override; Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, cudaStream_t stream); Buffer malloc_async(size_t size, int device, cudaStream_t stream);
void free(Buffer buffer) override; void free(Buffer buffer) override;
size_t size(Buffer buffer) const override; size_t size(Buffer buffer) const override;
@@ -62,7 +64,6 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache(); void clear_cache();
private: private:
Buffer malloc_impl(size_t size, cudaStream_t stream);
void cuda_free(CudaBuffer* buf); void cuda_free(CudaBuffer* buf);
CudaAllocator(); CudaAllocator();
@@ -70,16 +71,19 @@ class CudaAllocator : public allocator::Allocator {
std::mutex mutex_; std::mutex mutex_;
size_t memory_limit_; size_t memory_limit_;
size_t free_limit_;
size_t total_memory_;
size_t max_pool_size_; size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_; BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_; std::vector<cudaStream_t> free_streams_;
std::vector<cudaMemPool_t> mem_pools_;
SmallSizePool scalar_pool_; SmallSizePool scalar_pool_;
}; };
CudaAllocator& allocator(); CudaAllocator& allocator();
Buffer malloc_async(size_t size, cudaStream_t stream); Buffer malloc_async(size_t size, CommandEncoder& encoder);
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -42,7 +42,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
auto& encoder = cu::get_command_encoder(stream()); auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {

View File

@@ -143,7 +143,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
// Prepare the shapes, strides and axis arguments. // Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_); Shape shape = remove_index(in.shape(), axis_);

View File

@@ -367,9 +367,8 @@ void binary_op_gpu(
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out, bopt, [&](auto n) { set_binary_op_output_data(
return cu::malloc_async(n, encoder.stream()); a, b, out, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
});
binary_op_gpu_inplace<Op>(inputs, out, op, s); binary_op_gpu_inplace<Op>(inputs, out, op, s);
} }

View File

@@ -246,12 +246,10 @@ void binary_two_op_gpu_inplace(
auto& out_b = outputs[1]; auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) { set_binary_op_output_data(
return cu::malloc_async(n, encoder.stream()); a, b, out_a, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
}); set_binary_op_output_data(
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) { a, b, out_b, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
return cu::malloc_async(n, encoder.stream());
});
if (out_a.size() == 0) { if (out_a.size() == 0) {
return; return;

View File

@@ -298,7 +298,7 @@ void Compiled::eval_gpu(
// Put outputs. // Put outputs.
compiled_allocate_outputs( compiled_allocate_outputs(
inputs, outputs, is_constant_, contiguous, [&](auto n) { inputs, outputs, is_constant_, contiguous, [&](auto n) {
return cu::malloc_async(n, encoder.stream()); return cu::malloc_async(n, encoder);
}); });
for (auto& x : outputs) { for (auto& x : outputs) {
args.append(x); args.append(x);

View File

@@ -15,19 +15,16 @@ namespace mlx::core {
namespace { namespace {
// Alias for better readability. enum ConvBackendType {
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR CONV_FALLBACK,
#define CONV_BACKWARD_INPUT \ CONV_FORWARD,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR CONV_BACKWARD_INPUT,
#define CONV_BACKWARD_WEIGHT \ CONV_BACKWARD_WEIGHT,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR };
// Custom placeholder representing fallback kernel.
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
struct ConvCacheKey { struct ConvCacheKey {
int device_id; int device_id;
cudnnDataType_t cudnn_dtype; fe::DataType_t cudnn_dtype;
std::array<int, MAX_NDIM> input_shape; std::array<int, MAX_NDIM> input_shape;
std::array<int, MAX_NDIM> weight_shape; std::array<int, MAX_NDIM> weight_shape;
std::array<int, MAX_NDIM> stride; std::array<int, MAX_NDIM> stride;
@@ -44,15 +41,13 @@ struct ConvCacheKey {
auto& conv_cache() { auto& conv_cache() {
static LRUBytesKeyCache< static LRUBytesKeyCache<
ConvCacheKey, ConvCacheKey,
std::pair< std::pair<ConvBackendType, std::optional<DnnGraph>>>
cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>>
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128); cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
return cache; return cache;
} }
auto get_conv_op_settings( auto get_conv_settings(
cudnnBackendDescriptorType_t backend_type, ConvBackendType backend_type,
array& x, array& x,
array& w, array& w,
array& y, array& y,
@@ -68,8 +63,8 @@ auto get_conv_op_settings(
for (int i = 0; i < padding_lo.size(); ++i) { for (int i = 0; i < padding_lo.size(); ++i) {
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1); int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding_lo[i] - 1; padding_lo[i] = wt_size - padding_lo[i] - 1;
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1); int in_size = 1 + kernel_strides[i] * (y.shape(1 + i) - 1);
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1); int out_size = 1 + input_dilation[i] * (x.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + padding_hi[i]; padding_hi[i] = out_size - in_size + padding_hi[i];
} }
return std::make_tuple( return std::make_tuple(
@@ -95,49 +90,57 @@ auto get_conv_op_settings(
} }
} }
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph( std::optional<DnnGraph> build_conv_graph(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type, ConvBackendType backend_type,
Dtype dtype, Dtype dtype,
array& x, array& x,
array& w, array& w,
array& y, array& y,
const SmallVector<int64_t>& stride, const std::vector<int64_t>& stride,
const SmallVector<int64_t>& padding_lo, const std::vector<int64_t>& padding_lo,
const SmallVector<int64_t>& padding_hi, const std::vector<int64_t>& padding_hi,
const SmallVector<int64_t>& dilation) { const std::vector<int64_t>& dilation) {
try { auto compute_dtype =
auto compute_dtype = (dtype == float16 || dtype == bfloat16) (dtype == float16 || dtype == bfloat16) ? float32 : dtype;
? CUDNN_DATA_FLOAT DnnGraph graph(encoder.device().cudnn_handle(), dtype, compute_dtype);
: dtype_to_cudnn_type(dtype); auto x_ = graph.tensor_nchw("X", 'x', x);
auto conv_desc = cudnn_frontend::ConvDescBuilder() auto w_ = graph.tensor_nchw("W", 'w', w);
.setDataType(compute_dtype)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(stride.size())
.setStrides(stride.size(), stride.data())
.setPrePadding(padding_lo.size(), padding_lo.data())
.setPostPadding(padding_hi.size(), padding_hi.data())
.setDilation(dilation.size(), dilation.data())
.build();
auto op = cudnn_frontend::OperationBuilder(backend_type) auto set_options = [&](auto& options) {
.setxDesc(build_cudnn_tensor_nchw('x', x)) options.set_compute_data_type(dtype_to_cudnn_type(compute_dtype))
.setwDesc(build_cudnn_tensor_nchw('w', w)) .set_convolution_mode(fe::ConvolutionMode_t::CROSS_CORRELATION)
.setyDesc(build_cudnn_tensor_nchw('y', y)) .set_stride(stride)
.setcDesc(conv_desc) .set_pre_padding(padding_lo)
.build(); .set_post_padding(padding_hi)
.set_dilation(dilation);
};
std::array<cudnn_frontend::Operation const*, 1> ops = {&op}; std::shared_ptr<fe::graph::Tensor_attributes> y_;
return cudnn_frontend::OperationGraphBuilder() if (backend_type == CONV_FORWARD) {
.setHandle(encoder.device().cudnn_handle()) auto options = fe::graph::Conv_fprop_attributes();
.setOperationGraph(ops.size(), ops.data()) set_options(options);
.build(); y_ = graph.conv_fprop(x_, w_, options);
} catch (cudnn_frontend::cudnnException& error) { } else if (backend_type == CONV_BACKWARD_INPUT) {
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) { auto options = fe::graph::Conv_dgrad_attributes();
throw; set_options(options);
} y_ = graph.conv_dgrad(x_, w_, options);
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
auto options = fe::graph::Conv_wgrad_attributes();
set_options(options);
y_ = graph.conv_wgrad(w_, x_, options);
}
graph.tensor_nchw(y_, 'y', y)->set_output(true);
if (graph.prepare().is_bad()) {
return std::nullopt; return std::nullopt;
} }
graph.deselect_numeric_notes({fe::NumericalNote_t::DOWN_CONVERT_INPUTS});
if (dtype == float32 && !env::enable_tf32()) {
graph.deselect_numeric_notes({fe::NumericalNote_t::TENSOR_CORE});
}
CHECK_CUDNN_FE_ERROR(graph.build());
return graph;
} }
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups). // Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
@@ -181,7 +184,7 @@ array group_transpose(
// eval_gpu, with cost of possible redundant copies. // eval_gpu, with cost of possible redundant copies.
std::tuple<array, array, array> prepare_args( std::tuple<array, array, array> prepare_args(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type, ConvBackendType backend_type,
array in, array in,
array wt, array wt,
array out, array out,
@@ -221,27 +224,11 @@ std::tuple<array, array, array> prepare_args(
return {std::move(in), std::move(wt), std::move(out)}; return {std::move(in), std::move(wt), std::move(out)};
} }
// Get the x/w/y args from the in/wt/out args depending on backend type.
inline std::tuple<array&, array&, array&> dispatch_args(
cudnnBackendDescriptorType_t backend_type,
array& in,
array& wt,
array& out) {
switch (backend_type) {
case CONV_BACKWARD_INPUT:
return {out, wt, in};
case CONV_BACKWARD_WEIGHT:
return {in, out, wt};
default:
return {in, wt, out};
}
}
// Register inputs and outputs before actually running conv op. Can only be // Register inputs and outputs before actually running conv op. Can only be
// called once per eval_gpu. // called once per eval_gpu.
void register_args( void register_args(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type, ConvBackendType backend_type,
array& in, array& in,
array& wt, array& wt,
array& intermediate_out, array& intermediate_out,
@@ -277,11 +264,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
array in = inputs[0]; array in = inputs[0];
array wt = inputs[1]; array wt = inputs[1];
array out = out_; array out = out_;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
Dtype dtype = out.dtype(); Dtype dtype = out.dtype();
// Search cache. // Search cache.
ConvCacheKey cache_key{ BytesKey<ConvCacheKey> cache_key;
cache_key.pod = {
encoder.device().cuda_device(), encoder.device().cuda_device(),
dtype_to_cudnn_type(dtype), dtype_to_cudnn_type(dtype),
vector_key(in.shape()), vector_key(in.shape()),
@@ -296,16 +284,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
get_alignment(wt), get_alignment(wt),
get_alignment(out)}; get_alignment(out)};
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
auto& [backend_type, plan] = it->second; auto& [backend_type, graph] = it->second;
if (plan) { if (graph) {
// Run cached plan. // Run cached graph.
std::tie(in, wt, out) = std::tie(in, wt, out) =
prepare_args(encoder, backend_type, in, wt, out, groups_, s); prepare_args(encoder, backend_type, in, wt, out, groups_, s);
register_args(encoder, backend_type, in, wt, out, out_); register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out); CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { encoder,
throw std::runtime_error("[conv] Cached plan failed to execute."); {
} {'x', gpu_ptr<void>(in)},
{'w', gpu_ptr<void>(wt)},
{'y', gpu_ptr<void>(out)},
}));
} else { } else {
// Run fallback kernel. // Run fallback kernel.
gemm_conv( gemm_conv(
@@ -326,7 +317,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
// There is no reliable way to deduce the proper cuDNN backend for the // There is no reliable way to deduce the proper cuDNN backend for the
// convolution, so we make a best guess and then try. // convolution, so we make a best guess and then try.
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends; SmallVector<ConvBackendType, 2> try_backends;
if (flip_) { if (flip_) {
// When weight is flipped, we assume it is backward input convolution. // When weight is flipped, we assume it is backward input convolution.
try_backends.push_back(CONV_BACKWARD_INPUT); try_backends.push_back(CONV_BACKWARD_INPUT);
@@ -344,13 +335,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
} }
// Try to build op graph. // Try to build op graph.
cudnnBackendDescriptorType_t backend_type; ConvBackendType backend_type;
std::optional<cudnn_frontend::OperationGraph> op_graph; std::optional<DnnGraph> graph;
for (auto try_backend : try_backends) { for (auto try_backend : try_backends) {
auto [in_copy, wt_copy, out_copy] = auto [x, w, y] =
prepare_args(encoder, try_backend, in, wt, out, groups_, s); prepare_args(encoder, try_backend, in, wt, out, groups_, s);
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy); auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings(
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
try_backend, try_backend,
x, x,
w, w,
@@ -360,7 +350,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
padding_hi_, padding_hi_,
kernel_dilation_, kernel_dilation_,
input_dilation_); input_dilation_);
op_graph = build_conv_op_graph( graph = build_conv_graph(
encoder, encoder,
try_backend, try_backend,
dtype, dtype,
@@ -371,30 +361,27 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
padding_lo, padding_lo,
padding_hi, padding_hi,
dilation); dilation);
if (op_graph) { if (graph) {
backend_type = try_backend; backend_type = try_backend;
in = std::move(in_copy); in = std::move(x);
wt = std::move(wt_copy); wt = std::move(w);
out = std::move(out_copy); out = std::move(y);
break; break;
} }
} }
if (op_graph) { if (graph) {
// Find a plan for the graph and execute it. register_args(encoder, backend_type, in, wt, out, out_);
auto plan = find_cudnn_plan_from_op_graph( CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph); encoder,
if (plan) { {
// Setup inputs and outputs. {'x', gpu_ptr<void>(in)},
register_args(encoder, backend_type, in, wt, out, out_); {'w', gpu_ptr<void>(wt)},
{'y', gpu_ptr<void>(out)},
auto [x, w, y] = dispatch_args(backend_type, in, wt, out); }));
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { conv_cache().emplace(
conv_cache().emplace( cache_key, std::make_pair(backend_type, std::move(*graph)));
cache_key, std::make_pair(backend_type, std::move(*plan))); return;
return;
}
}
} }
// Use fallback kernel for settings not supported by cuDNN. // Use fallback kernel for settings not supported by cuDNN.

View File

@@ -86,7 +86,7 @@ array unfold_inputs_nd(
int mat_N, int mat_N,
ConvParams<NDIM>& params) { ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {}); array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream())); unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
encoder.add_temporary(unfolded); encoder.add_temporary(unfolded);
int filter_size = params.C; int filter_size = params.C;

View File

@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
int mat_N, int mat_N,
ConvParams<NDIM>& params) { ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {}); array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream())); unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
encoder.add_temporary(unfolded); encoder.add_temporary(unfolded);
int filter_size = params.C; int filter_size = params.C;

View File

@@ -7,9 +7,8 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) { bool donated = set_copy_output_data(
return cu::malloc_async(n, encoder.stream()); in, out, ctype, [&](auto n) { return cu::malloc_async(n, encoder); });
});
if (donated && in.dtype() == out.dtype()) { if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to // If the output has the same type as the input then there is nothing to
// copy, just use the buffer. // copy, just use the buffer.
@@ -104,7 +103,7 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
return; return;
} }
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
@@ -114,7 +113,7 @@ void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) { if (copy_necessary) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
copy_gpu_inplace( copy_gpu_inplace(
in, in,
out, out,

View File

@@ -95,11 +95,14 @@ void copy_general_input(
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in; const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out; OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1;
int work_per_thread = 8;
auto dim0 = ndim > 0 ? shape.back() : 1; auto dim0 = ndim > 0 ? shape.back() : 1;
auto rest = out.size() / dim0; auto rest = out.size() / dim0;
if (dim0 >= 4) { if (dim0 >= 4 && dim0 < 8) {
work_per_thread = 4; work_per_thread = 4;
} else if (dim0 < 4) {
work_per_thread = 1;
} }
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
auto block_dims = get_block_dims(dim0, rest, 1); auto block_dims = get_block_dims(dim0, rest, 1);
@@ -110,7 +113,10 @@ void copy_general_input(
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto kernel = auto kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>; cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
if (work_per_thread == 4) { if (work_per_thread == 8) {
kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
} else if (work_per_thread == 4) {
kernel = kernel =
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>; cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
} }
@@ -127,7 +133,9 @@ void copy_general_input(
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>; auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
if (work_per_thread == 4) { if (work_per_thread == 8) {
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
} else if (work_per_thread == 4) {
kernel = cu::copy_g<InType, OutType, IdxT, 4>; kernel = cu::copy_g<InType, OutType, IdxT, 4>;
} }
encoder.add_kernel_node( encoder.add_kernel_node(

View File

@@ -5,6 +5,7 @@
#include <cublasLt.h> #include <cublasLt.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cudnn.h>
namespace mlx::core { namespace mlx::core {
@@ -12,10 +13,12 @@ namespace mlx::core {
void check_cublas_error(const char* name, cublasStatus_t err); void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err); void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err); void check_cuda_error(const char* name, CUresult err);
void check_cudnn_error(const char* name, cudnnStatus_t err);
// The macro version that prints the command that failed. // The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) #define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
// Base class for RAII managed CUDA resources. // Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)> template <typename Handle, cudaError_t (*Destroy)(Handle)>
@@ -29,6 +32,10 @@ class CudaHandle {
} }
~CudaHandle() { ~CudaHandle() {
// Skip if there was an error to avoid throwing in the destructors
if (cudaPeekAtLastError() != cudaSuccess) {
return;
}
reset(); reset();
} }

View File

@@ -7,32 +7,26 @@ namespace mlx::core {
namespace { namespace {
// Create a cudnn tensor descriptor. #define RETURN_IF_ERROR(cmd) \
template <typename Vec> if (auto ret = cmd; ret.is_bad()) { \
inline cudnn_frontend::Tensor build_cudnn_tensor( return ret; \
int64_t id, }
const array& x,
const Vec& shape,
const Vec& strides) {
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(get_alignment(x))
.setDataType(dtype_to_cudnn_type(x.dtype()))
.build();
}
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN // In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
// whether a tensor is contiguous is determined with: // whether a tensor is contiguous is determined with:
// shape[dim] == shape[dim + 1] * strides[dim + 1] // shape[dim] == shape[dim + 1] * strides[dim + 1]
// So a contiguous array with singleton dims in MLX may be mistakenly treated // So a contiguous array with singleton dims in MLX may be mistakenly treated
// as strided in cuDNN, and we work around it by normalizing the strides. // as strided in cuDNN, and we work around it by normalizing the strides.
Strides normalized_strides(const array& x) { std::vector<int64_t> normalized_strides(const array& x) {
if (!x.flags().row_contiguous || x.ndim() < 2) { std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
return x.strides(); if (std::all_of(
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
strides.back() = 1;
return strides;
}
if (!x.flags().row_contiguous || x.ndim() < 2) {
return strides;
} }
Strides strides = x.strides();
for (int i = x.ndim() - 2; i >= 0; --i) { for (int i = x.ndim() - 2; i >= 0; --i) {
if (x.shape(i) == 1) { if (x.shape(i) == 1) {
strides[i] = x.shape(i + 1) * strides[i + 1]; strides[i] = x.shape(i + 1) * strides[i + 1];
@@ -42,7 +36,9 @@ Strides normalized_strides(const array& x) {
} }
// Return the shape and strides after transposing from NHWC to NCHW. // Return the shape and strides after transposing from NHWC to NCHW.
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) { inline auto nhwc_to_nchw(const array& x) {
auto shape = convert_vector<int64_t>(x.shape());
auto strides = normalized_strides(x);
assert(shape.size() >= 3); assert(shape.size() >= 3);
shape.insert(shape.begin() + 1, shape.back()); shape.insert(shape.begin() + 1, shape.back());
shape.erase(shape.end() - 1); shape.erase(shape.end() - 1);
@@ -51,228 +47,95 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
return std::make_tuple(std::move(shape), std::move(strides)); return std::make_tuple(std::move(shape), std::move(strides));
} }
inline auto nhwc_to_nchw(const array& x) {
return nhwc_to_nchw(
convert_vector<int64_t>(x.shape()), normalized_strides(x));
}
// Return available engines for a |op_graph|.
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph,
bool use_fallback = true) {
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
sources.push_back([](auto& op_graph) {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph)
.setHeurMode(CUDNN_HEUR_MODE_A)
.build();
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
});
if (use_fallback) {
sources.push_back([&backend_type](auto& op_graph) {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph)
.setOperation(backend_type)
.build();
return fallback.getFallbackList();
});
}
auto configs =
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
.generate_engine_config(op_graph);
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
if (cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
dtype == float32 && !env::enable_tf32()) {
return true;
}
return false;
});
return filtered_configs;
}
// Take |engine_configs| and |op_graph| and find a working execution plans
// from them.
std::optional<cudnn_frontend::ExecutionPlan>
find_cudnn_plan_from_engine_configs(
cudnnHandle_t handle,
const cudnn_frontend::EngineConfigList& engine_configs,
const cudnn_frontend::OperationGraph& op_graph) {
auto op_graph_tag = op_graph.getTag();
for (const auto& config : engine_configs) {
try {
return cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config, op_graph_tag)
.build();
} catch (cudnn_frontend::cudnnException& error) {
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
throw;
}
}
}
return std::nullopt;
}
// Prepare workspace and args to execute plan.
template <typename F>
bool prepare_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs,
F&& execute) {
int workspace_size = plan.getWorkspaceSize();
void* workspace_ptr = nullptr;
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder.stream()),
{workspace_size},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids)
.build();
auto handle = encoder.device().cudnn_handle();
cudnnSetStream(handle, encoder.stream());
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
return false;
}
return true;
}
} // namespace } // namespace
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) { fe::error_t DnnGraph::prepare() {
auto shape = convert_vector<int64_t>(x.shape()); RETURN_IF_ERROR(validate());
return build_cudnn_tensor(id, x, shape, normalized_strides(x)); try {
RETURN_IF_ERROR(build_operation_graph(handle_));
} catch (cudnn_frontend::cudnnException& error) {
// cuDNN bug: they did not catch all exceptions in the API.
return {fe::error_code_t::CUDNN_BACKEND_API_FAILED, error.what()};
}
RETURN_IF_ERROR(create_execution_plans({fe::HeurMode_t::A}));
return {};
} }
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) { fe::error_t DnnGraph::build() {
RETURN_IF_ERROR(check_support(handle_));
RETURN_IF_ERROR(build_plans(handle_));
return {};
}
fe::error_t DnnGraph::encode_graph(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack) {
cudnnSetStream(handle_, encoder.stream());
CudaGraph cuda_graph(encoder.device());
RETURN_IF_ERROR(populate_cuda_graph(
handle_, variant_pack, prepare_workspace(encoder), cuda_graph));
encoder.add_graph_node(cuda_graph);
return {};
}
fe::error_t DnnGraph::encode_capturing(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack) {
auto* workspace_ptr = prepare_workspace(encoder);
auto capture = encoder.capture_context();
cudnnSetStream(handle_, encoder.stream());
auto ret = execute(handle_, variant_pack, workspace_ptr);
if (ret.is_bad()) {
capture.discard = true;
}
return ret;
}
void* DnnGraph::prepare_workspace(cu::CommandEncoder& encoder) {
int64_t workspace_size = 0;
CHECK_CUDNN_FE_ERROR(get_workspace_size(workspace_size));
if (workspace_size > 0) {
array workspace(
cu::malloc_async(workspace_size, encoder),
{static_cast<int>(workspace_size)},
uint8);
encoder.add_temporary(workspace);
return gpu_ptr<void>(workspace);
}
return nullptr;
}
void DnnGraph::set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& strides) {
tensor->set_uid(uid)
.set_alignment(get_alignment(x))
.set_data_type(dtype_to_cudnn_type(x.dtype()))
.set_dim(shape)
.set_stride(strides);
}
void DnnGraph::set_tensor_attrs(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x) {
set_tensor_attrs(
tensor,
uid,
x,
convert_vector<int64_t>(x.shape()),
normalized_strides(x));
}
void DnnGraph::set_tensor_attrs_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x) {
auto [shape, strides] = nhwc_to_nchw(x); auto [shape, strides] = nhwc_to_nchw(x);
return build_cudnn_tensor(id, x, shape, strides); set_tensor_attrs(tensor, uid, x, shape, strides);
} }
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
if (x.ndim() == 0) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
}
if (x.ndim() == 1) {
int64_t s = x.shape(0);
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
SmallVector<int64_t, 4> strides = {s, 1, s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 2) {
int64_t s =
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
return build_cudnn_tensor(id, x, shape, strides);
}
if (x.ndim() == 3 || x.ndim() == 4) {
return build_cudnn_tensor_nchw(id, x);
}
throw std::runtime_error(
fmt::format("Unsupported array with {} dims.", x.ndim()));
}
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
return cudnn_frontend::TensorBuilder()
.setDim(scalar_dims.size(), scalar_dims.data())
.setStrides(scalar_dims.size(), scalar_dims.data())
.setId(id)
.setAlignment(16)
.setDataType(dtype_to_cudnn_type(dtype))
.setByValue(true)
.build();
}
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
cudnnHandle_t handle,
cudnnBackendDescriptorType_t backend_type,
Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
if (engine_configs.empty()) {
return std::nullopt;
}
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
}
bool encode_cudnn_plan_with_capturing(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
auto capture = encoder.capture_context();
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
// Discard the captured graph when failed.
capture.discard = true;
return false;
}
return true;
});
}
#if CUDNN_VERSION >= 90500
bool encode_cudnn_plan_with_graph_api(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
int num_args,
const int64_t* uids,
void** data_ptrs) {
return prepare_cudnn_plan(
encoder,
plan,
num_args,
uids,
data_ptrs,
[&](auto handle, auto plan, auto args) {
if (!graph) {
graph = CudaGraph(encoder.device());
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
} else {
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
CUDNN_STATUS_SUCCESS) {
return false;
}
}
encoder.add_graph_node(graph);
return true;
});
}
#endif
} // namespace mlx::core } // namespace mlx::core

View File

@@ -2,25 +2,30 @@
#pragma once #pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_find_plan.h>
#include <fmt/format.h> #include <fmt/format.h>
#include <algorithm>
#include <array>
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
class CommandEncoder; class CommandEncoder;
} }
namespace fe = cudnn_frontend;
#define CHECK_CUDNN_FE_ERROR(cmd) \
do { \
auto error = cmd; \
if (!error.is_good()) { \
throw std::runtime_error( \
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
} \
} while (0)
// Return pointer alignment of |x|'s data. // Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) { inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1; uint8_t alignment = 1;
@@ -35,8 +40,31 @@ inline uint8_t get_alignment(const array& x) {
// Convert the type of elements in |vec| to |T|. // Convert the type of elements in |vec| to |T|.
template <typename T, typename Vec> template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) { inline std::vector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end()); return std::vector<T>(vec.begin(), vec.end());
}
// Map dtype to cudnn data type.
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return fe::DataType_t::INT8;
case int32:
return fe::DataType_t::INT32;
case uint8:
return fe::DataType_t::UINT8;
case float16:
return fe::DataType_t::HALF;
case bfloat16:
return fe::DataType_t::BFLOAT16;
case float32:
return fe::DataType_t::FLOAT;
case float64:
return fe::DataType_t::DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype)));
}
} }
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM. // Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
@@ -44,122 +72,100 @@ inline SmallVector<T> convert_vector(const Vec& vec) {
// There are 2 differences from the const_param util from kernel_utils.cuh: // There are 2 differences from the const_param util from kernel_utils.cuh:
// 1. The rest of array is filled with 0. // 1. The rest of array is filled with 0.
// 2. This util can be used in .cpp files. // 2. This util can be used in .cpp files.
template <typename T, template <typename U> class Vec> template <int NDIM = MAX_NDIM, typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) { inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) { if (vec.size() > NDIM) {
throw std::runtime_error( throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM)); fmt::format("ndim can not be larger than {}.", NDIM));
} }
std::array<T, MAX_NDIM> result = {}; std::array<T, NDIM> result = {};
std::copy_n(vec.begin(), vec.size(), result.begin()); std::copy_n(vec.begin(), vec.size(), result.begin());
return result; return result;
} }
// Helpers used by get_data_ptrs to get pointers. // Extends cuDNN graph with helpers.
inline void* get_data_ptr(const array& arr) { class DnnGraph : public fe::graph::Graph {
return const_cast<void*>(gpu_ptr<void>(arr)); public:
} DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32)
: handle_(handle) {
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>> set_io_data_type(dtype_to_cudnn_type(io_dtype));
inline void* get_data_ptr(T& scalar) { set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype));
return &scalar; set_compute_data_type(dtype_to_cudnn_type(compute_dtype));
}
// Return an array filled with data pointers of args.
template <typename... Args>
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
return {get_data_ptr(args)...};
}
// Map dtype to cudnn data type.
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return CUDNN_DATA_INT8;
case int32:
return CUDNN_DATA_INT32;
case uint8:
return CUDNN_DATA_UINT8;
case float16:
return CUDNN_DATA_HALF;
case bfloat16:
return CUDNN_DATA_BFLOAT16;
case float32:
return CUDNN_DATA_FLOAT;
case float64:
return CUDNN_DATA_DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
} }
}
// Create a tensor descriptor from |x|. // Create a cuDNN tensor description from MLX array |x|.
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x); auto& tensor(
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
int64_t uid,
const array& x) {
set_tensor_attrs(attrs, uid, x);
return attrs;
}
auto tensor(const char* name, int64_t uid, const array& x) {
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
tensor(attrs, uid, x);
return attrs;
}
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW. // Create a cuDNN tensor description from MLX array |x|, and transpose it from
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x); // NHWC layout to NCHW.
auto& tensor_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
int64_t uid,
const array& x) {
set_tensor_attrs_nchw(attrs, uid, x);
return attrs;
}
auto tensor_nchw(const char* name, int64_t uid, const array& x) {
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
tensor_nchw(attrs, uid, x);
return attrs;
}
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it // Create a cuDNN tensor for scalar.
// from NHWC to NCHW. auto scalar(const char* name, int64_t uid, Dtype dtype) {
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x); return Graph::tensor(fe::graph::Tensor_attributes()
.set_name(name)
.set_uid(uid)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(dtype_to_cudnn_type(dtype)));
}
// Create a 4D scalar tensor descriptor, which is passed by value. // Call this before setting notes.
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype); fe::error_t prepare();
// Call this after setting notes.
fe::error_t build();
// Find a working plan for |op_graph|. // Add cuDNN graph to CUDA graph, using native CUDA graph API.
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph( fe::error_t encode_graph(
cudnnHandle_t handle, cu::CommandEncoder& encoder,
cudnnBackendDescriptorType_t backend_type, std::unordered_map<int64_t, void*> variant_pack);
Dtype dtype, // Add cuDNN graph to CUDA graph, using stream capture.
cudnn_frontend::OperationGraph& op_graph); fe::error_t encode_capturing(
cu::CommandEncoder& encoder,
std::unordered_map<int64_t, void*> variant_pack);
// Encode the plan to command buffer by capturing. private:
bool encode_cudnn_plan_with_capturing( void* prepare_workspace(cu::CommandEncoder& encoder);
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
int num_args,
const int64_t* uids,
void** data_ptrs);
#if CUDNN_VERSION >= 90500 void set_tensor_attrs(
// Encode the plan to command buffer by using native graph api of cudnn. If the std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
// |graph| is empty it will be populated, otherwise it will be updated. int64_t uid,
bool encode_cudnn_plan_with_graph_api( const array& x,
cu::CommandEncoder& encoder, const std::vector<int64_t>& shape,
cudnn_frontend::ExecutionPlan& plan, const std::vector<int64_t>& strides);
CudaGraph& graph, void set_tensor_attrs(
int num_args, std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
const int64_t* uids, int64_t uid,
void** data_ptrs); const array& x);
#endif void set_tensor_attrs_nchw(
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
int64_t uid,
const array& x);
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z). cudnnHandle_t handle_;
template <typename... Args> };
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_capturing(
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
}
#if CUDNN_VERSION >= 90500
template <typename... Args>
bool encode_cudnn_plan(
cu::CommandEncoder& encoder,
cudnn_frontend::ExecutionPlan& plan,
CudaGraph& graph,
std::initializer_list<int64_t> uids,
Args&... args) {
assert(uids.size() == sizeof...(args));
auto data_ptrs = get_data_ptrs(args...);
return encode_cudnn_plan_with_graph_api(
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
}
#endif
} // namespace mlx::core } // namespace mlx::core

View File

@@ -57,7 +57,7 @@ std::string build_kernel(
const std::vector<std::string>& output_names, const std::vector<std::string>& output_names,
const std::vector<Dtype>& output_dtypes, const std::vector<Dtype>& output_dtypes,
const std::vector<std::pair<std::string, TemplateArg>>& template_args, const std::vector<std::pair<std::string, TemplateArg>>& template_args,
const std::vector<CustomKernelShapeInfo>& shape_infos) { const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
std::string kernel_source; std::string kernel_source;
kernel_source.reserve(header.size() + source.size() + 8192); kernel_source.reserve(header.size() + source.size() + 8192);
kernel_source += default_header; kernel_source += default_header;
@@ -81,17 +81,17 @@ std::string build_kernel(
kernel_source += ",\n"; kernel_source += ",\n";
// Add input shape, strides and ndim if present in the source // Add input shape, strides and ndim if present in the source
if (arr.ndim() > 0) { if (arr.ndim() > 0) {
if (shape_infos[i].shape) { if (std::get<0>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Shape "; kernel_source += " const __grid_constant__ Shape ";
kernel_source += name; kernel_source += name;
kernel_source += "_shape,\n"; kernel_source += "_shape,\n";
} }
if (shape_infos[i].strides) { if (std::get<1>(shape_infos[i])) {
kernel_source += " const __grid_constant__ Strides "; kernel_source += " const __grid_constant__ Strides ";
kernel_source += name; kernel_source += name;
kernel_source += "_strides,\n"; kernel_source += "_strides,\n";
} }
if (shape_infos[i].ndim) { if (std::get<2>(shape_infos[i])) {
kernel_source += " const __grid_constant__ int "; kernel_source += " const __grid_constant__ int ";
kernel_source += name; kernel_source += name;
kernel_source += "_ndim,\n"; kernel_source += "_ndim,\n";
@@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel(
"[custom_kernel] Must specify at least one output."); "[custom_kernel] Must specify at least one output.");
} }
std::vector<CustomKernelShapeInfo> shape_infos; std::vector<std::tuple<bool, bool, bool>> shape_infos;
for (auto& n : input_names) { for (auto& n : input_names) {
CustomKernelShapeInfo shape_info; std::tuple<bool, bool, bool> shape_info;
shape_info.shape = source.find(n + "_shape") != std::string::npos; std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
shape_info.strides = source.find(n + "_strides") != std::string::npos; std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
shape_info.ndim = source.find(n + "_ndim") != std::string::npos; std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
shape_infos.push_back(shape_info); shape_infos.push_back(shape_info);
} }
@@ -254,8 +254,8 @@ std::vector<array> precompiled_cuda_kernel(
std::optional<float> init_value, std::optional<float> init_value,
bool ensure_row_contiguous, bool ensure_row_contiguous,
StreamOrDevice s) { StreamOrDevice s) {
std::vector<CustomKernelShapeInfo> shape_infos( std::vector<std::tuple<bool, bool, bool>> shape_infos(
inputs.size(), CustomKernelShapeInfo{false, false, false}); inputs.size(), {false, false, false});
return array::make_arrays( return array::make_arrays(
output_shapes, output_shapes,
output_dtypes, output_dtypes,
@@ -289,7 +289,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype()); copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s); fill_gpu(copies.back(), out, s);
} else { } else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} }
} }
@@ -327,13 +327,13 @@ void CustomKernel::eval_gpu(
const array& in = checked_inputs[i]; const array& in = checked_inputs[i];
auto& shape_info = shape_infos_[i]; auto& shape_info = shape_infos_[i];
args.append(in); args.append(in);
if (shape_info.shape) { if (std::get<0>(shape_info)) {
args.append_ndim(in.shape()); args.append_ndim(in.shape());
} }
if (shape_info.strides) { if (std::get<1>(shape_info)) {
args.append_ndim(in.strides()); args.append_ndim(in.strides());
} }
if (shape_info.ndim) { if (std::get<2>(shape_info)) {
args.append<int32_t>(in.ndim()); args.append<int32_t>(in.ndim());
} }
} }

View File

@@ -14,20 +14,20 @@ namespace mlx::core::cu {
namespace { namespace {
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) bool use_cuda_graphs() {
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
void check_cudnn_error(const char* name, cudnnStatus_t err) { return use_graphs;
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
}
} }
bool use_cuda_graphs() { const char* save_cuda_graphs_dot_file() {
static bool use_graphs = []() { static const char* filename = []() -> const char* {
return env::get_var("MLX_USE_CUDA_GRAPHS", true); const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
if (env && std::strlen(env) == 0) {
return nullptr;
}
return env;
}(); }();
return use_graphs; return filename;
} }
} // namespace } // namespace
@@ -46,6 +46,7 @@ Device::Device(int device) : device_(device) {
"Device {} does not support synchronization in managed memory.", "Device {} does not support synchronization in managed memory.",
device_)); device_));
} }
// The cublasLt handle is used by matmul. // The cublasLt handle is used by matmul.
make_current(); make_current();
CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_)); CHECK_CUBLAS_ERROR(cublasLtCreate(&lt_));
@@ -86,7 +87,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
return; return;
} }
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeThreadLocal));
} }
CommandEncoder::CaptureContext::~CaptureContext() { CommandEncoder::CaptureContext::~CaptureContext() {
@@ -114,18 +115,17 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
} }
// Use an empty graph node for synchronization // Use an empty graph node for synchronization
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)}; CommandEncoder::GraphNode empty{NULL, "E", std::to_string(enc.node_count_++)};
enc.empty_node_count_++;
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0)); CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
// Insert the concurrent -> empty node dependencies // Insert the concurrent -> empty node dependencies
for (auto& from : enc.concurrent_nodes_) { for (auto& from : enc.concurrent_nodes_) {
enc.from_nodes_.push_back(from.node); enc.from_nodes_.push_back(from.node);
enc.to_nodes_.push_back(empty.node); enc.to_nodes_.push_back(empty.node);
enc.graph_key_ += from.id; enc.graph_deps_key_ += from.id;
enc.graph_key_ += from.node_type; enc.graph_deps_key_ += "-";
enc.graph_key_ += empty.id; enc.graph_deps_key_ += empty.id;
enc.graph_key_ += empty.node_type; enc.graph_deps_key_ += "-";
} }
// Insert the input -> concurrent node dependencies without updating output // Insert the input -> concurrent node dependencies without updating output
@@ -140,9 +140,6 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
} }
void CommandEncoder::insert_graph_dependencies(GraphNode node) { void CommandEncoder::insert_graph_dependencies(GraphNode node) {
if (node.node_type == 'G') {
graph_node_count_++;
}
node.id = std::to_string(node_count_++); node.id = std::to_string(node_count_++);
if (in_concurrent_) { if (in_concurrent_) {
concurrent_nodes_.push_back(std::move(node)); concurrent_nodes_.push_back(std::move(node));
@@ -154,6 +151,10 @@ void CommandEncoder::insert_graph_dependencies(GraphNode node) {
} }
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) { void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
for (auto& node : nodes) {
graph_nodes_key_ += node.node_type;
graph_nodes_key_ += "-";
}
std::vector<GraphNode> deps; std::vector<GraphNode> deps;
{ {
// Dependencies must be added in the same order to produce a consistent // Dependencies must be added in the same order to produce a consistent
@@ -181,20 +182,49 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
for (auto& to : nodes) { for (auto& to : nodes) {
from_nodes_.push_back(from.node); from_nodes_.push_back(from.node);
to_nodes_.push_back(to.node); to_nodes_.push_back(to.node);
graph_key_ += from.id; graph_deps_key_ += from.id;
graph_key_ += from.node_type; graph_deps_key_ += "-";
graph_key_ += to.id; graph_deps_key_ += to.id;
graph_key_ += to.node_type; graph_deps_key_ += "-";
} }
} }
} }
// Can be tuned with MLX_MAX_OPS_PER_BUFFER, MLX_MAX_MB_PER_BUFFER
std::pair<int, int> get_graph_limits(Device& d) {
auto cc =
d.compute_capability_major() * 100 + d.compute_capability_minor() * 10;
int ops = 20;
int mb = 100;
switch (cc) {
case 800: // A100
ops = 20;
mb = 400;
break;
case 900: // H100
ops = 30;
mb = 400;
break;
case 1000: // B200
ops = 50;
mb = 500;
break;
case 1210: // DGX Spark
ops = 20;
mb = 25;
break;
}
return {env::max_ops_per_buffer(ops), env::max_mb_per_buffer(mb)};
}
CommandEncoder::CommandEncoder(Device& d) CommandEncoder::CommandEncoder(Device& d)
: device_(d), : device_(d),
stream_(d), stream_(d),
graph_(d), graph_(d),
worker_(d), worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {} graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {
std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d);
}
void CommandEncoder::add_completed_handler(std::function<void()> task) { void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task)); worker_.add_task(std::move(task));
@@ -204,6 +234,7 @@ void CommandEncoder::set_input_array(const array& arr) {
if (!use_cuda_graphs()) { if (!use_cuda_graphs()) {
return; return;
} }
bytes_in_graph_ += arr.data_size();
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr()); auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
active_deps_.push_back(id); active_deps_.push_back(id);
} }
@@ -278,13 +309,76 @@ void CommandEncoder::add_kernel_node(
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) { void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
cudaGraphNode_t node; cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params)); CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'}); insert_graph_dependencies(GraphNode{node, "K"});
} }
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) { void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
CUgraphNode node; CUgraphNode node;
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params)); CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, &params));
insert_graph_dependencies(GraphNode{node, 'K'}); insert_graph_dependencies(GraphNode{node, "K"});
}
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
// Constructs a key representing the nodes of a sub-graph.
// Also checks if the sub-graph is updatable as CUDA graphs do not get
// updated correctly if a kernel node getting updated has a different cluster
// shape than the node it's being updated with.
std::string key = "(";
size_t num_nodes = 0;
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
if (num_nodes == 0) {
return {key + ")", true};
}
bool is_updatable = true;
std::vector<cudaGraphNode_t> nodes(num_nodes);
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
for (const auto& node : nodes) {
if (!is_updatable) {
break;
}
cudaGraphNodeType type;
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
switch (type) {
case cudaGraphNodeTypeGraph: {
// Try to be updatable for a structure like graph -> graph -> kernel
cudaGraph_t child;
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
is_updatable &= sub_is_updatable;
key += subkey;
break;
}
case cudaGraphNodeTypeHost:
key += "H";
break;
case cudaGraphNodeTypeMemset:
key += "M";
break;
case cudaGraphNodeTypeKernel: {
cudaLaunchAttributeValue cluster_dim;
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
// Only allow dim.x to be greater than 1
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
is_updatable = false;
} else {
key += "K";
key += std::to_string(cluster_dim.clusterDim.x);
}
break;
}
case cudaGraphNodeTypeWaitEvent:
key += "W";
break;
case cudaGraphNodeTypeEventRecord:
key += "R";
break;
default:
is_updatable = false;
}
}
key += ")";
return {key, is_updatable};
} }
void CommandEncoder::add_graph_node(cudaGraph_t child) { void CommandEncoder::add_graph_node(cudaGraph_t child) {
@@ -297,12 +391,15 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
return; return;
} }
cudaGraphNode_t node; cudaGraphNode_t node;
auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
is_graph_updatable_ &= is_updatable;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'}); insert_graph_dependencies(GraphNode{node, sub_graph_key});
} }
int CommandEncoder::get_num_ops() { bool CommandEncoder::needs_commit() {
return node_count_; return (node_count_ > max_ops_per_graph_) ||
((bytes_in_graph_ >> 20) > max_mb_per_graph_);
} }
void CommandEncoder::commit() { void CommandEncoder::commit() {
@@ -322,53 +419,63 @@ void CommandEncoder::commit() {
from_nodes_.size())); from_nodes_.size()));
} }
graph_key_ += ".";
graph_key_ += std::to_string(node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(graph_node_count_);
graph_key_ += ".";
graph_key_ += std::to_string(empty_node_count_);
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graph_exec, graph_, &info);
update_result = info.result;
#else
cudaGraphNode_t error_node;
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
graph_exec.reset();
}
}
if (graph_exec == nullptr) {
graph_exec.instantiate(graph_);
}
device_.make_current(); device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
if (!is_graph_updatable_) {
CudaGraphExec graph_exec;
graph_exec.instantiate(graph_);
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
} else {
auto graph_key = graph_nodes_key_ + ":" + graph_deps_key_;
auto& graph_exec = graph_cache_[graph_key];
if (graph_exec != nullptr) {
cudaGraphExecUpdateResult update_result;
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo info;
cudaGraphExecUpdate(graph_exec, graph_, &info);
update_result = info.result;
#else
cudaGraphNode_t error_node;
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
#endif // CUDART_VERSION >= 12000
if (update_result != cudaGraphExecUpdateSuccess) {
cudaGetLastError(); // reset error
graph_exec.reset();
}
}
if (graph_exec == nullptr) {
graph_exec.instantiate(graph_);
}
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
}
// Save cuda graph to dot file
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
static int count = 0;
auto path = fmt::format("{}_{}.dot", filename, ++count);
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
}
// Reset state // Reset state
graph_node_count_ = 0;
empty_node_count_ = 0;
from_nodes_.clear(); from_nodes_.clear();
to_nodes_.clear(); to_nodes_.clear();
graph_key_.clear(); graph_deps_key_.clear();
graph_nodes_key_.clear();
node_map_.clear(); node_map_.clear();
graph_ = CudaGraph(device_); graph_ = CudaGraph(device_);
is_graph_updatable_ = true;
} }
// Put completion handlers in a batch. // Put completion handlers in a batch.
worker_.commit(stream_); worker_.commit(stream_);
node_count_ = 0; node_count_ = 0;
bytes_in_graph_ = 0;
} }
void CommandEncoder::synchronize() { void CommandEncoder::synchronize() {
cudaStreamSynchronize(stream_); CHECK_CUDA_ERROR(cudaStreamSynchronize(stream_));
auto p = std::make_shared<std::promise<void>>(); auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future(); std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); }); add_completed_handler([p = std::move(p)]() { p->set_value(); });

View File

@@ -84,7 +84,7 @@ class CommandEncoder {
} }
void add_completed_handler(std::function<void()> task); void add_completed_handler(std::function<void()> task);
int get_num_ops(); bool needs_commit();
void commit(); void commit();
Device& device() { Device& device() {
@@ -106,8 +106,9 @@ class CommandEncoder {
cudaGraphNode_t node; cudaGraphNode_t node;
// K = kernel // K = kernel
// E = empty // E = empty
// G = subgraph // () = subgraph (with metadata)
char node_type; // Symbols ':', '-' are reserved as separators
std::string node_type;
std::string id; std::string id;
}; };
@@ -119,18 +120,21 @@ class CommandEncoder {
CudaGraph graph_; CudaGraph graph_;
Worker worker_; Worker worker_;
char node_count_{0}; char node_count_{0};
char graph_node_count_{0};
char empty_node_count_{0};
bool in_concurrent_{false}; bool in_concurrent_{false};
std::vector<cudaGraphNode_t> from_nodes_; std::vector<cudaGraphNode_t> from_nodes_;
std::vector<cudaGraphNode_t> to_nodes_; std::vector<cudaGraphNode_t> to_nodes_;
std::string graph_key_; std::string graph_nodes_key_;
std::string graph_deps_key_;
std::vector<GraphNode> concurrent_nodes_; std::vector<GraphNode> concurrent_nodes_;
std::vector<std::shared_ptr<array::Data>> temporaries_; std::vector<std::shared_ptr<array::Data>> temporaries_;
LRUCache<std::string, CudaGraphExec> graph_cache_; LRUCache<std::string, CudaGraphExec> graph_cache_;
std::vector<std::uintptr_t> active_deps_; std::vector<std::uintptr_t> active_deps_;
std::vector<std::uintptr_t> active_outputs_; std::vector<std::uintptr_t> active_outputs_;
std::unordered_map<std::uintptr_t, GraphNode> node_map_; std::unordered_map<std::uintptr_t, GraphNode> node_map_;
size_t bytes_in_graph_{0};
bool is_graph_updatable_{true};
int max_ops_per_graph_;
int max_mb_per_graph_;
}; };
class Device { class Device {
@@ -166,6 +170,7 @@ class Device {
int device_; int device_;
int compute_capability_major_; int compute_capability_major_;
int compute_capability_minor_; int compute_capability_minor_;
std::string device_name_;
cublasLtHandle_t lt_; cublasLtHandle_t lt_;
cudnnHandle_t cudnn_; cudnnHandle_t cudnn_;
std::unordered_map<int, CommandEncoder> encoders_; std::unordered_map<int, CommandEncoder> encoders_;

View File

@@ -26,7 +26,7 @@ void AllReduce::eval_gpu(
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
return {in, out}; return {in, out};
} else { } else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
return {in, out}; return {in, out};
} }
}; };
@@ -74,7 +74,7 @@ void AllGather::eval_gpu(
}; };
auto input = ensure_contiguous(inputs[0]); auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream())); outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
encoder.set_input_array(input); encoder.set_input_array(input);
encoder.set_output_array(outputs[0]); encoder.set_output_array(outputs[0]);
@@ -103,7 +103,7 @@ void ReduceScatter::eval_gpu(
}; };
auto input = ensure_contiguous(inputs[0]); auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream())); outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
encoder.set_input_array(input); encoder.set_input_array(input);
encoder.set_output_array(outputs[0]); encoder.set_output_array(outputs[0]);

View File

@@ -11,9 +11,6 @@
namespace mlx::core::gpu { namespace mlx::core::gpu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;
bool is_available() { bool is_available() {
return true; return true;
} }
@@ -53,8 +50,7 @@ void eval(array& arr) {
encoder.add_temporary(s); encoder.add_temporary(s);
} }
if (encoder.get_num_ops() >= if (encoder.needs_commit()) {
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
scheduler::notify_new_task(stream); scheduler::notify_new_task(stream);
encoder.add_completed_handler( encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); }); [stream]() { scheduler::notify_task_completion(stream); });

View File

@@ -305,6 +305,7 @@ void Event::wait() {
} else { } else {
event->atomic->wait(value()); event->atomic->wait(value());
} }
CHECK_CUDA_ERROR(cudaPeekAtLastError());
} }
void Event::wait(Stream s) { void Event::wait(Stream s) {

View File

@@ -370,7 +370,7 @@ void CublasGemm::execute(
// Ensure workspace is 256-byte aligned // Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256; int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace( array workspace(
cu::malloc_async(nbytes, encoder.stream()), cu::malloc_async(nbytes, encoder),
{static_cast<int>(heuristic_.workspaceSize)}, {static_cast<int>(heuristic_.workspaceSize)},
int8); int8);
encoder.add_temporary(workspace); encoder.add_temporary(workspace);

View File

@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets // Launch kernel to set device offsets
auto pointers = array( auto pointers = array(
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()), cu::malloc_async(batch_count * sizeof(void*) * 3, encoder),
{batch_count * 3}, {batch_count * 3},
uint64); uint64);
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets // Launch kernel to set device offsets
auto pointers = array( auto pointers = array(
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()), cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder),
{batch_count * 4}, {batch_count * 4},
uint64); uint64);

View File

@@ -61,7 +61,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@@ -241,7 +241,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }

View File

@@ -279,11 +279,14 @@ void compile(
// Compile program. // Compile program.
std::vector<const char*> args; std::vector<const char*> args;
bool use_sass = compiler_supports_device_sass(device); bool use_sass = compiler_supports_device_sass(device);
auto cc = device.compute_capability_major();
std::string arch_tag = (cc == 90 || cc == 100 || cc == 121) ? "a" : "";
std::string compute = fmt::format( std::string compute = fmt::format(
"--gpu-architecture={}_{}{}", "--gpu-architecture={}_{}{}{}",
use_sass ? "sm" : "compute", use_sass ? "sm" : "compute",
device.compute_capability_major(), cc,
device.compute_capability_minor()); device.compute_capability_minor(),
arch_tag);
args.push_back(compute.c_str()); args.push_back(compute.c_str());
std::string cccl_include = cccl_dir(); std::string cccl_include = cccl_dir();
if (!cccl_include.empty()) { if (!cccl_include.empty()) {

View File

@@ -244,7 +244,7 @@ void LayerNorm::eval_gpu(
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
out.set_data( out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
gx.copy_shared_buffer(g); gx.copy_shared_buffer(g);
g_in_gx = true; g_in_gx = true;
} else { } else {
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
} }
if (g_copied && !g_in_gx) { if (g_copied && !g_in_gx) {
encoder.add_temporary(g); encoder.add_temporary(g);
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
g_in_gw = true; g_in_gw = true;
gw_temp.copy_shared_buffer(g); gw_temp.copy_shared_buffer(g);
} else { } else {
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
encoder.add_temporary(gw_temp); encoder.add_temporary(gw_temp);
} }
} }

View File

@@ -32,7 +32,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream()); auto& encoder = cu::get_command_encoder(stream());
auto size = out.size(); auto size = out.size();
auto nbytes = size * out.itemsize(); auto nbytes = size * out.itemsize();
out.set_data(cu::malloc_async(nbytes, encoder.stream())); out.set_data(cu::malloc_async(nbytes, encoder));
auto out_ptr = malloc(nbytes); auto out_ptr = malloc(nbytes);
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_); reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
if (swap_endianness_) { if (swap_endianness_) {

View File

@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = ensure_contiguous(inputs[0]); auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) { if (in.flags().row_contiguous) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else { } else {
auto n = in.shape(-1); auto n = in.shape(-1);
auto flags = in.flags(); auto flags = in.flags();
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
flags.col_contiguous = col_contig; flags.col_contiguous = col_contig;
out.set_data( out.set_data(
cu::malloc_async(in.nbytes() / n, encoder.stream()), cu::malloc_async(in.nbytes() / n, encoder),
in.data_size() / n, in.data_size() / n,
std::move(strides), std::move(strides),
flags); flags);

View File

@@ -135,12 +135,19 @@ class LRUCache {
}; };
// Turn a POD struct into a container key by doing bytes compare. // Turn a POD struct into a container key by doing bytes compare.
//
// Usage:
// BytesKey<MyKey> key;
// key.pod = { ... };
template <typename T> template <typename T>
struct BytesKey { struct BytesKey {
T pod; T pod;
static_assert(std::is_standard_layout_v<T>, "T is not POD"); static_assert(std::is_standard_layout_v<T>, "T is not POD");
BytesKey(T pod) : pod(std::move(pod)) {} BytesKey() {
// Make sure the paddings between members are filled with 0.
memset(&pod, 0, sizeof(T));
}
BytesKey(const BytesKey& other) { BytesKey(const BytesKey& other) {
memcpy(&pod, &other.pod, sizeof(T)); memcpy(&pod, &other.pod, sizeof(T));

View File

@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
return; return;
} }
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
int M = a_pre.shape(-2); int M = a_pre.shape(-2);
int N = b_pre.shape(-1); int N = b_pre.shape(-1);
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 && if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
c.data_size() == out.shape(-1)) { c.data_size() == out.shape(-1)) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
gemm_and_bias( gemm_and_bias(
encoder, encoder,
M, M,
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto sty = c.strides()[c.ndim() - 1]; auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) { if (sty == 1 && stx == c.shape(-1)) {
ldc = stx; ldc = stx;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else if (sty == 1 && stx == 0) { } else if (sty == 1 && stx == 0) {
ldc = 0; ldc = 0;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} else { } else {
// Copy C into out and set C to out // Copy C into out and set C to out
ldc = c.shape(-1); ldc = c.shape(-1);

View File

@@ -37,6 +37,7 @@ NO_GPU(Inverse)
NO_GPU(Cholesky) NO_GPU(Cholesky)
NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
NO_GPU(MaskedScatter)
namespace distributed { namespace distributed {
NO_GPU_MULTI(Send) NO_GPU_MULTI(Send)

View File

@@ -2,7 +2,11 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh"
#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh"
#include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/quantized/quantized.h"
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
@@ -13,17 +17,6 @@
namespace mlx::core { namespace mlx::core {
namespace cu { namespace cu {
template <int bits>
struct Quantize {
__device__ uint8_t operator()(float x) {
if constexpr (bits == 8) {
return __nv_fp8_e4m3(x).__x;
} else {
return __nv_fp4_e2m1(x).__x;
}
}
};
template <int bits> template <int bits>
struct Dequantize { struct Dequantize {
__device__ float operator()(uint8_t x) { __device__ float operator()(uint8_t x) {
@@ -37,29 +30,40 @@ struct Dequantize {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <typename T, int group_size, int bits, bool use_mx_scale> template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
__global__ void __global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) { using Tx2 = Vector2_t<T>;
using Tx4 = Vector4_t<T>;
uint32_t rbits = 0; // reserved bits for future use
auto block_size = cg::this_thread_block().dim_threads(); auto block_size = cg::this_thread_block().dim_threads();
auto block_idx = cg::this_thread_block().group_index(); auto block_idx = cg::this_thread_block().group_index();
auto idx_in_block = cg::this_thread_block().thread_index(); auto idx_in_block = cg::this_thread_block().thread_index();
auto tidx = block_idx.x * block_size.x + idx_in_block.x; auto tidx = block_idx.x * block_size.x + idx_in_block.x;
auto tidy = block_idx.y * block_size.y + idx_in_block.y; auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
auto grid_dim_x = size_t thread_idx = tidx + grid_dim_x * size_t(tidy);
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x; size_t base_idx = thread_idx * group_size;
size_t index = tidx + grid_dim_x * size_t(tidy);
if (index >= size) { if (base_idx >= size) {
return; return;
} }
float w_thread = w[index]; auto w_tile = load_vector<group_size, T>(w, thread_idx);
float scale = 0.0f;
cg::greater<float> max_op; Tx2 amax_2x = Tx2{0.0f, 0.0f};
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
#pragma unroll
for (int i = 0; i < group_size; i += 2) {
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
}
scale = static_cast<float>(
max(fabsf(static_cast<float>(amax_2x.x)),
fabsf(static_cast<float>(amax_2x.y))));
float scale = cg::reduce(warp, abs(w_thread), max_op);
scale /= bits == 4 ? 6.0f : 448.0f; scale /= bits == 4 ? 6.0f : 448.0f;
// Convert to mx scale or nv scale // Convert to mx scale or nv scale
using ScaleType = using ScaleType =
@@ -68,21 +72,24 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
uint8_t q_scale = s.__x; uint8_t q_scale = s.__x;
scale = float(s); scale = float(s);
// Write out the scales scales[thread_idx] = q_scale;
size_t gindex = index / group_size; constexpr int elem_per_byte = bits == 8 ? 1 : 2;
if (index % group_size == 0) { AlignedVector<uint8_t, group_size / elem_per_byte> quantized;
scales[gindex] = q_scale;
}
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale); #pragma unroll
if (bits == 4) { for (int i = 0; i < group_size / 4; i++) {
uint8_t sval = warp.shfl_down(output, 1); Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
output |= sval << bits; if constexpr (bits == 8) {
} uint32_t quantized_val =
constexpr int pack_factor = bits == 8 ? 1 : 2; scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
if (index % pack_factor == 0) { *reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
out[index / pack_factor] = output; } else {
uint16_t quantized_val =
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
*reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
}
} }
store_vector<group_size / elem_per_byte>(out, thread_idx, quantized);
} }
template <typename T, int group_size, int bits, bool use_mx_scale> template <typename T, int group_size, int bits, bool use_mx_scale>
@@ -142,15 +149,16 @@ void fp_quantize(
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) { dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if constexpr (!std::is_same_v<T, double>) { if constexpr (!std::is_same_v<T, double>) {
auto kernel = cu::fp_quantize<T, 32, 4, true>; auto kernel = cu::fp_quantize<T, 32, 4, true, false>;
if (bits == 8) { if (bits == 8) {
kernel = cu::fp_quantize<T, 32, 8, true>; kernel = cu::fp_quantize<T, 32, 8, true, false>;
} else if (group_size == 16) { } else if (group_size == 16) {
kernel = cu::fp_quantize<T, 16, 4, false>; kernel = cu::fp_quantize<T, 16, 4, false, false>;
} }
bool large = w.size() > UINT_MAX; bool large = w.size() > UINT_MAX;
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(w.size(), w.shape(), w.strides(), large); get_launch_args(w.size(), w.shape(), w.strides(), large, group_size);
enc.add_kernel_node( enc.add_kernel_node(
kernel, kernel,
num_blocks, num_blocks,

View File

@@ -0,0 +1,32 @@
#pragma once
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
// TODO implement fast path
template <typename T>
__device__ __forceinline__ uint32_t
scale_cvt_Tx4_to_fp8x4_fallback(const Vector4_t<T> input, const float scale) {
uint32_t out_fp8x4 = 0;
float4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
out_fp8x4 = __nv_fp8x4_e4m3(scaled).__x;
return out_fp8x4;
}
// Place holder for future fast path implementation
template <typename T, bool USE_SR>
__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
return scale_cvt_Tx4_to_fp8x4_fallback(input, scale);
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,334 @@
#pragma once
#include <cuda.h>
#include <cuda_fp4.h>
#include <cuda_runtime.h>
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu {
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using f32x4 = Vector4_t<float>;
template <typename T>
__device__ __forceinline__ uint16_t
scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
// Fallback implementation for architectures that do not support cvt
// instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
uint16_t out_fp4x4 = 0;
fp32x4 scaled;
scaled.x = static_cast<float>(input.x) * scale;
scaled.y = static_cast<float>(input.y) * scale;
scaled.z = static_cast<float>(input.z) * scale;
scaled.w = static_cast<float>(input.w) * scale;
uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
(static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
static_cast<uint16_t>(q0);
return out_fp4x4;
}
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
defined(__CUDA_ARCH_SPECIFIC__)
__device__ __forceinline__ uint16_t
scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t" // first bf16
".reg.b16 x1_bf16; \n\t" // second bf16
".reg.b16 x2_bf16; \n\t" // third bf16
".reg.b16 x3_bf16; \n\t" // fourth bf16
".reg.b32 x0; \n\t" // to hold scaled first
".reg.b32 x1; \n\t" // to hold scaled second
".reg.b32 x2; \n\t" // to hold scaled third
".reg.b32 x3; \n\t" // to hold scaled fourth
".reg.b64 x01; \n\t" // to hold vector mul
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
"cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t" // scale first pair
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t" // scale second pair
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
// pair
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
// pair
"mov.b16 %0, {q0, q1}; \n\t" // pack to output
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(
scale))); // here cast is needed becuase an asm operand must have
// scalar type
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
const bf16x4 input_bf16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_bf16; \n\t"
".reg.b16 x1_bf16; \n\t"
".reg.b16 x2_bf16; \n\t"
".reg.b16 x3_bf16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
"cvt.f32.bf16 x0, x0_bf16; \n\t"
"cvt.f32.bf16 x1, x1_bf16; \n\t"
"cvt.f32.bf16 x2, x2_bf16; \n\t"
"cvt.f32.bf16 x3, x3_bf16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
const float2 input_fp32x2_0,
const float2 input_fp32x2_1,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 x01, {%1, %2}; \n\t"
"mul.f32x2 x01, x01, %5; \n\t"
"mov.b64 x23, {%3, %4}; \n\t"
"mul.f32x2 x23, x23, %5; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
"}"
: "=h"(out_fp4x4)
: "f"(input_fp32x2_0.x),
"f"(input_fp32x2_0.y),
"f"(input_fp32x2_1.x),
"f"(input_fp32x2_1.y),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t
scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b8 q0; \n\t"
".reg.b8 q1; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
"mov.b16 %0, {q0, q1}; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)));
return out_fp4x4;
}
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
const fp16x4 input_fp16x4,
const float2 scale,
uint32_t rbits) {
uint16_t out_fp4x4 = 0;
asm volatile(
"{\n"
".reg.b16 x0_fp16; \n\t"
".reg.b16 x1_fp16; \n\t"
".reg.b16 x2_fp16; \n\t"
".reg.b16 x3_fp16; \n\t"
".reg.b32 x0; \n\t"
".reg.b32 x1; \n\t"
".reg.b32 x2; \n\t"
".reg.b32 x3; \n\t"
".reg.b64 x01; \n\t"
".reg.b64 x23; \n\t"
".reg.b16 q0; \n\t"
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
"cvt.f32.f16 x0, x0_fp16; \n\t"
"cvt.f32.f16 x1, x1_fp16; \n\t"
"cvt.f32.f16 x2, x2_fp16; \n\t"
"cvt.f32.f16 x3, x3_fp16; \n\t"
"mov.b64 x01, {x0, x1}; \n\t"
"mul.f32x2 x01, x01, %2; \n\t"
"mov.b64 x23, {x2, x3}; \n\t"
"mul.f32x2 x23, x23, %2; \n\t"
"mov.b64 {x0, x1}, x01; \n\t"
"mov.b64 {x2, x3}, x23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
"}"
: "=h"(out_fp4x4)
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
"l"(reinterpret_cast<const uint64_t&>(scale)),
"r"(rbits));
return out_fp4x4;
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
const bf16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
const fp16x4 input,
const float scale,
uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
if constexpr (USE_SR) {
return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
} else {
return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
}
}
template <bool USE_SR>
__device__ __forceinline__ uint16_t
scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
float2 scale_fp32x2 = make_float2(scale, scale);
float2 input_fp32x2_0 = make_float2(input.x, input.y);
float2 input_fp32x2_1 = make_float2(input.z, input.w);
if constexpr (USE_SR) {
return scale_cvt_fp32x4_to_fp4x4_rs(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
} else {
return scale_cvt_fp32x4_to_fp4x4_rn(
input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
}
}
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else if constexpr (std::is_same<T, __half>::value) {
return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
} else {
return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
}
}
#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
template <typename T, bool USE_SR>
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
const Vector4_t<T> input,
const float scale,
uint32_t rbits) {
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
(__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
#else
static_assert(
!USE_SR,
"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
#endif
}
} // namespace mlx::core::cu

View File

@@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu(
auto scales = ensure_row_contiguous(inputs[1], enc, s); auto scales = ensure_row_contiguous(inputs[1], enc, s);
auto& w = outputs[0]; auto& w = outputs[0];
w.set_data(cu::malloc_async(w.nbytes(), enc.stream())); w.set_data(cu::malloc_async(w.nbytes(), enc));
if (mode_ == QuantizationMode::Affine) { if (mode_ == QuantizationMode::Affine) {
auto biases = ensure_row_contiguous(inputs[2], enc, s); auto biases = ensure_row_contiguous(inputs[2], enc, s);
@@ -72,11 +72,11 @@ void fast::Quantize::eval_gpu(
auto& wq = outputs[0]; auto& wq = outputs[0];
auto& scales = outputs[1]; auto& scales = outputs[1];
wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream())); wq.set_data(cu::malloc_async(wq.nbytes(), enc));
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream())); scales.set_data(cu::malloc_async(scales.nbytes(), enc));
if (mode_ == QuantizationMode::Affine) { if (mode_ == QuantizationMode::Affine) {
auto& biases = outputs[2]; auto& biases = outputs[2];
biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream())); biases.set_data(cu::malloc_async(biases.nbytes(), enc));
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
} else { } else {
fp_quantize(w, wq, scales, group_size_, bits_, enc, s); fp_quantize(w, wq, scales, group_size_, bits_, enc, s);

View File

@@ -15,6 +15,22 @@ inline constexpr __device__ short get_bytes_per_pack() {
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
} }
template <typename T>
__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
if constexpr (
(std::is_same<T, __nv_bfloat162>::value) ||
(std::is_same<T, __half2>::value)) {
T a = x1;
T b = x2;
out = __hmax2(__habs2(a), __habs2(b));
} else if constexpr (std::is_same<T, float2>::value) {
float2 a = x1;
float2 b = x2;
out.x = fmaxf(fabsf(a.x), fabsf(b.x));
out.y = fmaxf(fabsf(a.y), fabsf(b.y));
}
}
} // namespace cu } // namespace cu
template <typename F> template <typename F>

View File

@@ -139,30 +139,36 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
// keys has shape (N1, ..., NK, 2) // keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...) // out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0]; auto& keys = inputs[0];
uint32_t num_keys = keys.size() / 2; size_t num_keys = keys.size() / 2;
uint32_t elems_per_key = out.size() / num_keys; size_t elems_per_key = out.size() / num_keys;
uint32_t bytes_per_key = out.itemsize() * elems_per_key; size_t bytes_per_key = out.itemsize() * elems_per_key;
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4; size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
uint32_t half_size = out_per_key / 2; size_t half_size = out_per_key / 2;
bool odd = out_per_key % 2; bool odd = out_per_key % 2;
if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {
throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported");
}
encoder.set_input_array(keys); encoder.set_input_array(keys);
encoder.set_output_array(out); encoder.set_output_array(out);
dim3 grid_dims{num_keys, half_size + odd}; int64_t total = num_keys * (half_size + odd);
int64_t total = grid_dims.x * grid_dims.y; uint32_t threads_y = 1;
int32_t threads_y = 1; while ((total / threads_y) >= UINT_MAX) {
while ((total / threads_y) >= (1U << 31)) {
threads_y *= 2; threads_y *= 2;
} }
int32_t threads_x = cuda::ceil_div(total, threads_y); uint32_t threads_x = cuda::ceil_div(total, threads_y);
dim3 grid_dims{
static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1); auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
auto& stream = encoder.stream(); auto& stream = encoder.stream();
if (keys.flags().row_contiguous) { if (keys.flags().row_contiguous) {

View File

@@ -66,7 +66,7 @@ void all_reduce(
Reduce::ReduceType reduce_type) { Reduce::ReduceType reduce_type) {
constexpr int N_READS = 8; constexpr int N_READS = 8;
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
auto get_args = [](size_t size, int N) { auto get_args = [](size_t size, int N) {
int threads = std::min(512UL, (size + N - 1) / N); int threads = std::min(512UL, (size + N - 1) / N);
@@ -107,8 +107,7 @@ void all_reduce(
encoder.set_input_array(in); encoder.set_input_array(in);
if (blocks > 1) { if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {}); array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data( intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
encoder.add_temporary(intermediate); encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate); encoder.set_output_array(intermediate);
dispatch_all_types(dt, [&](auto type_tag) { dispatch_all_types(dt, [&](auto type_tag) {

View File

@@ -89,9 +89,13 @@ template <
int NDIM, int NDIM,
int BM, int BM,
int BN, int BN,
int N_READS = 4> int N_READS = 4,
__global__ void int BLOCKS = 1>
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { __global__ void col_reduce_looped(
T* in,
U* out,
const __grid_constant__ ColReduceArgs args,
int64_t out_size) {
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<WARP_SIZE>(block);
@@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
size_t tile_idx = grid.block_rank(); size_t tile_idx = grid.block_rank();
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
size_t tile_out = tile_y / out_size;
tile_y = tile_y % out_size;
// Compute the indices for the thread within the tile // Compute the indices for the thread within the tile
short thread_x = block.thread_rank() % threads_per_row; short thread_x = block.thread_rank() % threads_per_row;
@@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
totals[i] = ReduceInit<Op, T>::value(); totals[i] = ReduceInit<Op, T>::value();
} }
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
size_t total = args.non_col_reductions * args.reduction_size; size_t total = args.non_col_reductions * args.reduction_size;
size_t per_block, start, end;
if constexpr (BLOCKS > 1) {
per_block = (total + BLOCKS - 1) / BLOCKS;
start = tile_out * per_block + thread_y;
end = min((tile_out + 1) * per_block, total);
} else {
per_block = total;
start = thread_y;
end = total;
}
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
if (tile_x * BN + BN <= args.reduction_stride) { if (tile_x * BN + BN <= args.reduction_stride) {
if (args.reduction_stride % N_READS == 0) { if (args.reduction_stride % N_READS == 0) {
for (size_t r = thread_y; r < total; r += BM) { for (size_t r = start; r < end; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
} }
} else { } else {
for (size_t r = thread_y; r < total; r += BM) { for (size_t r = start; r < end; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
} }
} }
} else { } else {
for (size_t r = thread_y; r < total; r += BM) { for (size_t r = start; r < end; r += BM) {
T vals[N_READS]; T vals[N_READS];
cub::LoadDirectBlocked( cub::LoadDirectBlocked(
thread_x, thread_x,
@@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
// Write result. // Write result.
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {
if (BLOCKS > 1) {
out += tile_out * out_size * args.reduction_stride;
}
cub::StoreDirectBlocked( cub::StoreDirectBlocked(
warp.meta_group_rank(), warp.meta_group_rank(),
out + tile_y * args.reduction_stride + tile_x * BN, out + tile_y * args.reduction_stride + tile_x * BN,
@@ -227,11 +247,12 @@ __global__ void col_reduce_small(
inline auto output_grid_for_col_reduce( inline auto output_grid_for_col_reduce(
const array& out, const array& out,
const cu::ColReduceArgs& args, const cu::ColReduceArgs& args,
int bn) { int bn,
int outer = 1) {
int gx, gy = 1; int gx, gy = 1;
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn); size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
size_t n_outer_blocks = out.size() / args.reduction_stride; size_t n_outer_blocks = out.size() / args.reduction_stride;
size_t n_blocks = n_outer_blocks * n_inner_blocks; size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
while (n_blocks / gy > INT32_MAX) { while (n_blocks / gy > INT32_MAX) {
gy *= 2; gy *= 2;
} }
@@ -277,7 +298,8 @@ void col_reduce_looped(
0, 0,
indata, indata,
gpu_ptr<U>(out), gpu_ptr<U>(out),
static_cast<cu::ColReduceArgs>(args)); static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
}); });
}); });
}); });
@@ -320,6 +342,117 @@ void col_reduce_small(
}); });
} }
void col_reduce_two_pass(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan,
const cu::ColReduceArgs& args) {
// Allocate data for the output using in's layout to access them as
// contiguously as possible.
allocate_same_layout(out, in, axes, encoder);
// Allocate an intermediate array to hold the 1st pass result
constexpr int outer = 32;
Shape intermediate_shape;
intermediate_shape.push_back(outer);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
Strides intermediate_strides;
intermediate_strides.push_back(out.size());
intermediate_strides.insert(
intermediate_strides.end(), out.strides().begin(), out.strides().end());
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
auto [data_size, rc, cc] =
check_contiguity(intermediate_shape, intermediate_strides);
auto fl = out.flags();
fl.row_contiguous = rc;
fl.col_contiguous = cc;
fl.contiguous = true;
intermediate.set_data(
cu::malloc_async(intermediate.nbytes(), encoder),
data_size,
intermediate_strides,
fl,
allocator::free);
encoder.add_temporary(intermediate);
encoder.set_input_array(in);
encoder.set_output_array(intermediate);
dispatch_all_types(in.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(gpu_ptr<T>(in));
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
int blocks = BM * BN / N_READS;
auto kernel = cu::
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
indata,
gpu_ptr<U>(intermediate),
static_cast<cu::ColReduceArgs>(args),
out.size() / args.reduction_stride);
});
});
});
// Prepare the reduction arguments for the 2nd pass
cu::ColReduceArgs second_args = args;
second_args.reduction_size = outer;
second_args.reduction_stride = out.size();
second_args.ndim = 0;
second_args.reduce_shape[0] = outer;
second_args.reduce_strides[0] = out.size();
second_args.reduce_ndim = 1;
second_args.non_col_reductions = 1;
encoder.set_input_array(intermediate);
encoder.set_output_array(out);
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
constexpr int N_READS = 4;
constexpr int BM = 32;
constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
int blocks = BM * BN / N_READS;
auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node(
kernel,
grid,
blocks,
0,
gpu_ptr<T>(intermediate),
gpu_ptr<U>(out),
second_args,
second_args.reduction_stride);
});
});
});
}
void col_reduce( void col_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
const array& in, const array& in,
@@ -334,6 +467,18 @@ void col_reduce(
// It is a general strided reduce. Each threadblock computes the output for // It is a general strided reduce. Each threadblock computes the output for
// a subrow of the fast moving axis. For instance 32 elements. // a subrow of the fast moving axis. For instance 32 elements.
// //
// - col_reduce_small
//
// It is a column reduce for small columns. Each thread loops over the whole
// column without communicating with any other thread.
//
// - col_reduce_two_pass
//
// It is a reduce for long columns. To increase parallelism, we split the
// reduction in two passes. First we do a column reduce where many
// threadblocks operate on different parts of the reduced axis. Then we
// perform a final column reduce.
//
// Notes: As in row reduce we opt to read as much in order as possible and // Notes: As in row reduce we opt to read as much in order as possible and
// leave transpositions as they are (contrary to our Metal backend). // leave transpositions as they are (contrary to our Metal backend).
// //
@@ -349,6 +494,14 @@ void col_reduce(
return; return;
} }
// Long column with smallish row
size_t total_sums = args.non_col_reductions * args.reduction_size;
size_t approx_threads = out.size();
if (total_sums / approx_threads > 32) {
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
return;
}
// Fallback col reduce // Fallback col reduce
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
} }

View File

@@ -28,7 +28,7 @@ void init_reduce(
Reduce::ReduceType reduce_type) { Reduce::ReduceType reduce_type) {
// Allocate if needed // Allocate if needed
if (out.data_shared_ptr() == nullptr) { if (out.data_shared_ptr() == nullptr) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} }
encoder.set_output_array(out); encoder.set_output_array(out);

View File

@@ -96,7 +96,7 @@ inline void allocate_same_layout(
const std::vector<int>& axes, const std::vector<int>& axes,
cu::CommandEncoder& encoder) { cu::CommandEncoder& encoder) {
if (in.flags().row_contiguous) { if (in.flags().row_contiguous) {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
return; return;
} }
@@ -135,7 +135,7 @@ inline void allocate_same_layout(
fl.col_contiguous = cc; fl.col_contiguous = cc;
fl.contiguous = true; fl.contiguous = true;
out.set_data( out.set_data(
cu::malloc_async(out.nbytes(), encoder.stream()), cu::malloc_async(out.nbytes(), encoder),
data_size, data_size,
final_strides, final_strides,
fl, fl,

View File

@@ -22,26 +22,28 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
} }
// Similar to cub::BlockReduce, but result is broadcasted to every thread. // Similar to cub::BlockReduce, but result is broadcasted to every thread.
template <typename T, int BLOCK_DIM> template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>
struct BlockBroadcastReduce { struct BlockBroadcastReduce {
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE); using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];
static_assert(BLOCK_DIM % WARP_SIZE == 0);
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
cg::thread_block& block; cg::thread_block& block;
TempStorage& temp; TempStorage& temp;
template <typename Op> template <typename Op>
__device__ T Reduce(const T& input, const Op& op, const T& init_value) { __device__ T Reduce(const T& input, const Op& op, const T& init_value) {
auto warp = cg::tiled_partition<WARP_SIZE>(block); auto warp = cg::tiled_partition<GROUP_DIM>(block);
T x = cg::reduce(warp, input, op); T x = cg::reduce(warp, input, op);
if (warp.thread_rank() == 0) { if constexpr (BLOCK_DIM > GROUP_DIM) {
temp[warp.meta_group_rank()] = x; if (warp.thread_rank() == 0) {
temp[warp.meta_group_rank()] = x;
}
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
} else {
return x;
} }
block.sync();
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
: init_value;
return cg::reduce(warp, x, op);
} }
__device__ T Sum(const T& input) { __device__ T Sum(const T& input) {
@@ -49,6 +51,52 @@ struct BlockBroadcastReduce {
} }
}; };
template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>
__global__ void rms_norm_small(
const T* x,
const T* w,
T* out,
float eps,
uint32_t axis_size,
uint32_t n_rows,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;
__shared__ typename BlockReduceT::TempStorage temp;
auto row =
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
if (row >= n_rows) {
return;
}
x += row * axis_size;
out += row * axis_size;
// Normalizer.
float normalizer = 0;
auto index = block.thread_index().x;
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float t = static_cast<float>(xn[i]);
normalizer += t * t;
}
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
normalizer = rsqrt(normalizer / axis_size + eps);
// Outputs.
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
#pragma unroll
for (int i = 0; i < N_READS; ++i) {
float y = static_cast<float>(xn[i]) * normalizer;
xn[i] = wn[i] * static_cast<T>(y);
}
store_vector<N_READS>(out, index, xn, axis_size);
}
template <typename T, int BLOCK_DIM, int N_READS = 4> template <typename T, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm( __global__ void rms_norm(
const T* x, const T* x,
@@ -94,6 +142,74 @@ __global__ void rms_norm(
} }
} }
template <
typename T,
bool HAS_W,
int BLOCK_DIM,
int REDUCE_DIM,
int N_READS = 4>
__global__ void rms_norm_vjp_small(
const T* x,
const T* w,
const T* g,
T* gx,
T* gw,
float eps,
int32_t axis_size,
int32_t n_rows,
int64_t w_stride) {
auto grid = cg::this_grid();
auto block = cg::this_thread_block();
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;
__shared__ typename BlockReduceF2::TempStorage temp;
auto row =
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
if (row >= n_rows) {
return;
}
x += row * axis_size;
g += row * axis_size;
gx += row * axis_size;
gw += row * axis_size;
// Normalizer.
float2 factors = {};
auto index = block.thread_index().x;
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
for (int i = 0; i < N_READS; i++) {
float t = static_cast<float>(xn[i]);
float wi = wn[i];
float gi = gn[i];
float wg = wi * gi;
factors = plus_f2(factors, {wg * t, t * t});
}
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
float meangwx = factors.x / axis_size;
float normalizer = rsqrt(factors.y / axis_size + eps);
float normalizer3 = normalizer * normalizer * normalizer;
// Outputs.
for (int i = 0; i < N_READS; i++) {
float xi = xn[i];
float wi = wn[i];
float gi = gn[i];
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
if constexpr (HAS_W) {
wn[i] = static_cast<T>(gi * xi * normalizer);
}
}
store_vector<N_READS>(gx, index, xn, axis_size);
if constexpr (HAS_W) {
store_vector<N_READS>(gw, index, wn, axis_size);
}
}
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4> template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
__global__ void rms_norm_vjp( __global__ void rms_norm_vjp(
const T* x, const T* x,
@@ -107,12 +223,8 @@ __global__ void rms_norm_vjp(
auto grid = cg::this_grid(); auto grid = cg::this_grid();
auto block = cg::this_thread_block(); auto block = cg::this_thread_block();
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>; using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
__shared__ union { __shared__ typename BlockReduceF2::TempStorage temp;
typename BlockReduceF::TempStorage f;
typename BlockReduceF2::TempStorage f2;
} temp;
x += grid.block_rank() * axis_size; x += grid.block_rank() * axis_size;
g += grid.block_rank() * axis_size; g += grid.block_rank() * axis_size;
@@ -134,7 +246,7 @@ __global__ void rms_norm_vjp(
factors = plus_f2(factors, {wg * t, t * t}); factors = plus_f2(factors, {wg * t, t * t});
} }
} }
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {}); factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
float meangwx = factors.x / axis_size; float meangwx = factors.x / axis_size;
float normalizer = rsqrt(factors.y / axis_size + eps); float normalizer = rsqrt(factors.y / axis_size + eps);
float normalizer3 = normalizer * normalizer * normalizer; float normalizer3 = normalizer * normalizer * normalizer;
@@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) {
return s.device == Device::cpu; return s.device == Device::cpu;
} }
template <int n_per_thread, typename F>
void dispatch_group_dim(int axis_size, F&& f) {
if (axis_size <= n_per_thread * 8) {
f(std::integral_constant<int, 8>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 16>());
} else if (axis_size <= n_per_thread * 16) {
f(std::integral_constant<int, 16>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 8>());
} else if (axis_size <= n_per_thread * 32) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 1>(),
std::integral_constant<int, 4>());
} else if (axis_size <= n_per_thread * 32 * 2) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 2>(),
std::integral_constant<int, 2>());
} else if (axis_size <= n_per_thread * 32 * 4) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 4>(),
std::integral_constant<int, 1>());
} else if (axis_size <= n_per_thread * 32 * 8) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 8>(),
std::integral_constant<int, 1>());
} else if (axis_size <= n_per_thread * 32 * 16) {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 16>(),
std::integral_constant<int, 1>());
} else {
f(std::integral_constant<int, 32>{},
std::integral_constant<int, 32>(),
std::integral_constant<int, 1>());
}
}
// TODO: There are duplicate code with backend/metal/normalization.cpp // TODO: There are duplicate code with backend/metal/normalization.cpp
void RMSNorm::eval_gpu( void RMSNorm::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
@@ -190,7 +339,7 @@ void RMSNorm::eval_gpu(
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
out.set_data( out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
@@ -216,12 +365,33 @@ void RMSNorm::eval_gpu(
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType); constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { if (axis_size <= N_READS * 1024) {
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>; dispatch_group_dim<N_READS>(
axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
constexpr int block_dim = n_groups() * group_dim();
auto kernel =
cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;
auto n_blocks =
(n_rows + groups_per_block() - 1) / groups_per_block();
encoder.add_kernel_node(
kernel,
n_blocks,
{block_dim, groups_per_block()},
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(out),
eps_,
axis_size,
n_rows,
w_stride);
});
} else {
auto kernel = cu::rms_norm<DataType, 1024, N_READS>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, kernel,
n_rows, n_rows,
block_dim(), 1024,
0, 0,
gpu_ptr<DataType>(x), gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w), gpu_ptr<DataType>(w),
@@ -229,7 +399,7 @@ void RMSNorm::eval_gpu(
eps_, eps_,
axis_size, axis_size,
w_stride); w_stride);
}); }
}); });
} }
@@ -274,7 +444,7 @@ void RMSNormVJP::eval_gpu(
gx.copy_shared_buffer(g); gx.copy_shared_buffer(g);
g_in_gx = true; g_in_gx = true;
} else { } else {
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream())); gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
} }
if (g_copied && !g_in_gx) { if (g_copied && !g_in_gx) {
encoder.add_temporary(g); encoder.add_temporary(g);
@@ -292,7 +462,7 @@ void RMSNormVJP::eval_gpu(
if (!g_in_gx && donate_g) { if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g); gw_temp.copy_shared_buffer(g);
} else { } else {
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream())); gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
encoder.add_temporary(gw_temp); encoder.add_temporary(gw_temp);
} }
} }
@@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu(
dispatch_bool(has_w, [&](auto has_w_constant) { dispatch_bool(has_w, [&](auto has_w_constant) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 16 / sizeof(DataType); constexpr int N_READS = 16 / sizeof(DataType);
dispatch_block_dim( if (axis_size <= N_READS * 1024) {
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { dispatch_group_dim<N_READS>(
auto kernel = cu::rms_norm_vjp< axis_size,
DataType, [&](auto group_dim, auto n_groups, auto groups_per_block) {
has_w_constant.value, constexpr int block_dim = group_dim() * n_groups();
block_dim(), auto kernel = cu::rms_norm_vjp_small<
N_READS>; DataType,
encoder.add_kernel_node( has_w_constant.value,
kernel, block_dim,
n_rows, group_dim(),
block_dim(), N_READS>;
0, auto n_blocks =
gpu_ptr<DataType>(x), (n_rows + groups_per_block() - 1) / groups_per_block();
gpu_ptr<DataType>(w), encoder.add_kernel_node(
gpu_ptr<DataType>(g), kernel,
gpu_ptr<DataType>(gx), n_blocks,
gpu_ptr<DataType>(gw_temp), {block_dim, groups_per_block()},
eps_, 0,
axis_size, gpu_ptr<DataType>(x),
w_stride); gpu_ptr<DataType>(w),
}); gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
n_rows,
w_stride);
});
} else {
auto kernel =
cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;
encoder.add_kernel_node(
kernel,
n_rows,
1024,
0,
gpu_ptr<DataType>(x),
gpu_ptr<DataType>(w),
gpu_ptr<DataType>(g),
gpu_ptr<DataType>(gx),
gpu_ptr<DataType>(gw_temp),
eps_,
axis_size,
w_stride);
}
}); });
}); });

View File

@@ -292,14 +292,14 @@ void RoPE::eval_gpu(
donated = true; donated = true;
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
} }
strides[0] = mat_size; strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2]; strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1]; strides[2] = in.strides()[ndim - 1];
} else if (dispatch_ndim == 3) { } else if (dispatch_ndim == 3) {
// Handle non-contiguous 3D inputs // Handle non-contiguous 3D inputs
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
strides[0] = in.strides()[ndim - 3]; strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2]; strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1]; strides[2] = in.strides()[ndim - 1];

View File

@@ -0,0 +1,506 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/cudnn_utils.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast_primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
array prepare_sdpa_input(const array& x, Stream s) {
// SDPA kernel's requirements on inputs:
// 1. last dim's stride be 1;
// 2. pointer be aligned.
if (x.strides(-1) != 1 || get_alignment(x) < 16) {
array x_copy = contiguous_copy_gpu(x, s);
auto& encoder = cu::get_command_encoder(s);
encoder.add_temporary(x_copy);
return x_copy;
}
return x;
}
void malloc_with_same_layout(
cu::CommandEncoder& encoder,
array& o,
const array& q) {
if (q.flags().row_contiguous) {
o.set_data(cu::malloc_async(o.nbytes(), encoder));
return;
}
// fill_order = argsort(q.strides())
Shape fill_order(q.ndim());
std::iota(fill_order.begin(), fill_order.end(), 0);
std::stable_sort(
fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {
auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;
auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;
return s1 < s2;
});
// Generate o_strides with fill_order
Strides o_strides(q.ndim());
int64_t stride = 1;
for (int i : fill_order) {
o_strides[i] = stride;
stride *= o.shape(i);
}
// o is a transposed contiguous array
o.set_data(
cu::malloc_async(o.nbytes(), encoder),
o.size(),
o_strides,
{true, false, false});
}
constexpr int QKV_NDIM = 4;
struct SDPACacheKey {
int device_id;
fe::DataType_t cudnn_dtype;
std::array<int, QKV_NDIM> q_shape;
std::array<int, QKV_NDIM> k_shape;
std::array<int, QKV_NDIM> v_shape;
std::array<int64_t, QKV_NDIM> q_strides;
std::array<int64_t, QKV_NDIM> k_strides;
std::array<int64_t, QKV_NDIM> v_strides;
bool do_causal;
std::array<int, QKV_NDIM> mask_shape;
std::array<int64_t, QKV_NDIM> mask_strides;
bool output_logsumexp;
};
inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
cu::CommandEncoder& encoder,
const array& q,
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp = true) {
BytesKey<SDPACacheKey> cache_key;
cache_key.pod = {
encoder.device().cuda_device(),
dtype_to_cudnn_type(q.dtype()),
vector_key<QKV_NDIM>(q.shape()),
vector_key<QKV_NDIM>(k.shape()),
vector_key<QKV_NDIM>(v.shape()),
vector_key<QKV_NDIM>(q.strides()),
vector_key<QKV_NDIM>(k.strides()),
vector_key<QKV_NDIM>(v.strides()),
do_causal,
{},
{},
output_logsumexp,
};
if (mask_arr) {
cache_key.pod.mask_shape = vector_key<QKV_NDIM>(mask_arr->shape());
cache_key.pod.mask_strides = vector_key<QKV_NDIM>(mask_arr->strides());
}
return cache_key;
}
auto& sdpa_cache() {
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
return cache;
}
auto& sdpa_backward_cache() {
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
return cache;
}
enum UIDS {
Q,
K,
V,
SCALE,
BIAS,
O,
STATS,
// Backward graph:
D_Q,
D_K,
D_V,
D_O,
};
DnnGraph build_sdpa_graph(
cudnnHandle_t handle,
const array& q,
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp,
const array& o,
const array& stats) {
DnnGraph graph(handle, q.dtype());
auto q_ = graph.tensor("Q", Q, q);
auto k_ = graph.tensor("K", K, k);
auto v_ = graph.tensor("V", V, v);
auto options = fe::graph::SDPA_attributes()
.set_name("sdpa_cudnn")
.set_attn_scale(graph.scalar("Scale", SCALE, float32))
.set_generate_stats(output_logsumexp);
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) {
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
}
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
graph.tensor(o_, O, o)->set_output(true);
if (output_logsumexp) {
graph.tensor(stats_, STATS, stats)->set_output(true);
}
CHECK_CUDNN_FE_ERROR(graph.prepare());
graph.select_behavior_notes(
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
CHECK_CUDNN_FE_ERROR(graph.build());
return graph;
}
DnnGraph build_sdpa_backward_graph(
cudnnHandle_t handle,
const array& q,
const array& k,
const array& v,
bool do_causal,
const std::optional<array>& mask_arr,
const array& o,
const array& d_o,
const array& stats,
array& d_q,
array& d_k,
array& d_v) {
DnnGraph graph(handle, q.dtype());
auto q_ = graph.tensor("Q", Q, q);
auto k_ = graph.tensor("K", K, k);
auto v_ = graph.tensor("V", V, v);
auto o_ = graph.tensor("O", O, o);
auto d_o_ = graph.tensor("D_O", D_O, d_o);
auto stats_ = graph.tensor("STATS", STATS, stats);
auto options = fe::graph::SDPA_backward_attributes()
.set_name("sdpa_backward_cudnn")
.set_attn_scale(graph.scalar("Scale", SCALE, float32));
if (do_causal) {
if (q.shape(2) > k.shape(2)) {
options.set_causal_mask(do_causal);
} else {
options.set_causal_mask_bottom_right(do_causal);
}
}
if (mask_arr) {
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
}
auto [d_q_, d_k_, d_v_] =
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
graph.tensor(d_q_, D_Q, d_q)->set_output(true);
graph.tensor(d_k_, D_K, d_k)->set_output(true);
graph.tensor(d_v_, D_V, d_v)->set_output(true);
CHECK_CUDNN_FE_ERROR(graph.prepare());
graph.select_behavior_notes(
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
CHECK_CUDNN_FE_ERROR(graph.build());
return graph;
}
} // namespace
bool supports_sdpa_cudnn(
const array& q,
const array& k,
const array& v,
bool do_causal,
Stream s) {
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
if (!enabled) {
return false;
}
// cuDNN SDPA requires Ampere and later.
if (cu::device(s.device).compute_capability_major() < 8) {
return false;
}
// Only use cuDNN for prefilling (T_q > 1) and training (T_q == T_kv).
if ((q.shape(2) == 1) && (q.shape(2) != k.shape(2))) {
return false;
}
// D_qk and D_v must be a multiple of 8 with maximum value 128.
if ((q.shape(-1) % 8 != 0) || (q.shape(-1) > 128) || (v.shape(-1) % 8 != 0) ||
(v.shape(-1) > 128)) {
return false;
}
Dtype dtype = q.dtype();
return dtype == float16 || dtype == bfloat16;
}
void sdpa_cudnn(
const array& q,
const array& k,
const array& v,
float scale,
array& o,
array& stats,
bool do_causal,
const std::optional<array>& mask_arr,
bool output_logsumexp,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
malloc_with_same_layout(encoder, o, q);
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_output_array(o);
if (mask_arr) {
encoder.set_input_array(*mask_arr);
}
if (output_logsumexp) {
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
encoder.set_output_array(stats);
}
// Search cache.
auto cache_key = build_sdpa_cache_key(
encoder, q, k, v, do_causal, mask_arr, output_logsumexp);
auto it = sdpa_cache().find(cache_key);
if (it == sdpa_cache().end()) {
auto graph = build_sdpa_graph(
handle, q, k, v, do_causal, mask_arr, output_logsumexp, o, stats);
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
std::unordered_map<int64_t, void*> variant_pack{
{Q, gpu_ptr<void>(q)},
{K, gpu_ptr<void>(k)},
{V, gpu_ptr<void>(v)},
{SCALE, &scale},
{O, gpu_ptr<void>(o)}};
if (mask_arr) {
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
}
if (output_logsumexp) {
variant_pack[STATS] = gpu_ptr<void>(stats);
}
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
}
void sdpa_backward_cudnn(
const array& q,
const array& k,
const array& v,
float scale,
const array& o,
const array& stats,
bool do_causal,
const std::optional<array>& mask_arr,
const array& d_o,
array& d_q,
array& d_k,
array& d_v,
Stream s) {
auto& encoder = cu::get_command_encoder(s);
auto handle = encoder.device().cudnn_handle();
malloc_with_same_layout(encoder, d_q, q);
malloc_with_same_layout(encoder, d_k, k);
malloc_with_same_layout(encoder, d_v, v);
encoder.set_input_array(q);
encoder.set_input_array(k);
encoder.set_input_array(v);
encoder.set_input_array(o);
encoder.set_input_array(stats);
encoder.set_input_array(d_o);
encoder.set_output_array(d_q);
encoder.set_output_array(d_k);
encoder.set_output_array(d_v);
if (mask_arr) {
encoder.set_input_array(*mask_arr);
}
// Search cache.
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr);
auto it = sdpa_backward_cache().find(cache_key);
if (it == sdpa_backward_cache().end()) {
auto graph = build_sdpa_backward_graph(
handle, q, k, v, do_causal, mask_arr, o, d_o, stats, d_q, d_k, d_v);
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
}
auto& graph = it->second;
std::unordered_map<int64_t, void*> variant_pack{
{Q, gpu_ptr<void>(q)},
{K, gpu_ptr<void>(k)},
{V, gpu_ptr<void>(v)},
{SCALE, &scale},
{O, gpu_ptr<void>(o)},
{STATS, gpu_ptr<void>(stats)},
{D_O, gpu_ptr<void>(d_o)},
{D_Q, gpu_ptr<void>(d_q)},
{D_K, gpu_ptr<void>(d_k)},
{D_V, gpu_ptr<void>(d_v)}};
if (mask_arr) {
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
}
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
}
// Defined in scaled_dot_product_attention.cu file.
bool supports_sdpa_vector(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
bool output_logsumexp);
void sdpa_vector(
const array& q,
const array& k,
const array& v,
float scale,
array& o,
bool do_causal,
const std::optional<array>& sinks,
Stream s);
namespace fast {
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
bool is_training,
bool output_logsumexp,
Stream s) {
if (s.device == Device::cpu) {
return true;
}
return !supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
!supports_sdpa_cudnn(q, k, v, do_causal, s);
}
bool ScaledDotProductAttention::supports_bool_mask() {
return false;
}
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
auto& s = stream();
array q = prepare_sdpa_input(inputs[0], s);
array k = prepare_sdpa_input(inputs[1], s);
array v = prepare_sdpa_input(inputs[2], s);
auto& out = outputs[0];
auto& stats = outputs[1];
bool has_mask = inputs.size() - has_sinks_ > 3;
bool has_arr_mask = has_mask && !do_causal_;
std::optional<array> mask_arr;
if (has_arr_mask) {
mask_arr = prepare_sdpa_input(inputs[3], s);
}
if (supports_sdpa_vector(
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
if (has_sinks_) {
sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s);
} else {
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
}
} else {
sdpa_cudnn(
q,
k,
v,
scale_,
out,
stats,
do_causal_,
mask_arr,
output_logsumexp_,
s);
}
}
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
// The frontend adds a padding mask when sequence length is not a multiple of
// tile size.
if (q.shape(2) % 128 != 0) {
return true;
}
return s.device == Device::cpu;
}
void ScaledDotProductAttentionVJP::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("ScaledDotProductAttentionVJP::eval_gpu");
auto& s = stream();
assert(inputs.size() >= 6);
int primals_size = inputs.size() - 3;
bool has_arr_mask = primals_size > 3 + has_sinks_;
array q = prepare_sdpa_input(inputs[0], s);
array k = prepare_sdpa_input(inputs[1], s);
array v = prepare_sdpa_input(inputs[2], s);
array o = prepare_sdpa_input(inputs[primals_size], s);
array stats = prepare_sdpa_input(inputs[primals_size + 1], s);
array d_o = prepare_sdpa_input(inputs[primals_size + 2], s);
std::optional<array> mask_arr;
if (has_arr_mask) {
mask_arr = prepare_sdpa_input(inputs[3], s);
}
assert(outputs.size() == 3);
auto& d_q = outputs[0];
auto& d_k = outputs[1];
auto& d_v = outputs[2];
sdpa_backward_cudnn(
q, k, v, scale_, o, stats, do_causal_, mask_arr, d_o, d_q, d_k, d_v, s);
}
} // namespace fast
} // namespace mlx::core

View File

@@ -6,10 +6,6 @@
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include <nvtx3/nvtx3.hpp>
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include <cooperative_groups/reduce.h> #include <cooperative_groups/reduce.h>
@@ -565,10 +561,9 @@ void sdpa_vector_2pass_fallback(
array sums(intermediate_shape, float32, nullptr, {}); array sums(intermediate_shape, float32, nullptr, {});
array maxs(std::move(intermediate_shape), float32, nullptr, {}); array maxs(std::move(intermediate_shape), float32, nullptr, {});
intermediate.set_data( intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
cu::malloc_async(intermediate.nbytes(), encoder.stream())); sums.set_data(cu::malloc_async(sums.nbytes(), encoder));
sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream())); maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder));
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
encoder.add_temporary(intermediate); encoder.add_temporary(intermediate);
encoder.add_temporary(sums); encoder.add_temporary(sums);
@@ -663,21 +658,16 @@ void sdpa_vector_fallback(
} // namespace } // namespace
namespace fast { bool supports_sdpa_vector(
bool ScaledDotProductAttention::use_fallback(
const array& q, const array& q,
const array& k, const array& k,
const array& v, const array& v,
bool has_mask, bool has_mask,
bool has_arr_mask, bool has_arr_mask,
bool do_causal, bool do_causal,
Stream s) { bool output_logsumexp) {
if (detail::in_grad_tracing()) { if (output_logsumexp) {
return true; return false;
}
if (s.device == Device::cpu) {
return true;
} }
const int value_head_dim = v.shape(-1); const int value_head_dim = v.shape(-1);
@@ -691,29 +681,24 @@ bool ScaledDotProductAttention::use_fallback(
const bool supported_vector_config = const bool supported_vector_config =
sdpa_supported_head_dim && query_sequence_length < 4; sdpa_supported_head_dim && query_sequence_length < 4;
const bool supported_config = supported_vector_config; return supported_vector_config && !has_arr_mask;
return has_arr_mask || !supported_config;
} }
void ScaledDotProductAttention::eval_gpu( void sdpa_vector(
const std::vector<array>& inputs, const array& q_pre,
array& out) { const array& k_pre,
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu"); const array& v_pre,
float scale,
auto& s = stream(); array& o,
bool do_causal,
const std::optional<array>& sinks_pre,
Stream s) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out;
std::vector<array> copies; std::vector<array> copies;
// Define some copy functions to ensure the layout of the inputs is as // Define some copy functions to ensure the layout of the inputs is as
// expected. // expected.
copies.reserve(inputs.size()); copies.reserve(4);
auto copy_unless = [&copies, &s]( auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& { auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) { if (!predicate(arr)) {
@@ -731,8 +716,8 @@ void ScaledDotProductAttention::eval_gpu(
}; };
std::optional<array> sinks = std::nullopt; std::optional<array> sinks = std::nullopt;
if (has_sinks_) { if (sinks_pre) {
sinks = copy_unless(is_matrix_contiguous, inputs.back()); sinks = copy_unless(is_matrix_contiguous, sinks_pre.value());
} }
// We are in vector mode ie single query // We are in vector mode ie single query
@@ -788,7 +773,7 @@ void ScaledDotProductAttention::eval_gpu(
}; };
o.set_data( o.set_data(
cu::malloc_async(o.nbytes(), encoder.stream()), cu::malloc_async(o.nbytes(), encoder),
o.size(), o.size(),
{str_oB, str_oH, str_oL, str_oD}, {str_oB, str_oH, str_oL, str_oD},
flags); flags);
@@ -798,8 +783,7 @@ void ScaledDotProductAttention::eval_gpu(
encoder.add_temporary(cp); encoder.add_temporary(cp);
} }
return sdpa_vector_fallback( sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks);
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
} }
// Full attention mode should never reach here // Full attention mode should never reach here
@@ -808,6 +792,4 @@ void ScaledDotProductAttention::eval_gpu(
} }
} }
} // namespace fast
} // namespace mlx::core } // namespace mlx::core

View File

@@ -374,7 +374,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
out.set_data( out.set_data(
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), cu::malloc_async(in.data_size() * out.itemsize(), encoder),
in.data_size(), in.data_size(),
in.strides(), in.strides(),
in.flags()); in.flags());

View File

@@ -24,7 +24,7 @@ void concatenate_gpu(
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream())); out.set_data(cu::malloc_async(out.nbytes(), encoder));
auto strides = out.strides(); auto strides = out.strides();
auto flags = out.flags(); auto flags = out.flags();
@@ -89,7 +89,7 @@ array compute_dynamic_offset(
if (donate) { if (donate) {
offset.copy_shared_buffer(indices); offset.copy_shared_buffer(indices);
} else { } else {
offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream())); offset.set_data(cu::malloc_async(offset.itemsize(), encoder));
} }
encoder.add_temporary(offset); encoder.add_temporary(offset);

View File

@@ -118,7 +118,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
out.set_data( out.set_data(
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()), cu::malloc_async(x.data_size() * x.itemsize(), encoder),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());

View File

@@ -49,14 +49,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array trans = swapaxes_in_eval(in, axis, last_dim); array trans = swapaxes_in_eval(in, axis, last_dim);
in = contiguous_copy_gpu(trans, s); in = contiguous_copy_gpu(trans, s);
encoder.add_temporary(in); encoder.add_temporary(in);
out = array( out =
cu::malloc_async(out.nbytes(), encoder.stream()), array(cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
in.shape(),
out.dtype());
encoder.add_temporary(out); encoder.add_temporary(out);
} else { } else {
out.set_data( out.set_data(
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()), cu::malloc_async(in.data_size() * out.itemsize(), encoder),
in.data_size(), in.data_size(),
in.strides(), in.strides(),
in.flags()); in.flags());
@@ -74,17 +72,13 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
if (argsort) { if (argsort) {
// Indices in the sorted dimension. // Indices in the sorted dimension.
array indices( array indices(
cu::malloc_async(out.nbytes(), encoder.stream()), cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
in.shape(),
out.dtype());
encoder.add_temporary(indices); encoder.add_temporary(indices);
// In argsort though we don't need the result of sorted values, the // In argsort though we don't need the result of sorted values, the
// API requires us to provide an array to store it. // API requires us to provide an array to store it.
array discard( array discard(
cu::malloc_async(in.nbytes(), encoder.stream()), cu::malloc_async(in.nbytes(), encoder), in.shape(), in.dtype());
in.shape(),
in.dtype());
encoder.add_temporary(discard); encoder.add_temporary(discard);
size_t size; size_t size;
@@ -104,9 +98,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
stream)); stream));
array temp( array temp(
cu::malloc_async(size, encoder.stream()), cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
{static_cast<int>(size)},
uint8);
encoder.add_temporary(temp); encoder.add_temporary(temp);
// Start capturing after allocations // Start capturing after allocations
@@ -148,9 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
stream)); stream));
array temp( array temp(
cu::malloc_async(size, encoder.stream()), cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
{static_cast<int>(size)},
uint8);
encoder.add_temporary(temp); encoder.add_temporary(temp);
// Start capturing after allocations // Start capturing after allocations

View File

@@ -3,31 +3,10 @@
#pragma once #pragma once
#include "mlx/backend/cuda/steel/utils.cuh" #include "mlx/backend/cuda/steel/utils.cuh"
#include "mlx/backend/cuda/vector_types.cuh"
namespace mlx::core::cu { namespace mlx::core::cu {
// Map types to their vector of 2 type float -> float2, double -> double2 etc
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
/** /**
* The basic building block for Ampere mmas. A 16x16 tile distributed across * The basic building block for Ampere mmas. A 16x16 tile distributed across
* the warp. * the warp.

View File

@@ -257,9 +257,8 @@ void ternary_op_gpu(
auto& c = inputs[2]; auto& c = inputs[2];
auto topt = get_ternary_op_type(a, b, c); auto topt = get_ternary_op_type(a, b, c);
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) { set_ternary_op_output_data(
return cu::malloc_async(n, encoder.stream()); a, b, c, out, topt, [&](auto n) { return cu::malloc_async(n, encoder); });
});
ternary_op_gpu_inplace<Op>(inputs, out, s); ternary_op_gpu_inplace<Op>(inputs, out, s);
} }

View File

@@ -208,9 +208,8 @@ void unary_op_gpu(
const char* op, const char* op,
const Stream& s) { const Stream& s) {
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
set_unary_output_data(inputs[0], out, [&](auto n) { set_unary_output_data(
return cu::malloc_async(n, encoder.stream()); inputs[0], out, [&](auto n) { return cu::malloc_async(n, encoder); });
});
unary_op_gpu_inplace<Op>(inputs, out, op, s); unary_op_gpu_inplace<Op>(inputs, out, op, s);
} }

View File

@@ -5,6 +5,7 @@
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include <fmt/format.h> #include <fmt/format.h>
#include <vector>
namespace mlx::core { namespace mlx::core {
@@ -31,6 +32,13 @@ void check_cuda_error(const char* name, CUresult err) {
} }
} }
void check_cudnn_error(const char* name, cudnnStatus_t err) {
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
}
}
const char* dtype_to_cuda_type(const Dtype& dtype) { const char* dtype_to_cuda_type(const Dtype& dtype) {
switch (dtype) { switch (dtype) {
case bool_: case bool_:
@@ -60,7 +68,7 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
case float64: case float64:
return "double"; return "double";
case complex64: case complex64:
return "complex64_t"; return "mlx::core::cu::complex64_t";
default: default:
return "unknown"; return "unknown";
} }
@@ -72,7 +80,6 @@ CudaGraph::CudaGraph(cu::Device& device) {
} }
void CudaGraph::end_capture(cudaStream_t stream) { void CudaGraph::end_capture(cudaStream_t stream) {
assert(handle_ == nullptr);
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_)); CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
} }

View File

@@ -31,8 +31,10 @@ inline T* gpu_ptr(array& arr) {
arr.offset()); arr.offset());
} }
// For const array, keep constness in pointer unless it is untyped.
template <typename T> template <typename T>
inline const T* gpu_ptr(const array& arr) { inline std::conditional_t<std::is_same_v<T, void>, void*, const T*> gpu_ptr(
const array& arr) {
return gpu_ptr<T>(const_cast<array&>(arr)); return gpu_ptr<T>(const_cast<array&>(arr));
} }

View File

@@ -0,0 +1,48 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace mlx::core::cu {
template <typename T>
struct Vector2;
template <>
struct Vector2<double> {
using type = double2;
};
template <>
struct Vector2<float> {
using type = float2;
};
template <>
struct Vector2<__half> {
using type = __half2;
};
template <>
struct Vector2<__nv_bfloat16> {
using type = __nv_bfloat162;
};
template <typename T>
using Vector2_t = typename Vector2<T>::type;
template <typename T>
struct Vector4 {
T x, y, z, w;
};
template <typename T>
using Vector4_t = Vector4<T>;
using bf16x4 = Vector4_t<__nv_bfloat16>;
using fp16x4 = Vector4_t<__half>;
using fp32x4 = Vector4_t<float>;
} // namespace mlx::core::cu

View File

@@ -44,7 +44,7 @@ void Worker::commit(cudaStream_t stream) {
} }
signal_event_.record(stream); signal_event_.record(stream);
signal_event_.wait(signal_stream_); signal_event_.wait(signal_stream_);
cudaLaunchHostFunc(signal_stream_, signal, this); CHECK_CUDA_ERROR(cudaLaunchHostFunc(signal_stream_, signal, this));
} }
void Worker::thread_fn() { void Worker::thread_fn() {

View File

@@ -7,8 +7,6 @@
namespace mlx::core { namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& in, array& out, CopyType ctype) { void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream()); copy_gpu(in, out, ctype, out.primitive().stream());
} }

View File

@@ -11,7 +11,7 @@ void slice_gpu(
array& out, array& out,
const Shape& start_indices, const Shape& start_indices,
const Shape& strides, const Shape& strides,
const Stream& s) { const Stream&) {
slice(in, out, start_indices, strides); slice(in, out, start_indices, strides);
} }

View File

@@ -28,6 +28,7 @@ make_jit_source(binary_ops)
make_jit_source(ternary_ops) make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h) make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(indexing/scatter kernels/indexing/indexing.h) make_jit_source(indexing/scatter kernels/indexing/indexing.h)
make_jit_source(indexing/masked_scatter)
make_jit_source(indexing/gather kernels/indexing/indexing.h) make_jit_source(indexing/gather kernels/indexing/indexing.h)
make_jit_source(indexing/gather_front kernels/indexing/indexing.h) make_jit_source(indexing/gather_front kernels/indexing/indexing.h)
make_jit_source(indexing/gather_axis) make_jit_source(indexing/gather_axis)

View File

@@ -149,7 +149,9 @@ Buffer MetalAllocator::malloc(size_t size) {
buf = device_->newBuffer(size, resource_options); buf = device_->newBuffer(size, resource_options);
} }
if (!buf) { if (!buf) {
return Buffer{nullptr}; std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
} }
lk.lock(); lk.lock();
num_resources_++; num_resources_++;
@@ -201,6 +203,32 @@ size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length(); return static_cast<MTL::Buffer*>(buffer.ptr())->length();
} }
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
if (!buf) {
return Buffer{nullptr};
}
std::unique_lock lk(mutex_);
residency_set_.insert(buf);
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
num_resources_++;
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::release(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_);
active_memory_ -= buf->length();
num_resources_--;
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();
}
MetalAllocator& allocator() { MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator // By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This // will not be called on exit and buffers in the cache will be leaked. This

View File

@@ -21,6 +21,9 @@ class MetalAllocator : public allocator::Allocator {
virtual Buffer malloc(size_t size) override; virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override; virtual size_t size(Buffer buffer) const override;
virtual Buffer make_buffer(void* ptr, size_t size) override;
virtual void release(Buffer buffer) override;
size_t get_active_memory() { size_t get_active_memory() {
return active_memory_; return active_memory_;
}; };

Some files were not shown because too many files have changed in this diff Show More