Compare commits

...

89 Commits

Author SHA1 Message Date
Awni Hannun
60d80a3728 fix release builds (#2746) 2025-11-11 07:44:30 -08:00
Pedro Cuenca
eba6a9d163 Compatibility with pip-installed openmpi (#2741)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-07 16:58:31 -08:00
CCYeh
be9e2aebd6 Shapeless support for zeros/ones_like (#2726)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* shapeless support for zeros/ones_like

* Improvements

* fix access after moved
2025-11-06 19:12:20 -08:00
Awni Hannun
df58b4133a [CUDA] Reduce use of managed memory (#2725)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* Use async cuda malloc managed with cuda 13

* add pool threshold

* refactor for regular cuda malloc

* load eval gpu for cuda

* remove use of cuda pool, use cuda free async

* fix

* fix

* fix

* fix

* fix + comment
2025-11-05 16:05:23 -08:00
Anastasiia Filippova
27778156dc Nccl reduce scatter, all gather (#2727)
* Added reduce scatter and all gather for nccl

* fix unused import, delete unused file

* small fix

* deleted useless condition

* fixed comments

* fix bug in eval_gpu, renamed to sum_scatter, fix docs

* final fix docs

* remove and

* Update mlx/distributed/mpi/mpi.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix broken set input output

* fixes set output

* typo

* fix typo

* no cpu, no gpu for reduce scatter

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-11-05 08:21:11 -08:00
Mike Drob
761f901a41 fix property name (#2736) 2025-11-05 06:31:56 -06:00
Angelos Katharopoulos
6ece97f69b Make cpu binary_op easily accessible (#2733) 2025-11-05 01:08:41 -08:00
Awni Hannun
d3bc6a9bff don't test when doing release (#2734)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-04 15:54:23 -08:00
Awni Hannun
26ceb507eb only build for macos 14 and up (#2731)
* only build for macos 14 and up

* bump metal cpp
2025-11-04 09:44:15 -08:00
Mike Drob
910b3e3299 skip self-hosted runners on forks (#2730)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
2025-11-03 16:22:13 -06:00
Harsh Sutaria
50fa315d18 Fix addmm with empty matrices and beta != 1.0 (#2715) 2025-11-03 14:16:15 -08:00
AN Long
1ff2b713b6 Check isnan in maximum / minimum with CPU backend (#2652)
* Check isnan in maximum / minimum with CPU backend

* Add tests

* fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-11-03 08:51:14 -08:00
Mike Drob
50514a6146 Set up publishing to PyPI and Test-PyPI (#2721) 2025-11-03 07:20:11 -08:00
Awni Hannun
93d76b0f30 Fix compile multi capture (#2678)
* fix compile when compiling multiple lambdas with the same capture

* add test
2025-11-03 06:33:43 -08:00
David Koski
78678de0cd add null check -- the bundleIdentifier is optional (#2709)
* add null check -- the bundleIdentifier is optional

* use variable
2025-11-03 06:33:21 -08:00
Melissa Kilby
ed9c6b1117 update: add linux fedora container CI - CPP build test only (#2722)
* update: add linux_fedora_build_cpp CI - CPP build test only - x86-64

Signed-off-by: Melissa Kilby <mkilby@apple.com>

* update: add linux_fedora_build_cpp_aarch64 CI - CPP build test only - arm64

Co-authored-by: Mike Drob <mdrob@apple.com>
Signed-off-by: Melissa Kilby <mkilby@apple.com>

* update: convert linux_fedora_build_cpp to matrix.arch loop

Co-authored-by: Mike Drob <mdrob@apple.com>
Signed-off-by: Melissa Kilby <mkilby@apple.com>

---------

Signed-off-by: Melissa Kilby <mkilby@apple.com>
Co-authored-by: Mike Drob <mdrob@apple.com>
2025-11-03 06:33:00 -08:00
Awni Hannun
39b04ce638 use faster dequant for fp4 qmv (#2720)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-10-31 11:49:59 -07:00
Mike Drob
d9e6349657 fix docs path (#2719) 2025-10-30 19:12:49 -05:00
Angelos Katharopoulos
b901a9f311 Fix the order of hosts in the ring (#2718)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-10-30 15:02:39 -07:00
Awni Hannun
68c5fa1c95 fix memory count bug (#2717) 2025-10-30 14:27:15 -07:00
Christopher Webb
793a31eeb6 Fix missing domain_uuid_key in thunderbolt ring setup (#2682) 2025-10-30 13:17:20 -07:00
Mike Drob
74c1ed25bb Migrate CircleCI to GitHub Actions (#2716)
Co-authored-by: Joseph Heck <j_heck@apple.com>
2025-10-30 12:26:55 -05:00
Awni Hannun
ec72b44417 Add quantize/dequantize for mxfp8 and nvfp4 (#2688)
* Add quantize/dequantize slow path for mxfp8 and nvfp4

* fast cuda kernel for mx/nv quantization

* fallback for cuda < 12.8 (#2697)

* format (#2700)

* fix (#2701)

* metal kernels

* docs

* fix jit

* add default bits and group sizes

* improve quant docs

* fix output type of mxfp4 matmuls
2025-10-28 16:23:12 -07:00
Melissa Kilby
460691a0e8 fix: linux-{fedora}x86_64-build (#2707)
Signed-off-by: Melissa Kilby <mkilby@apple.com>
2025-10-27 16:36:08 -07:00
Awni Hannun
969924cc69 Fp8 conversion (#2686)
* add fp8 e4m3 converters

* add cuda

* default saturate to min/max

* fix for older OS

* fix no gpu/cpu

* fix saturate

* fix compile
2025-10-27 16:35:50 -07:00
Awni Hannun
d1e06117e8 bump python (#2694) 2025-10-27 11:34:31 -07:00
Awni Hannun
539d8322d1 add median op (#2705) 2025-10-27 11:33:42 -07:00
Awni Hannun
c4767d110f fix addmm cpu (#2699) 2025-10-27 11:33:32 -07:00
David Koski
895217f25b optionally load metallib from framework (#2702)
* optionally load metallib from framework

* pre-commit

* adjust logic
2025-10-27 07:52:03 -07:00
Manuel Villanueva
0cfeeb60ca Einsum error msg improvement (#2690)
* Improved error message for Einsum

* Modifications via pre-commit

* format

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-27 06:31:47 -07:00
Ronan Collobert
8f8af61a37 fix warnings showing up with -Wall (#2692) 2025-10-24 11:43:35 -07:00
Manuel Villanueva
233384161e Improved mx.split() docs (#2689)
* Improved mx.split() documentation

* Fix typo in docstring for array split function

* add example

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-24 09:48:41 -07:00
Awni Hannun
5bcf3a6794 format 2025-10-22 16:08:47 -07:00
wickedcoder
7707196297 Merge commit from fork
* add length validation to the header

* fix accessing out of bound index with .at()
2025-10-22 15:31:25 -07:00
wickedcoder
7e3471c987 Merge commit from fork
* add tensor->weights_data validation

* add null pointer check for tensor
2025-10-22 15:31:03 -07:00
Awni Hannun
9f0ba3ddf1 patch bump (#2680) 2025-10-17 12:12:07 -07:00
Awni Hannun
4bce5f9b2d suppress gcc 10.1 warnings (#2679)
* suppress gcc 10.1 warnings

* suppress gcc 10.1 warnings
2025-10-17 12:09:21 -07:00
Anastasiia Filippova
e9eab527eb Nccl timeout (#2673)
* print the error & delete nccl group

* timeout for nccl binding

* typo

* revert error

* fixed a typo
2025-10-14 12:29:54 -07:00
Awni Hannun
36ca62dba8 remove unused unary file (#2672) 2025-10-13 19:36:26 -07:00
Manuel Villanueva
9cbb1b0148 Modified sort behavior when running CPU or Metal to match NumPy/JAX (#2667)
* Modified sort behavior when running CPU or Metal to match NumPy/JAX sorting behavior.

* Modified sort behavior when running CPU or Metal to match NumPy/JAX

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-10-13 14:36:45 -07:00
Fabrizio Milo
9bfc476d72 Normalize README bullet formatting (#2671) 2025-10-13 12:13:30 -07:00
Awni Hannun
25e2356316 speed up scalars (#2669) 2025-10-13 12:10:15 -07:00
Awni Hannun
226a1d24e0 Debug cuda conv (#2662)
* use t4

* use t4
2025-10-10 16:12:47 -07:00
Awni Hannun
630350ad3e Precise sigmoid (#2659)
* bump patch

* Sigmoid matches PyTorch and is more precise on tails
2025-10-10 10:05:23 -07:00
Awni Hannun
380aeb58ae enable admm low-precision cpu (#2661) 2025-10-10 09:50:54 -07:00
Awni Hannun
f37389d100 bump patch (#2658) 2025-10-10 08:36:41 -07:00
Awni Hannun
e89e8b4272 Export with callback (#2612)
* export with callback

* export with callback

* Add types, fix kwarg ordering bug + test

* cleanup, test, fix

* typos
2025-10-08 19:24:33 -07:00
AN Long
85a8824a8c Fix cumulative operations when axis=None (#2653) 2025-10-08 15:25:38 -07:00
Awni Hannun
f5d4397e5c Fix fast synch when fence is waited before a command buffer is created (#2657) 2025-10-08 11:23:46 -07:00
Awni Hannun
343e33b6d5 fix all_gather vjp (#2654) 2025-10-07 06:05:23 -07:00
Angelos Katharopoulos
0073096dd1 Split name into directories for cuda jit (#2656) 2025-10-07 01:52:58 -07:00
Angelos Katharopoulos
e3d004fed9 Fix and refactor row-reduce (#2650) 2025-10-07 01:51:08 -07:00
Awni Hannun
a393435d28 Speed up compile for node with many parents (#2649) 2025-10-03 19:30:36 -07:00
Awni Hannun
a7a94b29d7 Fix compile when outputs change (#2648) 2025-10-03 08:40:57 -07:00
Daniel Yeh
22a5da76c8 Faster complex matmul (#2571) 2025-10-02 23:33:15 -07:00
Andrey Portnoy
287c63a093 Configure CMake to export compile_commands.json (#2645)
This helps enable LSP for code navigation using clangd.
2025-10-02 15:40:32 -07:00
Awni Hannun
1c9ae1eaa1 cuda fix flaky test (#2646) 2025-10-02 15:40:04 -07:00
Angelos Katharopoulos
c2c3e0b0a2 [CUDA] Add a small column specialization to reduce (#2642) 2025-10-02 14:41:05 -07:00
Awni Hannun
b0cc71ae71 Faster triu, tril, where with scalar (#2644) 2025-10-02 12:21:27 -07:00
Awni Hannun
e88f2d4a8e fix cross entropy axis param (#2641)
* fix cross entropy axis param

* faster grad clipping
2025-10-01 16:49:55 -07:00
Angelos Katharopoulos
9cee557423 Fix status message (#2638) 2025-10-01 16:43:45 -07:00
Awni Hannun
bbf1423953 wait for tasks in cuda (#2636) 2025-09-30 16:08:46 -07:00
Angelos Katharopoulos
eb24267b56 Compile now can attach arbitrary data to an entry (#2634) 2025-09-30 13:33:27 -07:00
Awni Hannun
dc371ae7a5 fix for max block dim (#2631) 2025-09-29 08:59:25 -07:00
AN Long
e76a8dd5c5 Fix incorrect path and typos (#2630) 2025-09-28 06:03:04 -07:00
Cheng
b466dea982 [CUDA] Make CudaEvent work with multi-device (#2614)
* Set current device when creating cuda event

* Separate cuda events by device

* Avoid race condition in pool
2025-09-27 11:27:17 +09:00
Angelos Katharopoulos
7a6adda1e6 Bump the version (#2627) 2025-09-26 15:15:28 -07:00
Angelos Katharopoulos
1a9f820af6 Compiled should not end in broadcast (#2622) 2025-09-26 13:36:09 -07:00
Awni Hannun
d4f4ff3c5e Allow None input to compiled functions (#2621)
* Allow None input to compiled functions

* Allow None input to compiled functions
2025-09-25 08:42:23 -07:00
Jagrit Digani
7c7e48dbd1 New tuning for small K gemv (#2620)
* New tuning for small K gemv
2025-09-23 12:28:35 -07:00
Daniel Yeh
fbbf3b9b3e Support pickling array for bfloat16 (#2586)
* add bfloat16 pickling

* Improvements

* improve

---------

Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-09-22 20:12:15 -07:00
Daniel Yeh
bf01ad9367 fix (#2613)
Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
2025-09-22 20:12:04 -07:00
Cheng
ae438d05fa [CUDA] Recycle CUDA events (#2604)
* Make CudaEvent a CudaHandle

* Add caching for CudaEvent

* Make sure cuda events are destroyed at last

* Fix headers

* SharedEvent => AtomicEvent

* RawCudaEvent => CudaEventHandle, CudaEventWrapper => CopyableCudaEvent

* Remove unneeded asserts
2025-09-23 10:42:03 +09:00
Awni Hannun
711a645807 avoid producing NaN in attention (#2608) 2025-09-22 13:10:43 -07:00
Josh Bleecher Snyder
aa9d44b3d4 implement Convolution::output_shape (#2601)
- pull conv_out_shape out for re-use
- add Conv::output_shape
- add e2e python tests confirming shapeless=True support and correctness

Updates #2599
2025-09-22 10:09:45 -07:00
Awni Hannun
ec2ab42888 Lower sorted QMM gather threshold (#2609) 2025-09-19 18:22:55 -07:00
Cheng
787c0d90cd Detect cache thrashing in LRUCache (#2600)
* Detect cache thrashing in LRUCache

* Do not check cache thrashing in tests
2025-09-19 09:12:14 +09:00
Oleksandr Bilous
e8b604a6a3 fix: library loading for swift dynamic frameworks (#2568) 2025-09-18 13:54:59 -07:00
Awni Hannun
50cc09887f expose depends (#2606) 2025-09-18 10:06:15 -07:00
Umberto Mignozzetti
3f730e77aa Update export function example for array input (#2598)
After changing the shape to conform (same shapes for all objects), the example works.
2025-09-16 14:38:05 -07:00
Awni Hannun
caecbe876a no copy batch rope (#2595) 2025-09-15 14:23:48 -07:00
Umberto Mignozzetti
8afb6d62f2 Fix typo in average_gradients function call (#2594) 2025-09-15 11:29:21 -07:00
Awni Hannun
6ccfa603cd fix metal scan (#2591) 2025-09-15 11:01:57 -07:00
Umberto Mignozzetti
36cad99a11 Refactor code examples to use 'gelu' (#2592)
Updated code examples to use 'gelu' directly instead of 'nn.gelu'.
2025-09-15 09:47:02 -07:00
Awni Hannun
ee18e1cbf0 patch bump (#2588) 2025-09-11 17:10:09 -07:00
Awni Hannun
af120c2bc0 set nccl ABI version (#2587) 2025-09-11 16:55:53 -07:00
Cheng
6a3acf2301 [CUDA] Set bias as input when using bias epilogue (#2584) 2025-09-11 15:31:09 +09:00
Awni Hannun
d6977f2a57 Add sdpa with sinks (#2558)
* add sdpa with sinks

* fix 2 pass

* fix matrix sdpa

* fix perf regression

* add to cuda (#2580)
2025-09-10 14:53:00 -07:00
Gökdeniz Gülmez
db5443e831 Adding Relu2 (#2582)
* in. com.

* upd. ackn.

* update __init__

* nits

* nits + format

* used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer

* same with _make_activation_module

* Update python/mlx/nn/layers/activations.py

upd

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update funct.rst

* upd. layers.rst

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-09-10 07:24:30 -07:00
254 changed files with 7768 additions and 3712 deletions

View File

@@ -26,9 +26,9 @@ jobs:
name: Install name: Install
command: | command: |
xcodebuild -downloadComponent MetalToolchain xcodebuild -downloadComponent MetalToolchain
brew install python@3.9 brew install python@3.10
brew install doxygen brew install doxygen
python3.9 -m venv env python3.10 -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade cmake pip install --upgrade cmake
@@ -140,7 +140,7 @@ jobs:
- run: - run:
name: Install Python package name: Install Python package
command: | command: |
uv venv --python 3.9 uv venv --python 3.10
uv pip install \ uv pip install \
nanobind==2.4.0 \ nanobind==2.4.0 \
cmake \ cmake \
@@ -273,7 +273,7 @@ jobs:
parameters: parameters:
python_version: python_version:
type: string type: string
default: "3.9" default: "3.10"
xcode_version: xcode_version:
type: string type: string
default: "26.0.0" default: "26.0.0"
@@ -328,7 +328,7 @@ jobs:
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w << parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
- when: - when:
condition: condition:
equal: ["3.9", << parameters.python_version >>] equal: ["3.10", << parameters.python_version >>]
steps: steps:
- run: - run:
name: Build common package name: Build common package
@@ -351,7 +351,7 @@ jobs:
parameters: parameters:
python_version: python_version:
type: string type: string
default: "3.9" default: "3.10"
build_env: build_env:
type: string type: string
default: "" default: ""
@@ -387,7 +387,7 @@ jobs:
bash python/scripts/repair_linux.sh bash python/scripts/repair_linux.sh
- when: - when:
condition: condition:
equal: ["3.9", << parameters.python_version >>] equal: ["3.10", << parameters.python_version >>]
steps: steps:
- run: - run:
name: Build common package name: Build common package
@@ -484,7 +484,7 @@ workflows:
ignore: /.*/ ignore: /.*/
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
xcode_version: ["26.0.0"] xcode_version: ["26.0.0"]
@@ -503,7 +503,7 @@ workflows:
ignore: /.*/ ignore: /.*/
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["PYPI_RELEASE=1"] build_env: ["PYPI_RELEASE=1"]
- build_cuda_release: - build_cuda_release:
filters: filters:
@@ -546,13 +546,13 @@ workflows:
- build_release: - build_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
xcode_version: ["26.0.0"] xcode_version: ["26.0.0"]
- build_linux_release: - build_linux_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
- build_cuda_release - build_cuda_release
build_dev_release: build_dev_release:
@@ -564,14 +564,14 @@ workflows:
- build_release: - build_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
macosx_deployment_target: ["13.5", "14.0", "15.0"] macosx_deployment_target: ["13.5", "14.0", "15.0"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
xcode_version: ["26.0.0"] xcode_version: ["26.0.0"]
- build_linux_release: - build_linux_release:
matrix: matrix:
parameters: parameters:
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
build_env: ["DEV_RELEASE=1"] build_env: ["DEV_RELEASE=1"]
- build_cuda_release: - build_cuda_release:
matrix: matrix:

View File

@@ -0,0 +1,20 @@
name: 'Build CUDA wheel'
description: 'Build CUDA wheel'
inputs:
nvcc-location:
description: 'Location of nvcc compiler'
required: true
runs:
using: "composite"
steps:
- name: Build package
shell: bash
env:
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: |
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh

45
.github/actions/build-cuda/action.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
name: 'Build and Test with CUDA'
description: 'Build and test MLX with CUDA'
inputs:
nvcc-location:
description: 'Location of nvcc compiler'
required: true
default: '/usr/local/cuda-12.9/bin/nvcc'
runs:
using: "composite"
steps:
- name: Install Python package
shell: bash
env:
DEBUG: 1
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
run: pip install -e ".[dev]" -v
- name: Run Python tests - CPU
shell: bash
env:
LOW_MEMORY: 1
DEVICE: cpu
run: python -m unittest discover python/tests -v
- name: Run Python tests - GPU
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
run: python -m tests discover python/tests -v
- name: Build CPP only
shell: bash
run: |
cmake . -B build \
-DMLX_BUILD_CUDA=ON \
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
-DCMAKE_BUILD_TYPE=DEBUG
cmake --build build -j $(nproc)
- name: Run CPP tests
shell: bash
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"

38
.github/actions/build-docs/action.yml vendored Normal file
View File

@@ -0,0 +1,38 @@
name: 'Build Documentation'
description: 'Build documentation on a mac'
runs:
using: "composite"
steps:
- name: Setup machine
uses: ./.github/actions/setup-macos
- name: Install dependencies
shell: sh
run: |
brew install doxygen
uv pip install --upgrade pip cmake
uv pip install -r docs/requirements.txt
uv pip install . -v
- name: Build documentation
shell: bash
run: |
source .venv/bin/activate
cd docs
doxygen
make html O=-W
- name: Create artifact tar
shell: sh
run: tar -cf artifact.tar --cd docs --dereference build/html index.html
# Do it manually because upload-pages-artifact requires gtar
- name: Upload artifact
id: upload-artifact
uses: actions/upload-artifact@v5
with:
name: github-pages
path: artifact.tar
retention-days: 1
if-no-files-found: error

View File

@@ -0,0 +1,33 @@
name: 'Build Linux wheel'
description: 'Build Linux wheel'
inputs:
build-backend:
description: 'Build the backend mlx-cpu package'
type: boolean
required: false
default: false
runs:
using: "composite"
steps:
- name: Generate package stubs
shell: bash
run: |
pip install -e ".[dev]" -v
pip install typing_extensions
python setup.py generate_stubs
- name: Build Python package
shell: bash
run: |
pip install auditwheel patchelf build
python setup.py clean --all
MLX_BUILD_STAGE=1 python -m build -w
bash python/scripts/repair_linux.sh
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash
run: |
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64

41
.github/actions/build-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,41 @@
name: 'Build and Test on Linux'
description: 'Build and test MLX on Linux'
runs:
using: "composite"
steps:
- name: Install Python package
shell: sh
env:
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
DEBUG: 1
run: pip install -e ".[dev]" -v
- name: Generate package stubs
shell: sh
run: |
pip install typing_extensions
python setup.py generate_stubs
- name: Run Python tests
shell: bash
run: |
python -m unittest discover python/tests -v
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if grep -Fq '[WARN]' stderr.log ; then
grep -F '[WARN]' stderr.log
echo "Distributed ring test failed";
exit 1;
fi
- name: Build CPP only
shell: bash
run: |
mkdir -p build && cd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
- name: Run CPP tests
shell: sh
run: ./build/tests/tests

View File

@@ -0,0 +1,33 @@
name: 'Build macOS release'
description: 'Build MLX releases macOS'
inputs:
macos-target:
description: 'macOS build target'
required: false
default: '15.0'
build-backend:
description: 'Build the backend mlx-metal package'
type: boolean
required: false
default: false
runs:
using: "composite"
steps:
- name: Build Python package
shell: bash
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
uv pip install build
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=1 uv run -m build -w
- name: Build backend package
if: ${{ inputs.build-backend }}
shell: bash
env:
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
run: |
uv run --no-project setup.py clean --all
MLX_BUILD_STAGE=2 uv run -m build -w

88
.github/actions/build-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,88 @@
name: 'Build and Test on macOS'
description: 'Build and test MLX on macOS'
runs:
using: "composite"
steps:
- name: Install dependencies
shell: sh
env:
DEBUG: 1
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
run: |
uv pip install --upgrade pip
uv pip install cmake setuptools nanobind==2.4.0
uv pip install -e . -v
- name: Generate package stubs
shell: bash
run: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- name: Install tests dependencies
shell: sh
run: |
uv pip install numpy torch tensorflow unittest-xml-reporting
- name: Run Python tests
shell: bash
env:
LOW_MEMORY: 1
run: |
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
- name: Build example extension
shell: bash
run: |
cd examples/extensions
uv pip install -r requirements.txt
uv run --no-project setup.py build_ext --inplace
uv run --no-project test.py
- name: Build CPP only
shell: bash
run: |
mkdir -p build
cd build
cmake ..
make -j $(sysctl -n hw.ncpu)
- name: Run CPP tests
shell: bash
env:
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: ./build/tests/tests
- name: Build small binary with JIT
shell: bash
run: |
mkdir -p build
cd build
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
make -j $(sysctl -n hw.ncpu)
- name: Run Python tests with JIT
shell: bash
env:
LOW_MEMORY: 1
DEVICE: gpu
METAL_DEVICE_WRAPPER_TYPE: 1
METAL_DEBUG_ERROR_MODE: 0
run: |
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
uv pip install -e . -v
uv run -m xmlrunner discover \
-v python/tests \
-o test-results/gpu_jit

83
.github/actions/setup-linux/action.yml vendored Normal file
View File

@@ -0,0 +1,83 @@
name: 'Setup Linux Environment'
description: 'Install dependencies for Linux builds'
inputs:
runner-type:
description: 'Whether to set this up as a linux or CUDA runner'
required: false
default: 'linux'
type: choice
options:
- linux
- cuda
python-version:
description: 'Version of python to set up'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Free disk space
shell: sh
if: inputs.runner-type == 'linux'
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Install common dependencies
env:
TZ: Etc/UTC
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip
sudo apt autoremove -y
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}
cache: 'pip'
- name: setup python venv
shell: bash
run: |
python -m venv .venv
source .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
pip install --upgrade pip cmake
- name: Install MPI
if: inputs.runner-type == 'linux'
shell: bash
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
- name: Network CUDA installation from packages
id: install-cuda
if: inputs.runner-type == 'cuda'
env:
TZ: Etc/UTC
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
- name: Package and Driver Report
if: inputs.runner-type == 'cuda'
shell: bash
run: |
sudo apt-get install -y ubuntu-drivers-common dkms
echo "NVIDIA Driver Packages Available:"
sudo ubuntu-drivers list --gpgpu
echo "NVIDIA Driver Version:"
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
echo "Installed NVIDIA and CUDA packages:"
dpkg -l | egrep "cuda|nvidia" -i
echo "DKMS Status:"
dkms status || echo "dkms not found"
echo "NVIDIA-SMI Status:"
nvidia-smi || echo "nvidia-smi not found"

25
.github/actions/setup-macos/action.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
name: 'Setup macOS Environment'
description: 'Install dependencies for macOS builds'
inputs:
python-version:
description: 'Python version to use'
required: false
default: '3.10'
runs:
using: "composite"
steps:
- name: Install Homebrew packages
shell: sh
run: /opt/homebrew/bin/brew install openmpi
- name: Verify MetalToolchain installed
shell: bash
run: xcodebuild -showComponent MetalToolchain
- name: Setup uv
uses: astral-sh/setup-uv@v6
with:
python-version: ${{ inputs.python-version }}
activate-environment: true

6
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,6 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"

View File

@@ -0,0 +1,27 @@
#!/bin/bash
set -ex
# [Setup] Install dependencies inside the container.
dnf update -y
dnf install -y \
blas-devel \
lapack-devel \
openblas-devel \
make \
cmake \
clang \
git
dnf clean all
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
export DEBUG=1
export CMAKE_C_COMPILER=/usr/bin/clang
export CMAKE_CXX_COMPILER=/usr/bin/clang++
mkdir -p build
pushd build
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
make -j $(nproc)
./tests/tests
popd

28
.github/workflows/documentation.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
name: Documentation
on:
workflow_dispatch:
permissions:
contents: read
jobs:
build:
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy:
needs: build
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

124
.github/workflows/nightly.yml vendored Normal file
View File

@@ -0,0 +1,124 @@
name: Nightly Build
on:
schedule:
- cron: 33 6 * * 1-5
workflow_dispatch:
permissions:
contents: read
jobs:
build_linux_release:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
retention-days: 7
build_linux_with_tests:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.13"]
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- uses: ./.github/actions/build-macos
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
build_cuda_with_tests:
if: github.repository == 'ml-explore/mlx'
runs-on: gpu-t4-4-core
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- uses: ./.github/actions/build-cuda
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
retention-days: 7
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: CPP Build Test - No Release
run: |
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh

View File

@@ -1,20 +1,71 @@
on: name: Build and Test
pull_request:
branches: on: pull_request
- main
permissions:
contents: read
jobs: jobs:
check_lint: check_lint:
runs-on: ubuntu-latest runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v5
- uses: actions/setup-python@v4 - uses: ./.github/actions/setup-linux
- uses: pre-commit/action@v3.0.1
linux_build_and_test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
- uses: ./.github/actions/build-linux
mac_build_and_test:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
- uses: ./.github/actions/build-macos
cuda_build_and_test:
if: github.repository == 'ml-explore/mlx'
runs-on: gpu-t4-4-core
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with: with:
python-version: 3.8 runner-type: 'cuda'
- name: Install dependencies - uses: ./.github/actions/build-cuda
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
needs: check_lint
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
linux_fedora_build_cpp:
name: Linux Fedora CPP Build (${{ matrix.arch }})
strategy:
fail-fast: false
matrix:
include:
- host: ubuntu-22.04
arch: x86_64
- host: ubuntu-22.04-arm
arch: aarch64
runs-on: ${{ matrix.host }}
container:
image: fedora:42
steps:
- name: Checkout code
uses: actions/checkout@v5
- name: CPP Build Test - No Release
run: | run: |
python -m pip install --upgrade pip bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files

226
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,226 @@
name: PyPI Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
permissions:
contents: read
jobs:
setup:
runs-on: ubuntu-latest
outputs:
pypi_env: ${{ github.event_name == 'push' && 'pypi' || 'test-pypi' }}
pypi_url: ${{ github.event_name == 'push' && 'https://upload.pypi.org/legacy/' || 'https://test.pypi.org/legacy/' }}
steps:
- name: Set publishing variables
run: echo "Publishing setup complete"
build_documentation:
if: github.repository == 'ml-explore/mlx'
runs-on: [self-hosted, macos]
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/build-docs
deploy_documentation:
needs: build_documentation
permissions:
pages: write
id-token: write
runs-on: ubuntu-latest
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
build_linux_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: ubuntu-22.04
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
build_mac_release:
if: github.repository == 'ml-explore/mlx'
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
runs-on: [self-hosted, macos]
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-macos
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: sh
run: |
uv pip install --upgrade pip
uv pip install cmake setuptools nanobind==2.4.0
uv pip install -e . -v
- name: Generate package stubs
shell: bash
run: |
uv pip install typing_extensions
uv run --no-project setup.py generate_stubs
- name: Build macOS 14 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 14.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Build macOS 15 package
uses: ./.github/actions/build-macos-release
with:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
with:
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
with:
name: mlx-metal
path: dist/mlx_metal-*.whl
build_cuda_release:
if: github.repository == 'ml-explore/mlx'
runs-on: ubuntu-22-large
env:
PYPI_RELEASE: 1
steps:
- uses: actions/checkout@v5
- uses: ./.github/actions/setup-linux
with:
runner-type: 'cuda'
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
- name: Upload artifacts
uses: actions/upload-artifact@v5
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release, build_mac_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
with:
pattern: linux-wheels-*
merge-multiple: true
path: dist
- uses: actions/download-artifact@v6
with:
pattern: mac-wheels-*
merge-multiple: true
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
pypi-publish-cuda:
name: Upload CUDA release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_cuda_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cuda
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
pypi-publish-cpu:
name: Upload CPU release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_linux_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-cpu
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}
pypi-publish-metal:
name: Upload Metal release to PyPI
runs-on: ubuntu-latest
needs: [setup, build_mac_release]
permissions:
id-token: write
environment:
name: ${{ needs.setup.outputs.pypi_env }}
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
with:
name: mlx-metal
path: dist
- name: Display structure of downloaded files
run: ls -R dist
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: ${{ needs.setup.outputs.pypi_url }}

View File

@@ -1,4 +1,10 @@
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7 rev: v19.1.7
hooks: hooks:

View File

@@ -19,7 +19,7 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -26,6 +26,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_INSTALL_MESSAGE NEVER) set(CMAKE_INSTALL_MESSAGE NEVER)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# ----------------------------- Configuration ----------------------------- # ----------------------------- Configuration -----------------------------
option(MLX_BUILD_TESTS "Build tests for mlx" ON) option(MLX_BUILD_TESTS "Build tests for mlx" ON)
@@ -87,6 +88,11 @@ cmake_policy(SET CMP0135 NEW)
add_library(mlx) add_library(mlx)
# Supress warnings: note: parameter passing for argument of type
# std::pair<float, float> when C++17 is enabled changed to match C++14 in GCC
# 10.1
target_compile_options(mlx PRIVATE -Wno-psabi)
if(MLX_BUILD_CUDA) if(MLX_BUILD_CUDA)
enable_language(CUDA) enable_language(CUDA)
endif() endif()
@@ -121,9 +127,12 @@ if(MLX_BUILD_METAL)
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}") message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip) https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS >= 14.0")
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif() endif()
execute_process( execute_process(
@@ -132,7 +141,6 @@ if(MLX_BUILD_METAL)
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL}) FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp) FetchContent_MakeAvailable(metal_cpp)
target_include_directories( target_include_directories(
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}> mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
@@ -173,7 +181,7 @@ if(MLX_BUILD_CPU)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
else() else()
message(STATUS "Accelerate or arm neon not found, using default backend.") message(STATUS "Accelerate not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()

View File

@@ -2,7 +2,7 @@
[**Quickstart**](#quickstart) | [**Installation**](#installation) | [**Quickstart**](#quickstart) | [**Installation**](#installation) |
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) | [**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
[**Examples**](#examples) [**Examples**](#examples)
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx) [![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
@@ -11,37 +11,37 @@ brought to you by Apple machine learning research.
Some key features of MLX include: Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX - **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror [Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building `mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models. more complex models.
- **Composable function transformations**: MLX supports composable function - **Composable function transformations**: MLX supports composable function
transformations for automatic differentiation, automatic vectorization, transformations for automatic differentiation, automatic vectorization,
and computation graph optimization. and computation graph optimization.
- **Lazy computation**: Computations in MLX are lazy. Arrays are only - **Lazy computation**: Computations in MLX are lazy. Arrays are only
materialized when needed. materialized when needed.
- **Dynamic graph construction**: Computation graphs in MLX are constructed - **Dynamic graph construction**: Computation graphs in MLX are constructed
dynamically. Changing the shapes of function arguments does not trigger dynamically. Changing the shapes of function arguments does not trigger
slow compilations, and debugging is simple and intuitive. slow compilations, and debugging is simple and intuitive.
- **Multi-device**: Operations can run on any of the supported devices - **Multi-device**: Operations can run on any of the supported devices
(currently the CPU and the GPU). (currently the CPU and the GPU).
- **Unified memory**: A notable difference from MLX and other frameworks - **Unified memory**: A notable difference from MLX and other frameworks
is the *unified memory model*. Arrays in MLX live in shared memory. is the *unified memory model*. Arrays in MLX live in shared memory.
Operations on MLX arrays can be performed on any of the supported Operations on MLX arrays can be performed on any of the supported
device types without transferring data. device types without transferring data.
MLX is designed by machine learning researchers for machine learning MLX is designed by machine learning researchers for machine learning
researchers. The framework is intended to be user-friendly, but still efficient researchers. The framework is intended to be user-friendly, but still efficient
to train and deploy models. The design of the framework itself is also to train and deploy models. The design of the framework itself is also
conceptually simple. We intend to make it easy for researchers to extend and conceptually simple. We intend to make it easy for researchers to extend and
improve MLX with the goal of quickly exploring new ideas. improve MLX with the goal of quickly exploring new ideas.
The design of MLX is inspired by frameworks like The design of MLX is inspired by frameworks like
[NumPy](https://numpy.org/doc/stable/index.html), [NumPy](https://numpy.org/doc/stable/index.html),
@@ -91,7 +91,7 @@ Checkout the
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#) [documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
for more information on building the C++ and Python APIs from source. for more information on building the C++ and Python APIs from source.
## Contributing ## Contributing
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
on contributing to MLX. See the on contributing to MLX. See the
@@ -110,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX useful in your research and wish to cite it, please use the following MLX useful in your research and wish to cite it, please use the following
BibTex entry: BibTex entry:
``` ```text
@software{mlx2023, @software{mlx2023,
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert}, author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
title = {{MLX}: Efficient and flexible machine learning on Apple silicon}, title = {{MLX}: Efficient and flexible machine learning on Apple silicon},

View File

@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1) t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b) c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype( c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
np.float32
)
atol = 1e-5 if np_dtype == np.float32 else 1e-4 atol = 1e-5 if np_dtype == np.float32 else 1e-4
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks") parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16") dtypes = ("float32", "float16", "complex64")
transposes = ("nn", "nt", "tn") transposes = ("nn", "nt", "tn")
shapes = ( shapes = (
(16, 234, 768, 3072), (16, 234, 768, 3072),
@@ -187,7 +185,7 @@ if __name__ == "__main__":
diff = gflops_mx / gflops_pt - 1.0 diff = gflops_mx / gflops_pt - 1.0
print( print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%" f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
) )
if gflops_pt >= 2.0 * gflops_mx: if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^") print("ATTENTION ^^^^^^^")

View File

@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
for transpose in (False, True): for transpose in (False, True):
for dtype in ("float32", "float16"): for dtype in ("float32", "float16", "complex64"):
fig, axs = plt.subplots( fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained" len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
) )
@@ -215,7 +215,7 @@ for transpose in (False, True):
fig.suptitle(f"{device_name}: {dtype} {op_name}") fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig( fig.savefig(
os.path.join( os.path.join(
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf' results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
) )
) )
plt.close(fig) plt.close(fig)

View File

@@ -16,12 +16,11 @@ silicon computer is
To install from PyPI your system must meet the following requirements: To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon) - Using an M series chip (Apple silicon)
- Using a native Python >= 3.9 - Using a native Python >= 3.10
- macOS >= 13.5 - macOS >= 14.0
.. note:: .. note::
MLX is only available on devices running macOS >= 13.5 MLX is only available on devices running macOS >= 14.0 and higher.
It is highly recommended to use macOS 14 (Sonoma)
CUDA CUDA
^^^^ ^^^^
@@ -39,7 +38,7 @@ requirements:
- Nvidia driver >= 550.54.14 - Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0 - CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35 - Linux distribution with glibc >= 2.35
- Python >= 3.9 - Python >= 3.10
CPU-only (Linux) CPU-only (Linux)
@@ -55,7 +54,7 @@ To install the CPU-only package from PyPi your system must meet the following
requirements: requirements:
- Linux distribution with glibc >= 2.35 - Linux distribution with glibc >= 2.35
- Python >= 3.9 - Python >= 3.10
Troubleshooting Troubleshooting

View File

@@ -27,6 +27,7 @@ simple functions.
mish mish
prelu prelu
relu relu
relu2
relu6 relu6
selu selu
sigmoid sigmoid

View File

@@ -50,6 +50,7 @@ Layers
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
ReLU2
ReLU6 ReLU6
RNN RNN
RoPE RoPE

View File

@@ -112,6 +112,7 @@ Operations
max max
maximum maximum
mean mean
median
meshgrid meshgrid
min min
minimum minimum

View File

@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
.. code-block:: python .. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096)) x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x) timeit(gelu, x)
timeit(mx.compile(nn.gelu), x) timeit(mx.compile(gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster. five times faster.

View File

@@ -7,12 +7,13 @@ Distributed Communication
MLX supports distributed communication operations that allow the computational cost MLX supports distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. At the of training or inference to be shared across many physical machines. At the
moment we support two different communication backends: moment we support three different communication backends:
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a * `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
full-featured and mature distributed communications library full-featured and mature distributed communications library
* A **ring** backend of our own that uses native TCP sockets and should be * A **ring** backend of our own that uses native TCP sockets. It should be
faster for thunderbolt connections. faster for thunderbolt connections, but it also works over Ethernet.
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
The list of all currently supported operations and their documentation can be The list of all currently supported operations and their documentation can be
seen in the :ref:`API docs<distributed>`. seen in the :ref:`API docs<distributed>`.
@@ -84,9 +85,8 @@ Selecting Backend
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
You can select the backend you want to use when calling :func:`init` by passing You can select the backend you want to use when calling :func:`init` by passing
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they available backends. If they all fail then a singleton group is created.
both fail then a singleton group is created.
.. note:: .. note::
After a distributed backend is successfully initialized :func:`init` will After a distributed backend is successfully initialized :func:`init` will
@@ -184,7 +184,7 @@ almost identical to the example above:
def step(model, x, y): def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y) loss, grads = loss_grad_fn(model, x, y)
grads = mlx.nn.average_gradients(grads) # <---- This line was added grads = mx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads) optimizer.update(model, grads)
return loss return loss
@@ -220,7 +220,7 @@ print 4 etc.
Installing MPI Installing MPI
^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^
MPI can be installed with Homebrew, using the Anaconda package manager or MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
compiled from source. Most of our testing is done using ``openmpi`` installed compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows: with the Anaconda package manager as follows:
@@ -228,14 +228,16 @@ with the Anaconda package manager as follows:
$ conda install conda-forge::openmpi $ conda install conda-forge::openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld`` Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
done automatically by ``mlx.launch``. done automatically by ``mlx.launch``. Some environments use a non-standard
library filename that can be specified using the ``MPI_LIBNAME`` environment
variable. This is automatically taken care of by ``mlx.launch`` as well.
.. code:: shell .. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py $ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
$ # or simply $ # or simply
$ mlx.launch -n 2 test.py $ mlx.launch -n 2 test.py

View File

@@ -164,11 +164,11 @@ to export a function which can be used for inputs with variable shapes:
.. code-block:: python .. code-block:: python
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True) mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
imported_abs = mx.import_function("fun.mlxfn") imported_abs = mx.import_function("fun.mlxfn")
# Ok # Ok
out, = imported_abs(mx.array(-1.0)) out, = imported_abs(mx.array([-1.0]))
# Also ok # Also ok
out, = imported_abs(mx.array([-1.0, -2.0])) out, = imported_abs(mx.array([-1.0, -2.0]))

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_; void* ptr_;
public: public:
Buffer(void* ptr) : ptr_(ptr) {}; explicit Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer // Get the raw data pointer from the buffer
void* raw_ptr(); void* raw_ptr();

View File

@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
other.strides(), other.strides(),
other.flags(), other.flags(),
[](auto) {}); [](auto) {});
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr; cpy.array_desc_->offset = other.array_desc_->offset;
return cpy; return cpy;
} }
@@ -141,7 +141,7 @@ bool array::is_tracer() const {
void array::set_data(allocator::Buffer buffer, Deleter d) { void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d); array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr(); array_desc_->offset = 0;
array_desc_->data_size = size(); array_desc_->data_size = size();
array_desc_->flags.contiguous = true; array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true; array_desc_->flags.row_contiguous = true;
@@ -156,7 +156,7 @@ void array::set_data(
Flags flags, Flags flags,
Deleter d) { Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d); array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr(); array_desc_->offset = 0;
array_desc_->data_size = data_size; array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides); array_desc_->strides = std::move(strides);
array_desc_->flags = flags; array_desc_->flags = flags;
@@ -172,9 +172,8 @@ void array::copy_shared_buffer(
array_desc_->strides = strides; array_desc_->strides = strides;
array_desc_->flags = flags; array_desc_->flags = flags;
array_desc_->data_size = data_size; array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset; array_desc_->offset =
array_desc_->data_ptr = static_cast<void*>( sizeof(char) * itemsize() * offset + other.array_desc_->offset;
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
} }
void array::copy_shared_buffer(const array& other) { void array::copy_shared_buffer(const array& other) {
@@ -241,8 +240,8 @@ array::ArrayDesc::ArrayDesc(
std::vector<array> inputs) std::vector<array> inputs)
: shape(std::move(shape)), : shape(std::move(shape)),
dtype(dtype), dtype(dtype),
status(Status::unscheduled),
primitive(std::move(primitive)), primitive(std::move(primitive)),
status(Status::unscheduled),
inputs(std::move(inputs)) { inputs(std::move(inputs)) {
init(); init();
} }

View File

@@ -294,6 +294,11 @@ class array {
return array_desc_->siblings; return array_desc_->siblings;
} }
/** The array's position in the sibling list. */
int sibling_position() const {
return array_desc_->position;
}
void set_siblings(std::vector<array> siblings, uint16_t position) { void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings); array_desc_->siblings = std::move(siblings);
array_desc_->position = position; array_desc_->position = position;
@@ -349,15 +354,23 @@ class array {
return array_desc_->data; return array_desc_->data;
} }
// Return a raw pointer to the arrays data // Return a raw pointer to the arrays data. This function may do a copy if
// the underlying buffer is not accessible on the CPU. When accessing the
// data for GPU kernels, be sure to use the correct method / function for the
// given backend to access the GPU pointer.
template <typename T> template <typename T>
T* data() { T* data() {
return static_cast<T*>(array_desc_->data_ptr); return reinterpret_cast<T*>(
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
} }
template <typename T> template <typename T>
const T* data() const { const T* data() const {
return static_cast<T*>(array_desc_->data_ptr); return const_cast<array&>(*this).data<T>();
}
int64_t offset() const {
return array_desc_->offset;
} }
enum Status { enum Status {
@@ -461,8 +474,8 @@ class array {
// can share the underlying data buffer. // can share the underlying data buffer.
std::shared_ptr<Data> data; std::shared_ptr<Data> data;
// Properly offset data pointer // Offset from beginning of data pointer
void* data_ptr{nullptr}; int64_t offset{0};
// The size in elements of the data buffer the array accesses // The size in elements of the data buffer the array accesses
size_t data_size; size_t data_size;

View File

@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
const array& a, const array& a,
const array& b, const array& b,
array& out, array& out,
BinaryOpType bopt) { BinaryOpType bopt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
bool b_donatable = is_donatable(b, out); bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out); bool a_donatable = is_donatable(a, out);
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
out.set_data( out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
if (b_donatable) { if (b_donatable) {
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data( out.set_data(
allocator::malloc(b.data_size() * out.itemsize()), mallocfn(b.data_size() * out.itemsize()),
b.data_size(), b.data_size(),
b.strides(), b.strides(),
b.flags()); b.flags());
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} else { } else {
out.set_data( out.set_data(
allocator::malloc(a.data_size() * out.itemsize()), mallocfn(a.data_size() * out.itemsize()),
a.data_size(), a.data_size(),
a.strides(), a.strides(),
a.flags()); a.flags());
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data( out.set_data(
allocator::malloc(a.data_size() * out.itemsize()), mallocfn(a.data_size() * out.itemsize()),
a.data_size(), a.data_size(),
a.strides(), a.strides(),
a.flags()); a.flags());
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
b_donatable && b.flags().row_contiguous && b.size() == out.size()) { b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
out.copy_shared_buffer(b); out.copy_shared_buffer(b);
} else { } else {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(mallocfn(out.nbytes()));
} }
break; break;
} }

View File

@@ -6,7 +6,7 @@ namespace mlx::core {
void broadcast(const array& in, array& out) { void broadcast(const array& in, array& out) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }
Strides strides(out.ndim(), 0); Strides strides(out.ndim(), 0);

View File

@@ -114,7 +114,9 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant, const std::function<bool(size_t)>& is_constant,
bool contiguous) { bool contiguous,
const std::function<allocator::Buffer(size_t)>&
mallocfn /* = allocator::malloc */) {
if (contiguous) { if (contiguous) {
int o = 0; int o = 0;
Strides strides; Strides strides;
@@ -140,7 +142,7 @@ void compiled_allocate_outputs(
} }
for (; o < outputs.size(); ++o) { for (; o < outputs.size(); ++o) {
outputs[o].set_data( outputs[o].set_data(
allocator::malloc(data_size * outputs[o].itemsize()), mallocfn(data_size * outputs[o].itemsize()),
data_size, data_size,
strides, strides,
flags); flags);
@@ -163,7 +165,7 @@ void compiled_allocate_outputs(
} }
} }
for (; o < outputs.size(); ++o) { for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc(outputs[o].nbytes())); outputs[o].set_data(mallocfn(outputs[o].nbytes()));
} }
} }
} }

View File

@@ -58,7 +58,9 @@ void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::function<bool(size_t)>& is_constant, const std::function<bool(size_t)>& is_constant,
bool contiguous); bool contiguous,
const std::function<allocator::Buffer(size_t)>& mallocfn =
allocator::malloc);
// Collapse contiguous dims ignoring scalars and constants. // Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims( std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(

View File

@@ -22,7 +22,11 @@ enum class CopyType {
GeneralGeneral GeneralGeneral
}; };
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { inline bool set_copy_output_data(
const array& in,
array& out,
CopyType ctype,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types // If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output. // have the same size, then the input buffer can hold the output.
@@ -31,14 +35,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
return true; return true;
} else { } else {
out.set_data( out.set_data(
allocator::malloc(in.data_size() * out.itemsize()), mallocfn(in.data_size() * out.itemsize()),
in.data_size(), in.data_size(),
in.strides(), in.strides(),
in.flags()); in.flags());
return false; return false;
} }
} else { } else {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(mallocfn(out.nbytes()));
return false; return false;
} }
} }

View File

@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
const array& a, const array& a,
const array& b) { const array& b) {
if (a.ndim() == 2) { if (a.ndim() == 2) {
return {{1}, {0}, {0}}; return {Shape{1}, Strides{0}, Strides{0}};
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
inline std::tuple<Shape, Strides, Strides, Strides> inline std::tuple<Shape, Strides, Strides, Strides>
collapse_batches(const array& a, const array& b, const array& c) { collapse_batches(const array& a, const array& b, const array& c) {
if (a.ndim() == 2) { if (a.ndim() == 2) {
return {{1}, {0}, {0}, {0}}; return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
} }
Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; Shape A_bshape{a.shape().begin(), a.shape().end() - 2};

View File

@@ -45,7 +45,7 @@ void slice(
const Shape& start_indices, const Shape& start_indices,
const Shape& strides) { const Shape& strides) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }

View File

@@ -11,6 +11,8 @@ namespace mlx::core {
enum class TernaryOpType { enum class TernaryOpType {
ScalarScalarScalar, ScalarScalarScalar,
VectorVectorVector, VectorVectorVector,
VectorVectorScalar,
VectorScalarVector,
General, General,
}; };
@@ -25,6 +27,14 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
(a.flags().col_contiguous && b.flags().col_contiguous && (a.flags().col_contiguous && b.flags().col_contiguous &&
c.flags().col_contiguous)) { c.flags().col_contiguous)) {
topt = TernaryOpType::VectorVectorVector; topt = TernaryOpType::VectorVectorVector;
} else if (
b.data_size() == 1 && a.flags().row_contiguous &&
c.flags().row_contiguous) {
topt = TernaryOpType::VectorScalarVector;
} else if (
c.data_size() == 1 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
topt = TernaryOpType::VectorVectorScalar;
} else { } else {
topt = TernaryOpType::General; topt = TernaryOpType::General;
} }
@@ -36,7 +46,8 @@ inline void set_ternary_op_output_data(
const array& b, const array& b,
const array& c, const array& c,
array& out, array& out,
TernaryOpType topt) { TernaryOpType topt,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
auto maybe_donate = [&out](const array& x) { auto maybe_donate = [&out](const array& x) {
if (is_donatable(x, out)) { if (is_donatable(x, out)) {
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
@@ -47,24 +58,25 @@ inline void set_ternary_op_output_data(
switch (topt) { switch (topt) {
case TernaryOpType::ScalarScalarScalar: case TernaryOpType::ScalarScalarScalar:
out.set_data( out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
break; break;
case TernaryOpType::VectorVectorVector: case TernaryOpType::VectorVectorVector:
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
out.set_data( out.set_data(
allocator::malloc(out.itemsize() * b.data_size()), mallocfn(out.itemsize() * b.data_size()),
b.data_size(), b.data_size(),
b.strides(), b.strides(),
b.flags()); b.flags());
} }
break; break;
case TernaryOpType::VectorVectorScalar:
case TernaryOpType::VectorScalarVector:
case TernaryOpType::General: case TernaryOpType::General:
// Try to donate an input which is row_contiguous // Try to donate an input which is row_contiguous
if (!((a.flags().row_contiguous && maybe_donate(a)) || if (!((a.flags().row_contiguous && maybe_donate(a)) ||
(b.flags().row_contiguous && maybe_donate(b)) || (b.flags().row_contiguous && maybe_donate(b)) ||
(c.flags().row_contiguous && maybe_donate(c)))) { (c.flags().row_contiguous && maybe_donate(c)))) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(mallocfn(out.nbytes()));
} }
break; break;
} }

View File

@@ -7,19 +7,22 @@
namespace mlx::core { namespace mlx::core {
inline void set_unary_output_data(const array& in, array& out) { inline void set_unary_output_data(
const array& in,
array& out,
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
if (in.flags().contiguous) { if (in.flags().contiguous) {
if (is_donatable(in, out)) { if (is_donatable(in, out)) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
out.set_data( out.set_data(
allocator::malloc(in.data_size() * out.itemsize()), mallocfn(in.data_size() * out.itemsize()),
in.data_size(), in.data_size(),
in.strides(), in.strides(),
in.flags()); in.flags());
} }
} else { } else {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(mallocfn(out.nbytes()));
} }
} }

View File

@@ -14,233 +14,11 @@
namespace mlx::core { namespace mlx::core {
namespace {
template <typename Op>
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace
void Add::eval_cpu(const std::vector<array>& inputs, array& out) { void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Add(), stream()); binary_op_cpu(a, b, out, detail::Add(), stream());
} }
void DivMod::eval_cpu( void DivMod::eval_cpu(
@@ -324,14 +102,14 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Divide(), stream()); binary_op_cpu(a, b, out, detail::Divide(), stream());
} }
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) { void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Remainder(), stream()); binary_op_cpu(a, b, out, detail::Remainder(), stream());
} }
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) { void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -372,89 +150,90 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
}); });
} else { } else {
comparison_op(a, b, out, detail::Equal(), stream()); comparison_op_cpu(a, b, out, detail::Equal(), stream());
} }
} }
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) { void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream()); comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream());
} }
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream()); comparison_op_cpu(
inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
} }
void Less::eval_cpu(const std::vector<array>& inputs, array& out) { void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream()); comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream());
} }
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream()); comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream());
} }
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) { void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary_float(a, b, out, detail::LogAddExp(), stream()); binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream());
} }
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd(), stream()); binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream());
} }
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) { void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0]; auto& in1 = inputs[0];
auto& in2 = inputs[1]; auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr(), stream()); binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream());
} }
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) { void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Maximum(), stream()); binary_op_cpu(a, b, out, detail::Maximum(), stream());
} }
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) { void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Minimum(), stream()); binary_op_cpu(a, b, out, detail::Minimum(), stream());
} }
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) { void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Multiply(), stream()); binary_op_cpu(a, b, out, detail::Multiply(), stream());
} }
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) { void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream()); comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream());
} }
void Power::eval_cpu(const std::vector<array>& inputs, array& out) { void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Power(), stream()); binary_op_cpu(a, b, out, detail::Power(), stream());
} }
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) { void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
binary(a, b, out, detail::Subtract(), stream()); binary_op_cpu(a, b, out, detail::Subtract(), stream());
} }
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) { void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -463,19 +242,19 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
switch (op_) { switch (op_) {
case BitwiseBinary::And: case BitwiseBinary::And:
binary_int(a, b, out, detail::BitwiseAnd(), stream()); binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream());
break; break;
case BitwiseBinary::Or: case BitwiseBinary::Or:
binary_int(a, b, out, detail::BitwiseOr(), stream()); binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream());
break; break;
case BitwiseBinary::Xor: case BitwiseBinary::Xor:
binary_int(a, b, out, detail::BitwiseXor(), stream()); binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream());
break; break;
case BitwiseBinary::LeftShift: case BitwiseBinary::LeftShift:
binary_int(a, b, out, detail::LeftShift(), stream()); binary_int_op_cpu(a, b, out, detail::LeftShift(), stream());
break; break;
case BitwiseBinary::RightShift: case BitwiseBinary::RightShift:
binary_int(a, b, out, detail::RightShift(), stream()); binary_int_op_cpu(a, b, out, detail::RightShift(), stream());
break; break;
} }
} }
@@ -484,7 +263,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
const auto& a = inputs[0]; const auto& a = inputs[0];
const auto& b = inputs[1]; const auto& b = inputs[1];
binary_float(a, b, out, detail::ArcTan2(), stream()); binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
namespace mlx::core { namespace mlx::core {
@@ -290,4 +291,227 @@ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
binary_op<T, T, Op>(a, b, out, bopt); binary_op<T, T, Op>(a, b, out, bopt);
} }
template <typename Op>
void binary_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void comparison_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (a.dtype()) {
case bool_:
binary_op<bool, bool, Op>(a, b, out, bopt);
break;
case uint8:
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, bool, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, bool, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, bool, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, bool, Op>(a, b, out, bopt);
break;
case float16:
binary_op<float16_t, bool, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, bool, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, bool, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
break;
}
});
}
template <typename Op>
void binary_float_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case float16:
binary_op<float16_t, Op>(a, b, out, bopt);
break;
case float32:
binary_op<float, Op>(a, b, out, bopt);
break;
case float64:
binary_op<double, Op>(a, b, out, bopt);
break;
case bfloat16:
binary_op<bfloat16_t, Op>(a, b, out, bopt);
break;
case complex64:
binary_op<complex64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error(
"[binary_float] Only supports floating point types.");
}
});
}
template <typename Op>
void binary_int_op_cpu(
const array& a,
const array& b,
array& out,
Op op,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.dispatch([a = array::unsafe_weak_copy(a),
b = array::unsafe_weak_copy(b),
out = array::unsafe_weak_copy(out),
bopt]() mutable {
switch (out.dtype()) {
case bool_:
binary_op<bool, Op>(a, b, out, bopt);
case uint8:
binary_op<uint8_t, Op>(a, b, out, bopt);
break;
case uint16:
binary_op<uint16_t, Op>(a, b, out, bopt);
break;
case uint32:
binary_op<uint32_t, Op>(a, b, out, bopt);
break;
case uint64:
binary_op<uint64_t, Op>(a, b, out, bopt);
break;
case int8:
binary_op<int8_t, Op>(a, b, out, bopt);
break;
case int16:
binary_op<int16_t, Op>(a, b, out, bopt);
break;
case int32:
binary_op<int32_t, Op>(a, b, out, bopt);
break;
case int64:
binary_op<int64_t, Op>(a, b, out, bopt);
break;
default:
throw std::runtime_error("[binary_int] Type not supported");
break;
}
});
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -996,131 +996,6 @@ void explicit_gemm_conv_1D_cpu(
encoder.add_temporaries(std::move(temps)); encoder.add_temporaries(std::move(temps));
} }
void explicit_gemm_conv_2D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding_lo,
const std::vector<int>& padding_hi,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
Stream stream) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int iW = in.shape(2); // Input spatial dim
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
auto conv_dtype = out.dtype();
auto& encoder = cpu::get_command_encoder(stream);
// Pad input
Shape padded_shape = {
N,
iH + padding_lo[0] + padding_hi[0],
iW + padding_lo[1] + padding_hi[1],
C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
std::vector<array> temps;
temps.push_back(array(0, conv_dtype));
copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
// Pick input slice from padded
size_t data_offset = padding_lo[0] * in_padded.strides()[1] +
padding_lo[1] * in_padded.strides()[2];
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
temps.push_back(in_padded_slice);
// Copy input values into the slice
copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
// Make strided view
Shape strided_shape = {N, oH, oW, wH, wW, C};
Strides strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
in_padded.strides()[1],
in_padded.strides()[2],
in_padded.strides()[3]};
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
temps.push_back(in_strided);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy_cpu(wt, gemm_wt, ctype, stream);
temps.push_back(gemm_wt);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
temps.push_back(gemm_out);
}
encoder.set_input_array(in_strided);
encoder.set_input_array(gemm_wt);
encoder.set_output_array(gemm_out);
encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
gemm_wt_ptr = gemm_wt.data<float>(),
gemm_out_ptr = gemm_out.data<float>(),
strided_reshape = std::move(strided_reshape),
O]() {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided_ptr,
strided_reshape[1], // lda
gemm_wt_ptr,
strided_reshape[1], // ldb
0.0f, // beta
gemm_out_ptr,
O // ldc
);
});
// Copy results if needed
if (out.dtype() != float32) {
copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
}
encoder.add_temporaries(std::move(temps));
}
void explicit_gemm_conv_ND_cpu( void explicit_gemm_conv_ND_cpu(
const array& in, const array& in,
const array& wt, const array& wt,

View File

@@ -95,4 +95,9 @@ void Recv::eval_cpu(
distributed::detail::recv(group(), outputs[0], src_, stream()); distributed::detail::recv(group(), outputs[0], src_, stream());
} }
void ReduceScatter::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
}
} // namespace mlx::core::distributed } // namespace mlx::core::distributed

View File

@@ -46,7 +46,6 @@ void eig_impl(
int info; int info;
{ {
T work; T work;
int iwork;
geev<T>( geev<T>(
&jobl, &jobl,
&jobr, &jobr,

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
#include "mlx/array.h" #include "mlx/array.h"
@@ -49,9 +48,15 @@ void matmul_bnns(
size_t K = a_shape[ndim - 1]; size_t K = a_shape[ndim - 1];
BNNSDataType bnns_dtype = to_bnns_dtype<T>(); BNNSDataType bnns_dtype = to_bnns_dtype<T>();
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations" #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (auto i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
}
const BNNSLayerParametersBroadcastMatMul gemm_params{ const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ alpha, /* float alpha = */ alpha,
/* float beta = */ beta, /* float beta = */ beta,

View File

@@ -88,4 +88,47 @@ void matmul<double>(
} }
} }
template <>
void matmul<complex64_t>(
const complex64_t* a,
const complex64_t* b,
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
float alpha,
float beta,
size_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);
for (int i = 0; i < batch_size; ++i) {
cblas_cgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
&calpha,
a + elem_to_loc(M * K * i, a_shape, a_strides),
lda,
b + elem_to_loc(K * N * i, b_shape, b_strides),
ldb,
&cbeta,
out + M * N * i,
ldc);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -215,18 +215,18 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
const void* a_mask_ptr; const void* a_mask_ptr = nullptr;
const void* b_mask_ptr; const void* b_mask_ptr = nullptr;
const void* out_mask_ptr; const void* out_mask_ptr = nullptr;
Shape a_mask_shape; Shape a_mask_shape;
Shape b_mask_shape; Shape b_mask_shape;
Shape out_mask_shape; Shape out_mask_shape;
Strides a_mask_strides; Strides a_mask_strides;
Strides b_mask_strides; Strides b_mask_strides;
Strides out_mask_strides; Strides out_mask_strides;
bool a_mask_bool; bool a_mask_bool = false;
bool b_mask_bool; bool b_mask_bool = false;
bool out_mask_bool; bool out_mask_bool = false;
if (has_op_mask) { if (has_op_mask) {
auto& a_mask = inputs[inputs.size() - 2]; auto& a_mask = inputs[inputs.size() - 2];
auto& b_mask = inputs[inputs.size() - 1]; auto& b_mask = inputs[inputs.size() - 1];
@@ -423,7 +423,6 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& rhs_indices = inputs[3]; auto& rhs_indices = inputs[3];
auto batch_shape = get_batch_dims(out.shape()); auto batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
auto batch_shape_A = get_batch_dims(a.shape()); auto batch_shape_A = get_batch_dims(a.shape());
auto batch_strides_A = get_batch_dims(a.strides()); auto batch_strides_A = get_batch_dims(a.strides());

View File

@@ -2,6 +2,8 @@
#include <cstring> #include <cstring>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cpu/binary.h"
#include "mlx/backend/cpu/binary_ops.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/gemm.h" #include "mlx/backend/cpu/gemm.h"
@@ -91,7 +93,6 @@ void matmul_general(
auto [b_transposed, ldb, b] = check_transpose(b_pre); auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2); size_t M = a.shape(-2);
size_t N = b.shape(-1); size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) { if (M == 0 || N == 0) {
return; return;
} }
@@ -108,6 +109,9 @@ void matmul_general(
} else if (out.dtype() == float64) { } else if (out.dtype() == float64) {
matmul_dispatch<double>( matmul_dispatch<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream); a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else if (out.dtype() == complex64) {
matmul_dispatch<complex64_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
} else { } else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type."); throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
} }
@@ -128,24 +132,34 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
} }
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
return; return;
} }
// Handle empty matrix case (K=0)
if (inputs[0].shape(-1) == 0) {
auto& c = inputs[2];
if (beta_ == 1.0f) {
CopyType ctype = c.data_size() == 1
? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream());
} else {
array beta_scalar = array(beta_, c.dtype());
auto& encoder = cpu::get_command_encoder(stream());
binary_float_op_cpu(c, beta_scalar, out, detail::Multiply(), stream());
encoder.add_temporary(std::move(beta_scalar));
}
return;
}
// Fill output with C // Fill output with C
auto& c = inputs[2]; auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 CopyType ctype = c.data_size() == 1
? CopyType::Scalar ? CopyType::Scalar
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General); : (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_cpu(c, out, ctype, stream()); copy_cpu(c, out, ctype, stream());
if (inputs[0].shape(-1) == 0) {
return;
}
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_); matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
} }

View File

@@ -333,7 +333,7 @@ void Reshape::eval_cpu(const std::vector<array>& inputs, array& out) {
void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) { void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }
auto& in = inputs[0]; auto& in = inputs[0];
@@ -361,7 +361,7 @@ void DynamicSliceUpdate::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out) { array& out) {
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }
@@ -396,7 +396,7 @@ void DynamicSliceUpdate::eval_cpu(
void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) { void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
if (out.size() == 0) { if (out.size() == 0) {
out.set_data(nullptr); out.set_data(allocator::malloc(0));
return; return;
} }

View File

@@ -1,8 +1,11 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include "mlx/backend/common/unary.h"
#include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h" #include "mlx/backend/cpu/simd/simd.h"
#include "mlx/backend/cpu/unary.h"
#include "mlx/backend/cpu/unary_ops.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -445,7 +448,6 @@ void mxfp4_qmm(
int K) { int K) {
constexpr int group_size = 32; constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8); constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
@@ -487,7 +489,6 @@ void mxfp4_qmm_t(
int K) { int K) {
constexpr int group_size = 32; constexpr int group_size = 32;
constexpr int pack_factor = get_pack_factor(4, 8); constexpr int pack_factor = get_pack_factor(4, 8);
constexpr int bytes_per_pack = get_bytes_per_pack(4);
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
@@ -1104,4 +1105,44 @@ void fast::Quantize::eval_cpu(
}); });
} }
void fast::ConvertFP8::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& in = inputs[0];
auto& out = outputs[0];
set_unary_output_data(in, out);
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
to_fp8 = to_fp8_]() mutable {
if (to_fp8) {
switch (in.dtype()) {
case float16:
unary_op<float16_t, uint8_t>(in, out, detail::ToFP8());
break;
case bfloat16:
unary_op<bfloat16_t, uint8_t>(in, out, detail::ToFP8());
break;
default:
unary_op<float, uint8_t>(in, out, detail::ToFP8());
break;
}
} else {
switch (out.dtype()) {
case float16:
unary_op<uint8_t, float16_t>(in, out, detail::FromFP8());
break;
case bfloat16:
unary_op<uint8_t, bfloat16_t>(in, out, detail::FromFP8());
break;
default:
unary_op<uint8_t, float>(in, out, detail::FromFP8());
break;
}
}
});
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,5 +1,6 @@
#pragma once #pragma once
#include <arm_neon.h>
#include <simd/math.h> #include <simd/math.h>
#include <simd/vector.h> #include <simd/vector.h>
@@ -9,7 +10,7 @@
#include "mlx/backend/cpu/simd/base_simd.h" #include "mlx/backend/cpu/simd/base_simd.h"
// There seems to be a bug in sims/base.h // There seems to be a bug in simd/base_simd.h
// __XROS_2_0 is not defined, the expression evaluates // __XROS_2_0 is not defined, the expression evaluates
// to true instead of false setting the SIMD library // to true instead of false setting the SIMD library
// higher than it should be even on macOS < 15 // higher than it should be even on macOS < 15
@@ -200,6 +201,15 @@ SIMD_DEFAULT_COMPARISONS(<=)
SIMD_DEFAULT_COMPARISONS(==) SIMD_DEFAULT_COMPARISONS(==)
SIMD_DEFAULT_COMPARISONS(!=) SIMD_DEFAULT_COMPARISONS(!=)
template <typename T, int N>
Simd<T, N> clz(Simd<T, N> x) {
auto a = *(uint32x4_t*)(&x);
auto b = *((uint32x4_t*)(&x) + 1);
a = vclzq_u32(a);
b = vclzq_u32(b);
return asd::make_uint8(a, b);
}
template <typename T, int N> template <typename T, int N>
Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) { Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
return asd::atan2(a.value, b.value); return asd::atan2(a.value, b.value);
@@ -207,14 +217,20 @@ Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
template <typename T, int N> template <typename T, int N>
Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) { Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan auto out = Simd<T, N>(asd::max(a.value, b.value));
return asd::max(a.value, b.value); if constexpr (!std::is_integral_v<T>) {
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
} }
template <typename T, int N> template <typename T, int N>
Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) { Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
// TODO add isnan auto out = Simd<T, N>(asd::min(a.value, b.value));
return asd::min(a.value, b.value); if constexpr (!std::is_integral_v<T>) {
out = select(isnan(b), b, select(isnan(a), a, out));
}
return out;
} }
template <typename T, int N> template <typename T, int N>

View File

@@ -171,6 +171,11 @@ DEFAULT_BINARY(&)
DEFAULT_BINARY(&&) DEFAULT_BINARY(&&)
DEFAULT_BINARY(||) DEFAULT_BINARY(||)
template <typename T>
Simd<T, 1> clz(Simd<T, 1> x_) {
return __builtin_clz(x_.value);
}
template <typename T> template <typename T>
Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) { Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
T a = a_.value; T a = a_.value;

View File

@@ -15,6 +15,18 @@ namespace mlx::core {
namespace { namespace {
// NaN-aware comparator that places NaNs at the end
template <typename T>
bool nan_aware_less(T a, T b) {
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
if (std::isnan(a))
return false;
if (std::isnan(b))
return true;
}
return a < b;
}
template <typename T> template <typename T>
struct StridedIterator { struct StridedIterator {
using iterator_category = std::random_access_iterator_tag; using iterator_category = std::random_access_iterator_tag;
@@ -27,7 +39,7 @@ struct StridedIterator {
StridedIterator() = default; StridedIterator() = default;
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0) explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
: ptr_(ptr + offset * stride), stride_(stride) {} : stride_(stride), ptr_(ptr + offset * stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0) explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {} : StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
@@ -130,7 +142,7 @@ void sort(array& out, int axis) {
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed); std::stable_sort(st, ed, nan_aware_less<T>);
src_it.step(); src_it.step();
} }
} }
@@ -184,6 +196,15 @@ void argsort(const array& in, array& out, int axis) {
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }
@@ -219,7 +240,7 @@ void partition(array& out, int axis, int kth) {
StridedIterator md(data_ptr, axis_stride, kth); StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size); StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed); std::nth_element(st, md, ed, nan_aware_less<T>);
} }
} }
@@ -276,6 +297,15 @@ void argpartition(const array& in, array& out, int axis, int kth) {
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride]; auto v2 = data_ptr[b * in_stride];
// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
return true;
}
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }

View File

@@ -83,8 +83,6 @@ void svd_impl(
auto jobz = (u_ptr) ? "A" : "N"; auto jobz = (u_ptr) ? "A" : "N";
// Will contain the number of singular values after the call has returned.
int ns = 0;
T workspace_dimension = 0; T workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not // Will contain the indices of eigenvectors that failed to converge (not

View File

@@ -24,9 +24,9 @@ void unary_op(const array& a, array& out, Op) {
auto ndim = a.ndim(); auto ndim = a.ndim();
if (a.flags().contiguous) { if (a.flags().contiguous) {
auto size = a.data_size(); auto size = a.data_size();
constexpr int N = simd::max_size<T>; constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
while (size >= N) { while (size >= N) {
simd::store(dst, Op{}(simd::load<T, N>(src))); simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
size -= N; size -= N;
src += N; src += N;
dst += N; dst += N;

View File

@@ -77,7 +77,8 @@ struct Real {
struct Sigmoid { struct Sigmoid {
template <int N, typename T> template <int N, typename T>
Simd<T, N> operator()(Simd<T, N> x) { Simd<T, N> operator()(Simd<T, N> x) {
return 1.0f / (1.0f + simd::exp(-x)); auto y = 1.0f / (1.0f + simd::exp(simd::abs(x)));
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
} }
SINGLE() SINGLE()
}; };
@@ -107,4 +108,73 @@ struct Square {
SINGLE() SINGLE()
}; };
template <int N>
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
return *(Simd<float, N>*)(&x);
}
template <int N>
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
return *(Simd<uint32_t, N>*)(&x);
}
struct ToFP8 {
template <typename T, int N>
Simd<uint8_t, N> operator()(Simd<T, N> f) {
uint32_t fp8_max = 543 << 21;
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
Simd<uint32_t, N> f_bits;
Simd<float, N> f32 = f;
f_bits = fp32_to_bits(f32);
Simd<uint8_t, N> result = 0u;
auto sign = f_bits & 0x80000000;
f_bits = f_bits ^ sign;
auto f_bits_low =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
result = select(f_bits < (121 << 23), result_low, result_high);
auto result_sat = Simd<uint8_t, N>(0x7E);
result = select(f_bits >= fp8_max, result_sat, result);
return result | Simd<uint8_t, N>(sign >> 24);
}
template <typename T>
uint8_t operator()(T x) {
return (*this)(Simd<T, 1>(x)).value;
}
};
struct FromFP8 {
template <int N>
Simd<float, N> operator()(Simd<uint8_t, N> x) {
auto w = Simd<uint32_t, N>(x) << 24;
auto sign = w & 0x80000000;
auto nonsign = w & 0x7FFFFFFF;
auto renorm_shift = clz(nonsign);
renorm_shift = simd::select(
renorm_shift > Simd<uint32_t, N>{4},
renorm_shift - Simd<uint32_t, N>{4},
Simd<uint32_t, N>{0});
Simd<int32_t, N> inf_nan_mask =
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
auto result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
float operator()(uint8_t x) {
return (*this)(Simd<uint8_t, 1>(x)).value;
}
};
} // namespace mlx::core::detail } // namespace mlx::core::detail

View File

@@ -32,6 +32,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
@@ -51,12 +52,19 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
# fp4 is not available on < 12.8
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
target_sources( target_sources(
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu) mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
@@ -170,7 +178,6 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers. # Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>) --diag_suppress=997>)
# Install CCCL headers for JIT. # Install CCCL headers for JIT.
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl) DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)

View File

@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/utils.h" #include "mlx/utils.h"
@@ -30,15 +31,20 @@ SmallSizePool::SmallSizePool() {
next_free_ = buffer_; next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size)); CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
for (int i = 0; i < device_count; ++i) {
#if CUDART_VERSION >= 13000 #if CUDART_VERSION >= 13000
cudaMemLocation loc; cudaMemLocation loc;
loc.type = cudaMemLocationTypeDevice; loc.type = cudaMemLocationTypeDevice;
loc.id = 0; loc.id = i;
#else #else
int loc = 0; int loc = i;
#endif // CUDART_VERSION >= 13000 #endif // CUDART_VERSION >= 13000
CHECK_CUDA_ERROR( CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc)); cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
}
auto curr = next_free_; auto curr = next_free_;
for (size_t i = 1; i < num_blocks; ++i) { for (size_t i = 1; i < num_blocks; ++i) {
@@ -62,6 +68,7 @@ CudaBuffer* SmallSizePool::malloc() {
next_free_ = next_free_->next; next_free_ = next_free_->next;
b->buf.data = static_cast<char*>(data_) + i * small_block_size; b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size; b->buf.size = small_block_size;
b->buf.device = -1;
return &b->buf; return &b->buf;
} }
@@ -83,16 +90,41 @@ CudaAllocator::CudaAllocator()
page_size, page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) { [this](CudaBuffer* buf) { cuda_free(buf); }) {
// TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8; memory_limit_ = total * 0.95;
max_pool_size_ = memory_limit_; max_pool_size_ = memory_limit_;
int device_count = 0;
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
int curr;
CHECK_CUDA_ERROR(cudaGetDevice(&curr));
for (int i = 0; i < device_count; ++i) {
CHECK_CUDA_ERROR(cudaSetDevice(i));
cudaStream_t s;
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
free_streams_.push_back(s);
}
CHECK_CUDA_ERROR(cudaSetDevice(curr));
} }
Buffer CudaAllocator::malloc(size_t size) { void copy_to_managed(CudaBuffer& buf) {
// TODO maybe make this async on a i/o stream to avoid synchronizing the
// device on malloc/and free
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, buf.size));
buf.device = -1;
CHECK_CUDA_ERROR(cudaMemcpy(new_data, buf.data, buf.size, cudaMemcpyDefault));
CHECK_CUDA_ERROR(cudaFree(buf.data));
buf.data = new_data;
}
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
if (size == 0) {
return Buffer{new CudaBuffer{nullptr, 0, -1}};
}
// Find available buffer from cache. // Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
if (size <= small_block_size) { if (size <= small_block_size) {
size = 8; size = 8;
@@ -102,6 +134,11 @@ Buffer CudaAllocator::malloc(size_t size) {
size = page_size * ((size + page_size - 1) / page_size); size = page_size * ((size + page_size - 1) / page_size);
} }
int device = -1;
if (size > small_block_size && stream != nullptr) {
CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device));
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) { if (!buf) {
// If we have a lot of memory pressure try to reclaim memory from the cache. // If we have a lot of memory pressure try to reclaim memory from the cache.
@@ -117,8 +154,13 @@ Buffer CudaAllocator::malloc(size_t size) {
} }
lock.unlock(); lock.unlock();
if (!buf) { if (!buf) {
buf = new CudaBuffer{nullptr, size}; buf = new CudaBuffer{nullptr, size, device};
cudaError_t err = cudaMallocManaged(&buf->data, size); cudaError_t err;
if (device == -1) {
err = cudaMallocManaged(&buf->data, size);
} else {
err = cudaMallocAsync(&buf->data, size, stream);
}
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err))); "cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
@@ -126,21 +168,37 @@ Buffer CudaAllocator::malloc(size_t size) {
} }
lock.lock(); lock.lock();
} }
active_memory_ += size; active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_); peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit. // Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) { if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
} }
// Copy to managed here if the buffer is not on the right device
if (buf->device != device) {
copy_to_managed(*buf);
}
return Buffer{buf}; return Buffer{buf};
} }
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
return malloc_impl(size, stream);
}
Buffer CudaAllocator::malloc(size_t size) {
return malloc_impl(size, nullptr);
}
void CudaAllocator::free(Buffer buffer) { void CudaAllocator::free(Buffer buffer) {
auto* buf = static_cast<CudaBuffer*>(buffer.ptr()); auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
if (!buf) { if (!buf) {
return; return;
} }
if (buf->size == 0) {
delete buf;
return;
}
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
active_memory_ -= buf->size; active_memory_ -= buf->size;
@@ -164,7 +222,11 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
if (scalar_pool_.in_pool(buf)) { if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf); scalar_pool_.free(buf);
} else { } else {
cudaFree(buf->data); if (buf->device >= 0) {
cudaFreeAsync(buf->data, free_streams_[buf->device]);
} else {
cudaFree(buf->data);
}
delete buf; delete buf;
} }
} }
@@ -215,6 +277,16 @@ CudaAllocator& allocator() {
return *allocator_; return *allocator_;
} }
Buffer malloc_async(size_t size, cudaStream_t stream) {
auto buffer = allocator().malloc_async(size, stream);
if (size && !buffer.ptr()) {
std::ostringstream msg;
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
throw std::runtime_error(msg.str());
}
return buffer;
}
} // namespace cu } // namespace cu
namespace allocator { namespace allocator {
@@ -227,7 +299,11 @@ void* Buffer::raw_ptr() {
if (!ptr_) { if (!ptr_) {
return nullptr; return nullptr;
} }
return static_cast<cu::CudaBuffer*>(ptr_)->data; auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
if (cbuf.device != -1) {
copy_to_managed(cbuf);
}
return cbuf.data;
} }
} // namespace allocator } // namespace allocator

View File

@@ -4,7 +4,9 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h" #include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/cuda/cuda_utils.h"
#include <cuda_runtime.h>
#include <mutex> #include <mutex>
#include <set> #include <set>
#include <utility> #include <utility>
@@ -17,6 +19,7 @@ using allocator::Buffer;
struct CudaBuffer { struct CudaBuffer {
void* data; void* data;
size_t size; size_t size;
int device; // -1 for managed
}; };
class SmallSizePool { class SmallSizePool {
@@ -45,6 +48,7 @@ class SmallSizePool {
class CudaAllocator : public allocator::Allocator { class CudaAllocator : public allocator::Allocator {
public: public:
Buffer malloc(size_t size) override; Buffer malloc(size_t size) override;
Buffer malloc_async(size_t size, cudaStream_t stream);
void free(Buffer buffer) override; void free(Buffer buffer) override;
size_t size(Buffer buffer) const override; size_t size(Buffer buffer) const override;
@@ -58,6 +62,7 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache(); void clear_cache();
private: private:
Buffer malloc_impl(size_t size, cudaStream_t stream);
void cuda_free(CudaBuffer* buf); void cuda_free(CudaBuffer* buf);
CudaAllocator(); CudaAllocator();
@@ -69,9 +74,12 @@ class CudaAllocator : public allocator::Allocator {
BufferCache<CudaBuffer> buffer_cache_; BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
std::vector<cudaStream_t> free_streams_;
SmallSizePool scalar_pool_; SmallSizePool scalar_pool_;
}; };
CudaAllocator& allocator(); CudaAllocator& allocator();
Buffer malloc_async(size_t size, cudaStream_t stream);
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -41,9 +41,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(stream()); auto& encoder = cu::get_command_encoder(stream());
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
@@ -58,7 +57,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks, num_blocks,
block_dims, block_dims,
0, 0,
out.data<OutType>(), gpu_ptr<OutType>(out),
out.data_size(), out.data_size(),
static_cast<CTYPE>(start_), static_cast<CTYPE>(start_),
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_)); static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));

View File

@@ -140,8 +140,10 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ArgReduce::eval_gpu"); nvtx3::scoped_range r("ArgReduce::eval_gpu");
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
// Prepare the shapes, strides and axis arguments. // Prepare the shapes, strides and axis arguments.
Shape shape = remove_index(in.shape(), axis_); Shape shape = remove_index(in.shape(), axis_);
@@ -154,7 +156,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int32_t ndim = shape.size(); int32_t ndim = shape.size();
// ArgReduce. // ArgReduce.
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
@@ -172,8 +173,8 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
num_blocks, num_blocks,
block_dim(), block_dim(),
0, 0,
in.data<T>(), gpu_ptr<T>(in),
out.data<uint32_t>(), gpu_ptr<uint32_t>(out),
out.size(), out.size(),
const_param(shape), const_param(shape),
const_param(in_strides), const_param(in_strides),

View File

@@ -292,9 +292,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y}, {num_blocks_x, num_blocks_y},
block_dims, block_dims,
0, 0,
a.data<InType>(), gpu_ptr<InType>(a),
b.data<InType>(), gpu_ptr<InType>(b),
out.data<OutType>(), gpu_ptr<OutType>(out),
rest, rest,
const_param<dims_constant()>(shape), const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides), const_param<dims_constant()>(a_strides),
@@ -310,9 +310,9 @@ void binary_op_gpu_inplace(
{num_blocks_x, num_blocks_y}, {num_blocks_x, num_blocks_y},
block_dims, block_dims,
0, 0,
a.data<InType>(), gpu_ptr<InType>(a),
b.data<InType>(), gpu_ptr<InType>(b),
out.data<OutType>(), gpu_ptr<OutType>(out),
rest, rest,
const_param(shape), const_param(shape),
const_param(a_strides), const_param(a_strides),
@@ -339,9 +339,9 @@ void binary_op_gpu_inplace(
num_blocks, num_blocks,
block_dims, block_dims,
0, 0,
a.data<InType>(), gpu_ptr<InType>(a),
b.data<InType>(), gpu_ptr<InType>(b),
out.data<OutType>(), gpu_ptr<OutType>(out),
out.data_size()); out.data_size());
}); });
} }
@@ -365,7 +365,11 @@ void binary_op_gpu(
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt); auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
binary_op_gpu_inplace<Op>(inputs, out, op, s); binary_op_gpu_inplace<Op>(inputs, out, op, s);
} }

View File

@@ -245,14 +245,18 @@ void binary_two_op_gpu_inplace(
auto& out_a = outputs[0]; auto& out_a = outputs[0];
auto& out_b = outputs[1]; auto& out_b = outputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out_a, bopt); auto& encoder = cu::get_command_encoder(s);
set_binary_op_output_data(a, b, out_b, bopt); set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
if (out_a.size() == 0) { if (out_a.size() == 0) {
return; return;
} }
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out_a); encoder.set_output_array(out_a);
@@ -313,10 +317,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y}, {num_blocks_x, num_blocks_y},
block_dims, block_dims,
0, 0,
a.data<InType>(), gpu_ptr<InType>(a),
b.data<InType>(), gpu_ptr<InType>(b),
out_a.data<OutType>(), gpu_ptr<OutType>(out_a),
out_b.data<OutType>(), gpu_ptr<OutType>(out_b),
rest, rest,
const_param<dims_constant()>(shape), const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides), const_param<dims_constant()>(a_strides),
@@ -332,10 +336,10 @@ void binary_two_op_gpu_inplace(
{num_blocks_x, num_blocks_y}, {num_blocks_x, num_blocks_y},
block_dims, block_dims,
0, 0,
a.data<InType>(), gpu_ptr<InType>(a),
b.data<InType>(), gpu_ptr<InType>(b),
out_a.data<OutType>(), gpu_ptr<OutType>(out_a),
out_b.data<OutType>(), gpu_ptr<OutType>(out_b),
rest, rest,
const_param(shape), const_param(shape),
const_param(a_strides), const_param(a_strides),
@@ -366,10 +370,10 @@ void binary_two_op_gpu_inplace(
num_blocks, num_blocks,
block_dims, block_dims,
0, 0,
a.data<InType>(), gpu_ptr<InType>(a),
b.data<InType>(), gpu_ptr<InType>(b),
out_a.data<OutType>(), gpu_ptr<OutType>(out_a),
out_b.data<OutType>(), gpu_ptr<OutType>(out_b),
out_a.data_size()); out_a.data_size());
}); });
} }

View File

@@ -293,8 +293,13 @@ void Compiled::eval_gpu(
} }
} }
auto& encoder = cu::get_command_encoder(s);
// Put outputs. // Put outputs.
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); compiled_allocate_outputs(
inputs, outputs, is_constant_, contiguous, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
for (auto& x : outputs) { for (auto& x : outputs) {
args.append(x); args.append(x);
} }
@@ -324,7 +329,6 @@ void Compiled::eval_gpu(
kernel_name += fmt::format( kernel_name += fmt::format(
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread); "_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
} }
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) { for (const auto& in : inputs) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }
@@ -332,9 +336,9 @@ void Compiled::eval_gpu(
encoder.set_output_array(out); encoder.set_output_array(out);
} }
auto kernel = mod.get_kernel(kernel_name); auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread); get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }

View File

@@ -47,7 +47,7 @@ auto& conv_cache() {
std::pair< std::pair<
cudnnBackendDescriptorType_t, cudnnBackendDescriptorType_t,
std::optional<cudnn_frontend::ExecutionPlan>>> std::optional<cudnn_frontend::ExecutionPlan>>>
cache(/* capacity */ 128); cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
return cache; return cache;
} }
@@ -270,17 +270,16 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
if (out_.size() == 0) { if (out_.size() == 0) {
return; return;
} }
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2); assert(inputs.size() == 2);
array in = inputs[0]; array in = inputs[0];
array wt = inputs[1]; array wt = inputs[1];
array out = out_; array out = out_;
out.set_data(allocator::malloc(out.nbytes())); out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
Dtype dtype = out.dtype(); Dtype dtype = out.dtype();
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
// Search cache. // Search cache.
ConvCacheKey cache_key{ ConvCacheKey cache_key{
encoder.device().cuda_device(), encoder.device().cuda_device(),
@@ -382,20 +381,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
} }
if (op_graph) { if (op_graph) {
// Setup inputs and outputs.
register_args(encoder, backend_type, in, wt, out, out_);
// Find a plan for the graph and execute it. // Find a plan for the graph and execute it.
auto plan = find_cudnn_plan_from_op_graph( auto plan = find_cudnn_plan_from_op_graph(
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph); encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
if (!plan) { if (plan) {
throw std::runtime_error("[conv] Unable to find an execution plan."); // Setup inputs and outputs.
} register_args(encoder, backend_type, in, wt, out, out_);
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
conv_cache().emplace( if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
cache_key, std::make_pair(backend_type, std::move(*plan))); conv_cache().emplace(
return; cache_key, std::make_pair(backend_type, std::move(*plan)));
return;
}
} }
} }

View File

@@ -86,7 +86,7 @@ array unfold_inputs_nd(
int mat_N, int mat_N,
ConvParams<NDIM>& params) { ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {}); array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes())); unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
encoder.add_temporary(unfolded); encoder.add_temporary(unfolded);
int filter_size = params.C; int filter_size = params.C;
@@ -118,8 +118,8 @@ array unfold_inputs_nd(
num_blocks, num_blocks,
block_dims, block_dims,
0, 0,
in.data<DataType>(), gpu_ptr<DataType>(in),
unfolded.data<DataType>(), gpu_ptr<DataType>(unfolded),
filter_size, filter_size,
out_pixels, out_pixels,
params); params);

View File

@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
int mat_N, int mat_N,
ConvParams<NDIM>& params) { ConvParams<NDIM>& params) {
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {}); array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
unfolded.set_data(allocator::malloc(unfolded.nbytes())); unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
encoder.add_temporary(unfolded); encoder.add_temporary(unfolded);
int filter_size = params.C; int filter_size = params.C;
@@ -121,8 +121,8 @@ array grouped_unfold_transpose_inputs_nd(
num_blocks, num_blocks,
block_dims, block_dims,
0, 0,
in.data<DataType>(), gpu_ptr<DataType>(in),
unfolded.data<DataType>(), gpu_ptr<DataType>(unfolded),
filter_size, filter_size,
out_pixels, out_pixels,
params); params);

View File

@@ -5,6 +5,22 @@
namespace mlx::core { namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
auto& encoder = cu::get_command_encoder(s);
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) {
return cu::malloc_async(n, encoder.stream());
});
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
copy_gpu_inplace(in, out, ctype, s);
}
void copy_gpu_inplace( void copy_gpu_inplace(
const array& in, const array& in,
array& out, array& out,
@@ -87,11 +103,31 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
out.set_data(allocator::malloc(out.nbytes()));
auto& encoder = cu::get_command_encoder(s); auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
encoder.set_input_array(in); encoder.set_input_array(in);
encoder.set_output_array(out); encoder.set_output_array(out);
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
} }
void reshape_gpu(const array& in, array& out, Stream s) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
make_contiguous_strides(in.shape()),
0,
0,
CopyType::General,
s);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -77,8 +77,8 @@ void copy_contiguous(
num_blocks, num_blocks,
block_dims, block_dims,
0, 0,
in.data<InType>() + in_offset, gpu_ptr<InType>(in) + in_offset,
out.data<OutType>() + out_offset, gpu_ptr<OutType>(out) + out_offset,
out.data_size()); out.data_size());
}); });
}); });

View File

@@ -106,8 +106,8 @@ void copy_general(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in; const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out; OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size(); int ndim = shape.size();
size_t data_size = 1; size_t data_size = 1;
for (auto& s : shape) for (auto& s : shape)

View File

@@ -69,8 +69,8 @@ void copy_general_dynamic(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in; const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out; OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size(); int ndim = shape.size();
if (ndim <= 3) { if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto dims_constant) { dispatch_1_2_3(ndim, [&](auto dims_constant) {
@@ -90,8 +90,8 @@ void copy_general_dynamic(
const_param<dims_constant()>(shape), const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in), const_param<dims_constant()>(strides_in),
const_param<dims_constant()>(strides_out), const_param<dims_constant()>(strides_out),
dynamic_offset_in.data<int64_t>(), gpu_ptr<int64_t>(dynamic_offset_in),
dynamic_offset_out.data<int64_t>()); gpu_ptr<int64_t>(dynamic_offset_out));
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto [num_blocks, block_dims] = get_launch_args(out, large()); auto [num_blocks, block_dims] = get_launch_args(out, large());
@@ -107,8 +107,8 @@ void copy_general_dynamic(
const_param(strides_in), const_param(strides_in),
const_param(strides_out), const_param(strides_out),
ndim, ndim,
dynamic_offset_in.data<int64_t>(), gpu_ptr<int64_t>(dynamic_offset_in),
dynamic_offset_out.data<int64_t>()); gpu_ptr<int64_t>(dynamic_offset_out));
} }
}); });
}); });

View File

@@ -92,8 +92,8 @@ void copy_general_input(
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in; const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out; OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
int ndim = shape.size(); int ndim = shape.size();
int work_per_thread = 1; int work_per_thread = 1;
auto dim0 = ndim > 0 ? shape.back() : 1; auto dim0 = ndim > 0 ? shape.back() : 1;

View File

@@ -0,0 +1,82 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_runtime.h>
namespace mlx::core {
// Throw exception if the cuda API does not succeed.
void check_cublas_error(const char* name, cublasStatus_t err);
void check_cuda_error(const char* name, cudaError_t err);
void check_cuda_error(const char* name, CUresult err);
// The macro version that prints the command that failed.
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
// Base class for RAII managed CUDA resources.
template <typename Handle, cudaError_t (*Destroy)(Handle)>
class CudaHandle {
public:
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
assert(this != &other);
other.handle_ = nullptr;
}
~CudaHandle() {
reset();
}
CudaHandle(const CudaHandle&) = delete;
CudaHandle& operator=(const CudaHandle&) = delete;
CudaHandle& operator=(CudaHandle&& other) {
assert(this != &other);
reset();
std::swap(handle_, other.handle_);
return *this;
}
void reset() {
if (handle_ != nullptr) {
CHECK_CUDA_ERROR(Destroy(handle_));
handle_ = nullptr;
}
}
operator Handle() const {
return handle_;
}
protected:
Handle handle_;
};
namespace cu {
class Device;
}; // namespace cu
// Wrappers of CUDA resources.
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
public:
using CudaHandle::CudaHandle;
explicit CudaGraph(cu::Device& device);
void end_capture(cudaStream_t stream);
};
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
public:
void instantiate(cudaGraph_t graph);
};
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
public:
explicit CudaStream(cu::Device& device);
};
} // namespace mlx::core

View File

@@ -132,14 +132,18 @@ bool prepare_cudnn_plan(
void** data_ptrs, void** data_ptrs,
F&& execute) { F&& execute) {
int workspace_size = plan.getWorkspaceSize(); int workspace_size = plan.getWorkspaceSize();
array workspace( void* workspace_ptr = nullptr;
workspace_size > 0 ? allocator::malloc(workspace_size) if (workspace_size > 0) {
: allocator::Buffer(nullptr), array workspace(
{workspace_size}, cu::malloc_async(workspace_size, encoder.stream()),
uint8); {workspace_size},
uint8);
encoder.add_temporary(workspace);
workspace_ptr = gpu_ptr<void>(workspace);
}
auto args = cudnn_frontend::VariantPackBuilder() auto args = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data<void>()) .setWorkspacePointer(workspace_ptr)
.setDataPointers(num_args, data_ptrs) .setDataPointers(num_args, data_ptrs)
.setUids(num_args, uids) .setUids(num_args, uids)
.build(); .build();
@@ -151,7 +155,6 @@ bool prepare_cudnn_plan(
return false; return false;
} }
encoder.add_temporary(workspace);
return true; return true;
} }
@@ -210,6 +213,9 @@ std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
Dtype dtype, Dtype dtype,
cudnn_frontend::OperationGraph& op_graph) { cudnn_frontend::OperationGraph& op_graph) {
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph); auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
if (engine_configs.empty()) {
return std::nullopt;
}
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph); return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
} }

View File

@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
@@ -23,7 +24,7 @@ class CommandEncoder;
// Return pointer alignment of |x|'s data. // Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) { inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1; uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>()); uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));
for (; alignment < 32; alignment *= 2) { for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) { if (address % (alignment * 2)) {
return alignment; return alignment;
@@ -56,7 +57,7 @@ inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
// Helpers used by get_data_ptrs to get pointers. // Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) { inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(arr.data<void>()); return const_cast<void*>(gpu_ptr<void>(arr));
} }
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>> template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>

View File

@@ -279,6 +279,7 @@ void CustomKernel::eval_gpu(
std::vector<array>& outputs) { std::vector<array>& outputs) {
nvtx3::scoped_range r("CustomKernel::eval_gpu"); nvtx3::scoped_range r("CustomKernel::eval_gpu");
auto& s = stream(); auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
std::vector<array> copies; std::vector<array> copies;
@@ -288,7 +289,7 @@ void CustomKernel::eval_gpu(
copies.emplace_back(init_value_.value(), out.dtype()); copies.emplace_back(init_value_.value(), out.dtype());
fill_gpu(copies.back(), out, s); fill_gpu(copies.back(), out, s);
} else { } else {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} }
} }
@@ -356,7 +357,6 @@ void CustomKernel::eval_gpu(
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz); dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel // Call the kernel
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : checked_inputs) { for (const auto& in : checked_inputs) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }

View File

@@ -14,10 +14,6 @@ namespace mlx::core::cu {
namespace { namespace {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
// This should be less than 255
constexpr int default_max_nodes_per_graph = 20;
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) #define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) { void check_cudnn_error(const char* name, cudnnStatus_t err) {
@@ -27,13 +23,6 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
} }
} }
int cuda_graph_cache_size() {
static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
}();
return cache_size;
}
bool use_cuda_graphs() { bool use_cuda_graphs() {
static bool use_graphs = []() { static bool use_graphs = []() {
return env::get_var("MLX_USE_CUDA_GRAPHS", true); return env::get_var("MLX_USE_CUDA_GRAPHS", true);
@@ -75,8 +64,8 @@ Device::~Device() {
void Device::make_current() { void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce // We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. This function assumes single-thread in host. // actual calls of CUDA APIs.
static int current = 0; static thread_local int current = 0;
if (current != device_) { if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_)); CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_; current = device_;
@@ -102,6 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
CommandEncoder::CaptureContext::~CaptureContext() { CommandEncoder::CaptureContext::~CaptureContext() {
if (!use_cuda_graphs()) { if (!use_cuda_graphs()) {
enc.node_count_++;
return; return;
} }
@@ -203,7 +193,8 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d), : device_(d),
stream_(d), stream_(d),
graph_(d), graph_(d),
graph_cache_(cuda_graph_cache_size()) {} worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) { void CommandEncoder::add_completed_handler(std::function<void()> task) {
worker_.add_task(std::move(task)); worker_.add_task(std::move(task));
@@ -227,12 +218,6 @@ void CommandEncoder::set_output_array(const array& arr) {
active_outputs_.push_back(id); active_outputs_.push_back(id);
} }
void CommandEncoder::maybe_commit() {
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
commit();
}
}
void CommandEncoder::add_kernel_node( void CommandEncoder::add_kernel_node(
void* func, void* func,
dim3 grid_dim, dim3 grid_dim,
@@ -240,6 +225,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes, uint32_t smem_bytes,
void** params) { void** params) {
if (!use_cuda_graphs()) { if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cudaLaunchKernel( CHECK_CUDA_ERROR(cudaLaunchKernel(
func, grid_dim, block_dim, params, smem_bytes, stream())); func, grid_dim, block_dim, params, smem_bytes, stream()));
return; return;
@@ -260,6 +246,7 @@ void CommandEncoder::add_kernel_node(
uint32_t smem_bytes, uint32_t smem_bytes,
void** params) { void** params) {
if (!use_cuda_graphs()) { if (!use_cuda_graphs()) {
node_count_++;
CHECK_CUDA_ERROR(cuLaunchKernel( CHECK_CUDA_ERROR(cuLaunchKernel(
func, func,
grid_dim.x, grid_dim.x,
@@ -302,22 +289,28 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
void CommandEncoder::add_graph_node(cudaGraph_t child) { void CommandEncoder::add_graph_node(cudaGraph_t child) {
if (!use_cuda_graphs()) { if (!use_cuda_graphs()) {
node_count_++;
CudaGraphExec graph_exec; CudaGraphExec graph_exec;
graph_exec.instantiate(child); graph_exec.instantiate(child);
device_.make_current(); device_.make_current();
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream())); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
return;
} }
cudaGraphNode_t node; cudaGraphNode_t node;
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
insert_graph_dependencies(GraphNode{node, 'G'}); insert_graph_dependencies(GraphNode{node, 'G'});
} }
int CommandEncoder::get_num_ops() {
return node_count_;
}
void CommandEncoder::commit() { void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit"); nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) { if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {}); add_completed_handler([temporaries = std::move(temporaries_)]() {});
} }
if (node_count_ > 0) { if (use_cuda_graphs() && node_count_ > 0) {
if (!from_nodes_.empty()) { if (!from_nodes_.empty()) {
CHECK_CUDA_ERROR(cudaGraphAddDependencies( CHECK_CUDA_ERROR(cudaGraphAddDependencies(
graph_, graph_,
@@ -360,7 +353,6 @@ void CommandEncoder::commit() {
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
// Reset state // Reset state
node_count_ = 0;
graph_node_count_ = 0; graph_node_count_ = 0;
empty_node_count_ = 0; empty_node_count_ = 0;
from_nodes_.clear(); from_nodes_.clear();
@@ -372,6 +364,7 @@ void CommandEncoder::commit() {
// Put completion handlers in a batch. // Put completion handlers in a batch.
worker_.commit(stream_); worker_.commit(stream_);
node_count_ = 0;
} }
void CommandEncoder::synchronize() { void CommandEncoder::synchronize() {

View File

@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/cuda/lru_cache.h"
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/stream.h" #include "mlx/stream.h"
@@ -83,7 +84,7 @@ class CommandEncoder {
} }
void add_completed_handler(std::function<void()> task); void add_completed_handler(std::function<void()> task);
void maybe_commit(); int get_num_ops();
void commit(); void commit();
Device& device() { Device& device() {
@@ -140,7 +141,7 @@ class Device {
Device(const Device&) = delete; Device(const Device&) = delete;
Device& operator=(const Device&) = delete; Device& operator=(const Device&) = delete;
// Make this device the current cuda device, required by some cuda calls. // Make this device the current cuda device, this method is thread-safe.
void make_current(); void make_current();
CommandEncoder& get_command_encoder(Stream s); CommandEncoder& get_command_encoder(Stream s);

View File

@@ -2,6 +2,8 @@
#pragma once #pragma once
#include <cuda_fp8.h>
#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
@@ -257,8 +259,8 @@ struct Round {
struct Sigmoid { struct Sigmoid {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
T y = 1 / (1 + exp(-abs(x))); T y = 1 / (1 + exp(abs(x)));
return (x < 0) ? 1 - y : y; return (x < 0) ? y : 1 - y;
} }
}; };
@@ -334,4 +336,17 @@ struct Tanh {
} }
}; };
struct ToFP8 {
template <typename T>
__device__ uint8_t operator()(T x) {
return __nv_fp8_e4m3(x).__x;
}
};
struct FromFP8 {
__device__ float operator()(uint8_t x) {
return float(*(__nv_fp8_e4m3*)(&x));
}
};
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -1,6 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
// This file must not include any host-only code, utilies that work under both // This file must not include any host-only code, utilities that work under both
// host and device can be put here. // host and device can be put here.
// //
// See more about the requirements at: // See more about the requirements at:
@@ -202,7 +202,7 @@ struct Limits<
} }
}; };
// CUDA 11 does not have host side arithmatic operators for half types. // CUDA 11 does not have host side arithmetic operators for half types.
template <typename T> template <typename T>
struct Limits< struct Limits<
T, T,

View File

@@ -15,8 +15,10 @@ void AllReduce::eval_gpu(
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(outputs.size() == 1); assert(outputs.size() == 1);
auto set_input_output = auto& s = stream();
[s = stream()](const array& in, array& out) -> std::pair<array, array> { auto& encoder = cu::get_command_encoder(s);
auto set_input_output = [&](const array& in,
array& out) -> std::pair<array, array> {
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
copy_gpu(in, out, CopyType::General, s); copy_gpu(in, out, CopyType::General, s);
return {out, out}; return {out, out};
@@ -24,19 +26,17 @@ void AllReduce::eval_gpu(
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
return {in, out}; return {in, out};
} else { } else {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
return {in, out}; return {in, out};
} }
}; };
auto [input, output] = set_input_output(inputs[0], outputs[0]); auto [input, output] = set_input_output(inputs[0], outputs[0]);
auto& encoder = cu::get_command_encoder(stream());
encoder.set_input_array(input); encoder.set_input_array(input);
encoder.set_output_array(output); encoder.set_output_array(output);
auto capture = encoder.capture_context(); auto capture = encoder.capture_context();
auto& s = stream();
switch (reduce_type_) { switch (reduce_type_) {
case Sum: case Sum:
@@ -53,4 +53,69 @@ void AllReduce::eval_gpu(
"Only all reduce sum, max, and min are supported."); "Only all reduce sum, max, and min are supported.");
} }
} }
void AllGather::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().row_contiguous) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);
auto capture = encoder.capture_context();
distributed::detail::all_gather(group(), input, outputs[0], s);
}
void ReduceScatter::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().row_contiguous) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);
auto capture = encoder.capture_context();
switch (reduce_type_) {
case Sum:
distributed::detail::sum_scatter(group(), input, outputs[0], s);
break;
default:
throw std::runtime_error("Only sum scatter is supported. ");
}
}
} // namespace mlx::core::distributed } // namespace mlx::core::distributed

View File

@@ -5,18 +5,24 @@
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/available.h" #include "mlx/backend/gpu/available.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
namespace mlx::core::gpu { namespace mlx::core::gpu {
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
constexpr int default_max_nodes_per_graph = 20;
bool is_available() { bool is_available() {
return true; return true;
} }
void new_stream(Stream s) { void new_stream(Stream s) {
// Force initalization of cuda, so cuda runtime get destroyed at last. // Force initalization of CUDA, so CUDA runtime get destroyed at last.
cudaFree(nullptr); cudaFree(nullptr);
// Make sure CUDA event pool get destroyed after device and stream.
cu::CudaEvent::init_pool();
// Ensure the static stream objects get created. // Ensure the static stream objects get created.
cu::get_command_encoder(s); cu::get_command_encoder(s);
} }
@@ -34,7 +40,8 @@ void eval(array& arr) {
arr.primitive().eval_gpu(arr.inputs(), outputs); arr.primitive().eval_gpu(arr.inputs(), outputs);
} }
auto& encoder = cu::get_command_encoder(arr.primitive().stream()); auto& stream = arr.primitive().stream();
auto& encoder = cu::get_command_encoder(stream);
// Keep used buffers alive until kernel finishes running. // Keep used buffers alive until kernel finishes running.
for (auto& in : arr.inputs()) { for (auto& in : arr.inputs()) {
// Except for the donated one. // Except for the donated one.
@@ -45,7 +52,14 @@ void eval(array& arr) {
for (auto& s : arr.siblings()) { for (auto& s : arr.siblings()) {
encoder.add_temporary(s); encoder.add_temporary(s);
} }
encoder.maybe_commit();
if (encoder.get_num_ops() >=
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
scheduler::notify_new_task(stream);
encoder.add_completed_handler(
[stream]() { scheduler::notify_task_completion(stream); });
encoder.commit();
}
} }
void finalize(Stream s) { void finalize(Stream s) {

View File

@@ -3,10 +3,12 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/event.h" #include "mlx/event.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
#include <map>
#include <vector>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
namespace mlx::core { namespace mlx::core {
@@ -17,104 +19,180 @@ namespace cu {
// CudaEvent implementations // CudaEvent implementations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Cuda event managed with RAII. namespace {
class CudaEventHandle {
// Manage cached cudaEvent_t objects.
class CudaEventPool {
public: public:
CudaEventHandle() { CudaEventHandle create(Device& d, int flags) {
CHECK_CUDA_ERROR(cudaEventCreateWithFlags( if (!on_creation_thread()) {
&event_, cudaEventDisableTiming | cudaEventBlockingSync)); return CudaEventHandle(d, flags);
}
auto& cache = cache_for(d, flags);
if (cache.empty()) {
return CudaEventHandle(d, flags);
} else {
CudaEventHandle ret = std::move(cache.back());
cache.pop_back();
return ret;
}
} }
~CudaEventHandle() { void release(CudaEventHandle event) {
CHECK_CUDA_ERROR(cudaEventDestroy(event_)); if (!on_creation_thread()) {
} // Event will be destroyed directly instead of getting moved to cache.
return;
CudaEventHandle(const CudaEventHandle&) = delete; }
CudaEventHandle& operator=(const CudaEventHandle&) = delete; cache_for(event.device, event.flags).push_back(std::move(event));
operator cudaEvent_t() const {
return event_;
} }
private: private:
cudaEvent_t event_; std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
return cache_[d.cuda_device()][flags];
}
bool on_creation_thread() {
return std::this_thread::get_id() == thread_id_;
}
// The CudaEvent may be created and destroyed on different threads (for
// example when waiting on GPU work in CPU stream), we don't want to make
// the cache thread-safe as it adds overhead, so we just skip cache when
// using events in worker threads.
std::thread::id thread_id_{std::this_thread::get_id()};
// {device: {flags: [events]}}
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
}; };
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {} CudaEventPool& cuda_event_pool() {
static CudaEventPool pool;
return pool;
}
} // namespace
CudaEventHandle::CudaEventHandle(Device& d, int flags)
: device(d), flags(flags) {
device.make_current();
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
assert(handle_ != nullptr);
}
CudaEvent::CudaEvent(Device& d, int flags)
: event_(cuda_event_pool().create(d, flags)) {}
CudaEvent::~CudaEvent() {
cuda_event_pool().release(std::move(event_));
}
void CudaEvent::wait() { void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait"); nvtx3::scoped_range r("cu::CudaEvent::wait");
if (!recorded_) { event_.device.make_current();
throw std::runtime_error("Should not wait on a CudaEvent before record."); cudaEventSynchronize(event_);
}
cudaEventSynchronize(*event_);
} }
void CudaEvent::wait(cudaStream_t stream) { void CudaEvent::wait(cudaStream_t stream) {
if (!recorded_) { event_.device.make_current();
throw std::runtime_error("Should not wait on a CudaEvent before record."); cudaStreamWaitEvent(stream, event_);
}
cudaStreamWaitEvent(stream, *event_);
}
void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); });
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
wait(enc.stream());
}
} }
void CudaEvent::record(cudaStream_t stream) { void CudaEvent::record(cudaStream_t stream) {
cudaEventRecord(*event_, stream); event_.device.make_current();
recorded_ = true; cudaEventRecord(event_, stream);
}
void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
record(enc.stream());
}
} }
bool CudaEvent::completed() const { bool CudaEvent::completed() const {
return cudaEventQuery(*event_) == cudaSuccess; // Note: cudaEventQuery can be safely called from any device.
return cudaEventQuery(event_) == cudaSuccess;
} }
// static
void CudaEvent::init_pool() {
cuda_event_pool();
}
// Wraps CudaEvent with a few features:
// 1. The class can be copied.
// 2. Make wait/record work with CPU streams.
// 3. Add checks for waiting on un-recorded event.
class CopyableCudaEvent {
public:
explicit CopyableCudaEvent(Device& d)
: event_(std::make_shared<CudaEvent>(
d,
cudaEventDisableTiming | cudaEventBlockingSync)) {}
void wait() {
event_->wait();
}
void wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable {
check_recorded();
event_->wait();
});
} else {
check_recorded();
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->wait(encoder.stream());
}
}
void record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on CPU stream.");
} else {
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->record(encoder.stream());
recorded_ = true;
}
}
bool is_signaled() const {
return recorded_ && event_->completed();
}
private:
void check_recorded() const {
if (!recorded_) {
throw std::runtime_error(
"Should not wait on a CudaEvent before recording.");
}
}
std::shared_ptr<CudaEvent> event_;
bool recorded_{false};
};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// SharedEvent implementations // AtomicEvent implementations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) { __host__ __device__ void event_wait(AtomicEvent::Atomic* ac, uint64_t value) {
uint64_t current; uint64_t current;
while ((current = ac->load()) < value) { while ((current = ac->load()) < value) {
ac->wait(current); ac->wait(current);
} }
} }
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) { __host__ __device__ void event_signal(AtomicEvent::Atomic* ac, uint64_t value) {
ac->store(value); ac->store(value);
ac->notify_all(); ac->notify_all();
} }
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) { __global__ void event_wait_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
event_wait(ac, value); event_wait(ac, value);
} }
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { __global__ void event_signal_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value); event_signal(ac, value);
} }
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) { AtomicEvent::AtomicEvent() {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
SharedEvent::SharedEvent() {
buf_ = std::shared_ptr<Buffer>( buf_ = std::shared_ptr<Buffer>(
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) { new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
allocator().free(*ptr); allocator().free(*ptr);
@@ -123,17 +201,17 @@ SharedEvent::SharedEvent() {
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0; *static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
} }
void SharedEvent::wait(uint64_t value) { void AtomicEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait"); nvtx3::scoped_range r("cu::AtomicEvent::wait");
event_wait(to_atomic(buf_), value); event_wait(atomic(), value);
} }
void SharedEvent::wait(cudaStream_t stream, uint64_t value) { void AtomicEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); event_wait_kernel<<<1, 1, 0, stream>>>(atomic(), value);
} }
void SharedEvent::wait(Stream s, uint64_t value) { void AtomicEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait(s)"); nvtx3::scoped_range r("cu::AtomicEvent::wait(s)");
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else { } else {
@@ -144,17 +222,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
} }
} }
void SharedEvent::signal(uint64_t value) { void AtomicEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal"); nvtx3::scoped_range r("cu::AtomicEvent::signal");
event_signal(to_atomic(buf_), value); event_signal(atomic(), value);
} }
void SharedEvent::signal(cudaStream_t stream, uint64_t value) { void AtomicEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); event_signal_kernel<<<1, 1, 0, stream>>>(atomic(), value);
} }
void SharedEvent::signal(Stream s, uint64_t value) { void AtomicEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); nvtx3::scoped_range r("cu::AtomicEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
// Signal through a GPU stream so the atomic is updated in GPU - updating // Signal through a GPU stream so the atomic is updated in GPU - updating
// the atomic in CPU sometimes does not get GPU notified. // the atomic in CPU sometimes does not get GPU notified.
@@ -168,14 +246,14 @@ void SharedEvent::signal(Stream s, uint64_t value) {
} }
} }
bool SharedEvent::is_signaled(uint64_t value) const { bool AtomicEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); nvtx3::scoped_range r("cu::AtomicEvent::is_signaled");
return to_atomic(buf_)->load() >= value; return atomic()->load() >= value;
} }
uint64_t SharedEvent::value() const { uint64_t AtomicEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value"); nvtx3::scoped_range r("cu::AtomicEvent::value");
return to_atomic(buf_)->load(); return atomic()->load();
} }
} // namespace cu } // namespace cu
@@ -188,14 +266,14 @@ namespace {
struct EventImpl { struct EventImpl {
// CudaEvent is preferred when possible because it is fast, however we have // CudaEvent is preferred when possible because it is fast, however we have
// to fallback to SharedEvent in following cases: // to fallback to AtomicEvent in following cases:
// 1. the event is used to wait/signal a cpu stream; // 1. the event is used to wait/signal a cpu stream;
// 2. signal value other than 1 has been specified. // 2. signal value other than 1 has been specified.
std::unique_ptr<cu::CudaEvent> cuda; std::unique_ptr<cu::CopyableCudaEvent> cuda;
std::unique_ptr<cu::SharedEvent> shared; std::unique_ptr<cu::AtomicEvent> atomic;
bool is_created() const { bool is_created() const {
return cuda || shared; return cuda || atomic;
} }
void ensure_created(Stream s, uint64_t signal_value) { void ensure_created(Stream s, uint64_t signal_value) {
@@ -203,10 +281,10 @@ struct EventImpl {
return; return;
} }
if (s.device == mlx::core::Device::cpu || signal_value > 1) { if (s.device == mlx::core::Device::cpu || signal_value > 1) {
nvtx3::mark("Using slow SharedEvent"); nvtx3::mark("Using slow AtomicEvent");
shared = std::make_unique<cu::SharedEvent>(); atomic = std::make_unique<cu::AtomicEvent>();
} else { } else {
cuda = std::make_unique<cu::CudaEvent>(); cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
} }
} }
}; };
@@ -225,7 +303,7 @@ void Event::wait() {
assert(value() == 1); assert(value() == 1);
event->cuda->wait(); event->cuda->wait();
} else { } else {
event->shared->wait(value()); event->atomic->wait(value());
} }
} }
@@ -236,7 +314,7 @@ void Event::wait(Stream s) {
assert(value() == 1); assert(value() == 1);
event->cuda->wait(s); event->cuda->wait(s);
} else { } else {
event->shared->wait(s, value()); event->atomic->wait(s, value());
} }
} }
@@ -247,7 +325,7 @@ void Event::signal(Stream s) {
assert(value() == 1); assert(value() == 1);
event->cuda->record(s); event->cuda->record(s);
} else { } else {
event->shared->signal(s, value()); event->atomic->signal(s, value());
} }
} }
@@ -258,9 +336,9 @@ bool Event::is_signaled() const {
} }
if (event->cuda) { if (event->cuda) {
assert(value() == 1); assert(value() == 1);
return event->cuda->recorded() && event->cuda->completed(); return event->cuda->is_signaled();
} else { } else {
return event->shared->is_signaled(value()); return event->atomic->is_signaled(value());
} }
} }

View File

@@ -3,49 +3,60 @@
#pragma once #pragma once
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/stream.h" #include "mlx/stream.h"
#include <memory>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda/atomic> #include <cuda/atomic>
#include <memory>
namespace mlx::core::cu { namespace mlx::core::cu {
class CudaEventHandle; class Device;
// RAII-managed move-only wrapper of cudaEvent_t.
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
CudaEventHandle(Device& d, int flags);
Device& device;
int flags;
};
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait // Wrapper of native cuda event. It can synchronize between GPU streams, or wait
// on GPU stream in CPU stream, but can not wait on CPU stream. // on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent { class CudaEvent {
public: public:
CudaEvent(); CudaEvent(Device& d, int flags);
~CudaEvent();
CudaEvent(CudaEvent&&) = default;
CudaEvent& operator=(CudaEvent&&) = default;
CudaEvent(const CudaEvent&) = delete;
CudaEvent& operator=(const CudaEvent&) = delete;
void wait(); void wait();
void wait(cudaStream_t stream); void wait(cudaStream_t stream);
void wait(Stream s);
void record(cudaStream_t stream); void record(cudaStream_t stream);
void record(Stream s);
// Return whether the recorded kernels have completed. Note that this method // Return whether the recorded kernels have completed. Note that this method
// returns true if record() has not been called. // returns true if record() has not been called.
bool completed() const; bool completed() const;
bool recorded() const { // Internal: make sure event pool is initialized.
return recorded_; static void init_pool();
}
private: private:
bool recorded_{false}; CudaEventHandle event_;
std::shared_ptr<CudaEventHandle> event_;
}; };
// Event that can synchronize between CPU and GPU. It is much slower than // Event that can synchronize between CPU and GPU. It is much slower than
// CudaEvent so the latter should always be preferred when possible. // CudaEvent so the latter should always be preferred when possible.
class SharedEvent { class AtomicEvent {
public: public:
using Atomic = cuda::atomic<uint64_t>; using Atomic = cuda::atomic<uint64_t>;
SharedEvent(); AtomicEvent();
void wait(uint64_t value); void wait(uint64_t value);
void wait(cudaStream_t stream, uint64_t value); void wait(cudaStream_t stream, uint64_t value);
@@ -57,7 +68,11 @@ class SharedEvent {
uint64_t value() const; uint64_t value() const;
private: private:
std::shared_ptr<mlx::core::allocator::Buffer> buf_; Atomic* atomic() const {
return static_cast<AtomicEvent::Atomic*>(buf_->raw_ptr());
}
std::shared_ptr<allocator::Buffer> buf_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -1,13 +1,15 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/fence.h" #include "mlx/fence.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
namespace mlx::core { namespace mlx::core {
struct FenceImpl { struct FenceImpl {
uint32_t count; uint32_t count;
cu::SharedEvent event; cu::AtomicEvent event;
}; };
Fence::Fence(Stream s) { Fence::Fence(Stream s) {
@@ -20,8 +22,24 @@ void Fence::wait(Stream s, const array&) {
fence->event.wait(fence->count); fence->event.wait(fence->count);
} }
void Fence::update(Stream s, const array&) { void Fence::update(Stream s, const array& a, bool cross_device) {
auto* fence = static_cast<FenceImpl*>(fence_.get()); auto* fence = static_cast<FenceImpl*>(fence_.get());
if (cross_device) {
// Move to managed memory if there is a device switch
auto& cbuf =
*static_cast<cu::CudaBuffer*>(const_cast<array&>(a).buffer().ptr());
if (cbuf.device != -1) {
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
cbuf.device = -1;
auto& encoder = cu::device(s.device).get_command_encoder(s);
encoder.commit();
CHECK_CUDA_ERROR(cudaMemcpyAsync(
new_data, cbuf.data, cbuf.size, cudaMemcpyDefault, encoder.stream()));
CHECK_CUDA_ERROR(cudaFreeAsync(cbuf.data, encoder.stream()));
cbuf.data = new_data;
}
}
fence->count++; fence->count++;
fence->event.signal(s, fence->count); fence->event.signal(s, fence->count);
} }

View File

@@ -50,8 +50,10 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F; : CUBLAS_COMPUTE_32F;
case float64: case float64:
case complex64:
return CUBLAS_COMPUTE_64F; return CUBLAS_COMPUTE_64F;
case complex64:
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
: CUBLAS_COMPUTE_32F;
default: default:
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype))); "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
@@ -126,12 +128,13 @@ CublasGemm::CublasGemm(
N_(b_cols) { N_(b_cols) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto scale_type = dtype_to_cublas_type(dtype); scale_type_ = dtype_to_cublas_type(dtype);
if (dtype == bfloat16 || dtype == float16) { if (dtype == bfloat16 || dtype == float16) {
scale_type = CUDA_R_32F; scale_type_ = CUDA_R_32F;
} }
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), scale_type)); &matmul_desc_, dtype_to_compute_type(dtype), scale_type_));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
@@ -230,15 +233,20 @@ void CublasGemm::set_out(
batch_stride); batch_stride);
} }
void CublasGemm::set_bias(void* bias) { void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
encoder.set_input_array(bias);
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, matmul_desc_,
CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, &epilogue,
sizeof(epilogue))); sizeof(epilogue)));
auto* bias_ptr = gpu_ptr<void>(bias);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); matmul_desc_,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr,
sizeof(bias_ptr)));
} }
void CublasGemm::run( void CublasGemm::run(
@@ -270,9 +278,9 @@ void CublasGemm::run(
execute( execute(
encoder, encoder,
out.data<void>(), gpu_ptr<void>(out),
a.data<void>(), gpu_ptr<void>(a),
b.data<void>(), gpu_ptr<void>(b),
nullptr, nullptr,
alpha); alpha);
} }
@@ -313,10 +321,10 @@ void CublasGemm::run(
execute( execute(
encoder, encoder,
out.data<void>(), gpu_ptr<void>(out),
a.data<void>(), gpu_ptr<void>(a),
b.data<void>(), gpu_ptr<void>(b),
c.data<void>(), gpu_ptr<void>(c),
alpha, alpha,
beta); beta);
} }
@@ -347,28 +355,38 @@ void CublasGemm::execute(
} }
} }
const void* alpha_ptr = &alpha;
const void* beta_ptr = &beta;
complex64_t alpha_c, beta_c;
if (scale_type_ == CUDA_C_32F) {
alpha_c = complex64_t{alpha, 0.0f};
beta_c = complex64_t{beta, 0.0f};
alpha_ptr = &alpha_c;
beta_ptr = &beta_c;
}
void* workspace_ptr = nullptr; void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) { if (heuristic_.workspaceSize > 0) {
// Ensure workspace is 256-byte aligned // Ensure workspace is 256-byte aligned
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256; int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
array workspace( array workspace(
allocator::malloc(nbytes), cu::malloc_async(nbytes, encoder.stream()),
{static_cast<int>(heuristic_.workspaceSize)}, {static_cast<int>(heuristic_.workspaceSize)},
int8); int8);
encoder.add_temporary(workspace); encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>(); workspace_ptr = gpu_ptr<void>(workspace);
} }
auto capture = encoder.capture_context(); auto capture = encoder.capture_context();
CHECK_CUBLAS_ERROR(cublasLtMatmul( CHECK_CUBLAS_ERROR(cublasLtMatmul(
handle_, handle_,
matmul_desc_, matmul_desc_,
&alpha, alpha_ptr,
b, // a and b are swapped b, // a and b are swapped
a_desc_, a_desc_,
a, a,
b_desc_, b_desc_,
&beta, beta_ptr,
c ? c : out, c ? c : out,
c ? c_desc_ : out_desc_, c ? c_desc_ : out_desc_,
out, out,

View File

@@ -55,7 +55,7 @@ class CublasGemm {
int32_t batch_count, int32_t batch_count,
int64_t batch_stride); int64_t batch_stride);
void set_bias(void* bias); void set_bias(cu::CommandEncoder& encoder, const array& bias);
void run( void run(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,
@@ -115,6 +115,7 @@ class CublasGemm {
uint64_t M_; uint64_t M_;
uint64_t N_; uint64_t N_;
cudaDataType_t scale_type_;
cublasLtMatmulPreference_t pref_{nullptr}; cublasLtMatmulPreference_t pref_{nullptr};
cublasLtHandle_t handle_{nullptr}; cublasLtHandle_t handle_{nullptr};
cublasLtMatmulDesc_t matmul_desc_{nullptr}; cublasLtMatmulDesc_t matmul_desc_{nullptr};

View File

@@ -25,9 +25,10 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
execute( execute(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, gpu_ptr<int8_t>(out) +
a.data<int8_t>() + a.itemsize() * a_it.loc, out.itemsize() * i * batch_shape.back() * M_ * N_,
b.data<int8_t>() + b.itemsize() * b_it.loc, gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
nullptr, nullptr,
alpha); alpha);
a_it.step(); a_it.step();
@@ -60,10 +61,11 @@ void CublasGemm::run_batched(
for (size_t i = 0; i < nbatch; ++i) { for (size_t i = 0; i < nbatch; ++i) {
execute( execute(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_, gpu_ptr<int8_t>(out) +
a.data<int8_t>() + a.itemsize() * a_it.loc, out.itemsize() * i * batch_shape.back() * M_ * N_,
b.data<int8_t>() + b.itemsize() * b_it.loc, gpu_ptr<int8_t>(a) + a.itemsize() * a_it.loc,
c.data<int8_t>() + c.itemsize() * c_it.loc, gpu_ptr<int8_t>(b) + b.itemsize() * b_it.loc,
gpu_ptr<int8_t>(c) + c.itemsize() * c_it.loc,
alpha, alpha,
beta); beta);
a_it.step(); a_it.step();

View File

@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets // Launch kernel to set device offsets
auto pointers = array( auto pointers = array(
allocator::malloc(batch_count * sizeof(void*) * 3), cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
{batch_count * 3}, {batch_count * 3},
uint64); uint64);
@@ -183,10 +183,10 @@ void CublasGemm::run_batched(
num_blocks, num_blocks,
block_dims, block_dims,
0, 0,
pointers.data<int8_t*>(), gpu_ptr<int8_t*>(pointers),
a.data<int8_t>(), gpu_ptr<int8_t>(a),
b.data<int8_t>(), gpu_ptr<int8_t>(b),
out.data<int8_t>(), gpu_ptr<int8_t>(out),
item_size, item_size,
const_param<ndim_constant()>(batch_shape), const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides), const_param<ndim_constant()>(a_batch_strides),
@@ -200,10 +200,10 @@ void CublasGemm::run_batched(
num_blocks, num_blocks,
block_dims, block_dims,
0, 0,
pointers.data<int8_t*>(), gpu_ptr<int8_t*>(pointers),
a.data<int8_t>(), gpu_ptr<int8_t>(a),
b.data<int8_t>(), gpu_ptr<int8_t>(b),
out.data<int8_t>(), gpu_ptr<int8_t>(out),
item_size, item_size,
const_param(batch_shape), const_param(batch_shape),
const_param(a_batch_strides), const_param(a_batch_strides),
@@ -219,7 +219,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>(); auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count; auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count; auto out_pointers = b_pointers + batch_count;
execute( execute(
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
// Launch kernel to set device offsets // Launch kernel to set device offsets
auto pointers = array( auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4), cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
{batch_count * 4}, {batch_count * 4},
uint64); uint64);
@@ -271,11 +271,11 @@ void CublasGemm::run_batched(
num_blocks, num_blocks,
block_dims, block_dims,
0, 0,
pointers.data<int8_t*>(), gpu_ptr<int8_t*>(pointers),
a.data<int8_t>(), gpu_ptr<int8_t>(a),
b.data<int8_t>(), gpu_ptr<int8_t>(b),
c.data<int8_t>(), gpu_ptr<int8_t>(c),
out.data<int8_t>(), gpu_ptr<int8_t>(out),
item_size, item_size,
const_param<ndim_constant()>(batch_shape), const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides), const_param<ndim_constant()>(a_batch_strides),
@@ -290,11 +290,11 @@ void CublasGemm::run_batched(
num_blocks, num_blocks,
block_dims, block_dims,
0, 0,
pointers.data<int8_t*>(), gpu_ptr<int8_t*>(pointers),
a.data<int8_t>(), gpu_ptr<int8_t>(a),
b.data<int8_t>(), gpu_ptr<int8_t>(b),
c.data<int8_t>(), gpu_ptr<int8_t>(c),
out.data<int8_t>(), gpu_ptr<int8_t>(out),
item_size, item_size,
const_param(batch_shape), const_param(batch_shape),
const_param(a_batch_strides), const_param(a_batch_strides),
@@ -312,7 +312,7 @@ void CublasGemm::run_batched(
encoder.set_input_array(c); encoder.set_input_array(c);
encoder.set_output_array(out); encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>(); auto a_pointers = gpu_ptr<int8_t*>(pointers);
auto b_pointers = a_pointers + batch_count; auto b_pointers = a_pointers + batch_count;
auto c_pointers = b_pointers + batch_count; auto c_pointers = b_pointers + batch_count;
auto out_pointers = c_pointers + batch_count; auto out_pointers = c_pointers + batch_count;

View File

@@ -13,6 +13,37 @@ namespace cg = cooperative_groups;
static constexpr int rows_per_block = 8; static constexpr int rows_per_block = 8;
// Accumulator type selection per input element type T.
template <typename T>
struct GemvAccType {
using type = T;
};
template <>
struct GemvAccType<__half> {
using type = float;
};
template <>
struct GemvAccType<__nv_bfloat16> {
using type = float;
};
template <>
struct GemvAccType<float> {
using type = float;
};
template <>
struct GemvAccType<double> {
using type = double;
};
template <>
struct GemvAccType<cu::complex64_t> {
using type = cu::complex64_t;
};
template <typename T, int rows_per_block, int n_per_thread> template <typename T, int rows_per_block, int n_per_thread>
__device__ void __device__ void
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) { gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
@@ -24,7 +55,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
int row = g_idx.x * rows_per_block + t_idx.y; int row = g_idx.x * rows_per_block + t_idx.y;
if (row < rows) { if (row < rows) {
float sum = 0.0f; using Acc = typename GemvAccType<T>::type;
Acc sum = Acc(0);
for (int col = n_per_thread * warp.thread_rank(); col < cols; for (int col = n_per_thread * warp.thread_rank(); col < cols;
col += (WARP_SIZE * n_per_thread)) { col += (WARP_SIZE * n_per_thread)) {
auto local_mat = auto local_mat =
@@ -32,12 +64,11 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0); auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
#pragma unroll #pragma unroll
for (int j = 0; j < n_per_thread; ++j) { for (int j = 0; j < n_per_thread; ++j) {
sum += sum += static_cast<Acc>(local_mat[j]) * static_cast<Acc>(local_vec[j]);
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
} }
} }
sum = cg::reduce(warp, sum, cg::plus<float>{}); sum = cg::reduce(warp, sum, cg::plus<Acc>{});
if (warp.thread_rank() == 0) { if (warp.thread_rank() == 0) {
out[row] = static_cast<T>(sum); out[row] = static_cast<T>(sum);
} }
@@ -107,7 +138,7 @@ void gemv(
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) { dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dim3 block_dims{WARP_SIZE, rows_per_block}; dim3 block_dims{WARP_SIZE, rows_per_block};
const DataType* mat; const DataType* mat;
@@ -118,13 +149,13 @@ void gemv(
auto vec_strides = const_param(b_batch_strides); auto vec_strides = const_param(b_batch_strides);
if (M == 1) { if (M == 1) {
mat = b.data<DataType>(); mat = gpu_ptr<DataType>(b);
vec = a.data<DataType>(); vec = gpu_ptr<DataType>(a);
rows = N; rows = N;
std::swap(mat_strides, vec_strides); std::swap(mat_strides, vec_strides);
} else { } else {
mat = a.data<DataType>(); mat = gpu_ptr<DataType>(a);
vec = b.data<DataType>(); vec = gpu_ptr<DataType>(b);
rows = M; rows = M;
} }
uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block;
@@ -146,7 +177,7 @@ void gemv(
0, 0,
mat, mat,
vec, vec,
out.data<DataType>(), gpu_ptr<DataType>(out),
rows, rows,
cols); cols);
} else { } else {
@@ -158,7 +189,7 @@ void gemv(
0, 0,
mat, mat,
vec, vec,
out.data<DataType>(), gpu_ptr<DataType>(out),
rows, rows,
cols, cols,
const_param(batch_shape), const_param(batch_shape),

View File

@@ -31,7 +31,7 @@ void append_indices_arg(
int idx_ndim) { int idx_ndim) {
SmallVector<const void*> indices(nidx); SmallVector<const void*> indices(nidx);
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>(); indices[i] = gpu_ptr<void>(inputs[i + 1]);
} }
args.append(std::move(indices)); args.append(std::move(indices));
SmallVector<int32_t> indices_shape(nidx * idx_ndim); SmallVector<int32_t> indices_shape(nidx * idx_ndim);
@@ -59,7 +59,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() > 0); assert(inputs.size() > 0);
const auto& src = inputs[0]; const auto& src = inputs[0];
out.set_data(allocator::malloc(out.nbytes())); auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@@ -80,7 +82,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_string(idx_dtype), dtype_to_string(idx_dtype),
nidx); nidx);
auto& s = stream();
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::vector<std::string> kernel_names; std::vector<std::string> kernel_names;
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
@@ -121,7 +122,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_ndim, idx_ndim,
large ? "int64_t" : "int32_t"); large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) { for (const auto& in : inputs) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }
@@ -239,7 +239,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
const auto& src = inputs[0]; const auto& src = inputs[0];
const auto& idx = inputs[1]; const auto& idx = inputs[1];
out.set_data(allocator::malloc(out.nbytes())); auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@@ -251,7 +253,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
dtype_to_string(out.dtype()), dtype_to_string(out.dtype()),
dtype_to_string(idx.dtype())); dtype_to_string(idx.dtype()));
auto& s = stream();
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::vector<std::string> kernel_names; std::vector<std::string> kernel_names;
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
@@ -312,7 +313,6 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
idx.flags().row_contiguous, idx.flags().row_contiguous,
large ? "int64_t" : "int32_t"); large ? "int64_t" : "int32_t");
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) { for (const auto& in : inputs) {
encoder.set_input_array(in); encoder.set_input_array(in);
} }

View File

@@ -99,6 +99,30 @@ const std::filesystem::path& ptx_cache_dir() {
return cache; return cache;
} }
std::filesystem::path get_ptx_path(
const std::filesystem::path& cache_dir,
const std::string& module_name) {
#ifdef _WIN32
constexpr int max_file_name_length = 140;
#else
constexpr int max_file_name_length = 245;
#endif
if (module_name.size() <= max_file_name_length) {
return cache_dir / (module_name + ".ptx");
}
auto ptx_path = cache_dir;
int offset = 0;
while (module_name.size() - offset > max_file_name_length) {
ptx_path /= module_name.substr(offset, max_file_name_length);
offset += max_file_name_length;
}
ptx_path /= module_name.substr(offset) + ".ptx";
return ptx_path;
}
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. // Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
bool read_cached_ptx( bool read_cached_ptx(
const std::filesystem::path& cache_dir, const std::filesystem::path& cache_dir,
@@ -109,7 +133,7 @@ bool read_cached_ptx(
return false; return false;
} }
auto ptx_path = cache_dir / (module_name + ".ptx"); auto ptx_path = get_ptx_path(cache_dir, module_name);
std::error_code error; std::error_code error;
auto ptx_size = std::filesystem::file_size(ptx_path, error); auto ptx_size = std::filesystem::file_size(ptx_path, error);
if (error) { if (error) {
@@ -122,7 +146,7 @@ bool read_cached_ptx(
ptx.resize(ptx_size); ptx.resize(ptx_size);
ptx_file.read(ptx.data(), ptx_size); ptx_file.read(ptx.data(), ptx_size);
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); std::ifstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
std::string line; std::string line;
while (std::getline(txt_file, line)) { while (std::getline(txt_file, line)) {
auto tab = line.find('\t'); auto tab = line.find('\t');
@@ -144,16 +168,26 @@ void write_cached_ptx(
return; return;
} }
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); auto ptx_path = get_ptx_path(cache_dir, module_name);
// Ensure that the directory exists
auto parent = ptx_path.parent_path();
if (parent != cache_dir) {
std::filesystem::create_directories(parent);
}
// Write the compiled code and mangled names
std::ofstream ptx_file(ptx_path, std::ios::binary);
if (!ptx.empty()) { if (!ptx.empty()) {
ptx_file.write(&ptx.front(), ptx.size()); ptx_file.write(&ptx.front(), ptx.size());
} }
std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary); std::ofstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
txt_file << name << "\t" << mangled << std::endl; txt_file << name << "\t" << mangled << std::endl;
} }
std::ofstream source_file(cache_dir / (module_name + ".cu")); // Write the generated code
std::ofstream source_file(ptx_path.replace_extension(".cu"));
source_file << source_code; source_file << source_code;
} }
@@ -297,7 +331,8 @@ void load_module(
const std::string& ptx, const std::string& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels, const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
CUmodule& module_, CUmodule& module_,
std::unordered_map<std::string, std::pair<CUfunction, bool>>& kernels) { std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
kernels) {
// Load module. // Load module.
char jit_log[4089] = {}; char jit_log[4089] = {};
CUjit_option options[] = { CUjit_option options[] = {
@@ -314,7 +349,7 @@ void load_module(
for (const auto& [name, mangled] : ptx_kernels) { for (const auto& [name, mangled] : ptx_kernels) {
CUfunction kernel; CUfunction kernel;
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str())); CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
kernels[name] = std::make_pair(kernel, false); kernels[name] = std::make_tuple(kernel, false, 0);
} }
} }
@@ -358,7 +393,7 @@ JitModule::~JitModule() {
CHECK_CUDA_ERROR(cuModuleUnload(module_)); CHECK_CUDA_ERROR(cuModuleUnload(module_));
} }
CUfunction JitModule::get_kernel( std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
const std::string& kernel_name, const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) { std::function<void(CUfunction)> configure_kernel) {
auto it = kernels_.find(kernel_name); auto it = kernels_.find(kernel_name);
@@ -369,14 +404,22 @@ CUfunction JitModule::get_kernel(
// If it is the first time we run this kernel then configure it. Do it only // If it is the first time we run this kernel then configure it. Do it only
// once! // once!
if (!it->second.second) { auto kernel = std::get<0>(it->second);
if (!std::get<1>(it->second)) {
if (configure_kernel) { if (configure_kernel) {
configure_kernel(it->second.first); configure_kernel(kernel);
} }
it->second.second = true; std::get<1>(it->second) = true;
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
} }
return it->second.first; return {kernel, std::get<2>(it->second)};
}
CUfunction JitModule::get_kernel(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel) {
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
} }
std::unordered_map<std::string, JitModule>& get_jit_module_cache() { std::unordered_map<std::string, JitModule>& get_jit_module_cache() {

View File

@@ -31,7 +31,7 @@ struct KernelArgs {
} }
void append(const array& a) { void append(const array& a) {
append(reinterpret_cast<CUdeviceptr>(a.data<void>())); append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a)));
} }
template <typename T> template <typename T>
@@ -99,10 +99,13 @@ class JitModule {
CUfunction get_kernel( CUfunction get_kernel(
const std::string& kernel_name, const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr); std::function<void(CUfunction)> configure_kernel = nullptr);
std::pair<CUfunction, uint> get_kernel_and_dims(
const std::string& kernel_name,
std::function<void(CUfunction)> configure_kernel = nullptr);
private: private:
CUmodule module_{nullptr}; CUmodule module_{nullptr};
std::unordered_map<std::string, std::pair<CUfunction, bool>> kernels_; std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
}; };
std::unordered_map<std::string, JitModule>& get_jit_module_cache(); std::unordered_map<std::string, JitModule>& get_jit_module_cache();

View File

@@ -35,12 +35,10 @@ std::tuple<dim3, uint> get_launch_args(
const Shape& shape, const Shape& shape,
const Strides& strides, const Strides& strides,
bool large, bool large,
int work_per_thread) { int work_per_thread /* = 1 */,
uint max_block_dim /* = 1024 */) {
size_t nthreads = cuda::ceil_div(size, work_per_thread); size_t nthreads = cuda::ceil_div(size, work_per_thread);
uint block_dim = 1024; uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;
if (block_dim > nthreads) {
block_dim = nthreads;
}
dim3 num_blocks; dim3 num_blocks;
if (large) { if (large) {
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread); num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);

View File

@@ -1,14 +1,15 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
// This file includes host-only utilies for writing CUDA kernels, the difference // This file includes host-only utilities for writing CUDA kernels, the
// from backend/cuda/device/utils.cuh is that the latter file only include // difference from backend/cuda/device/utils.cuh is that the latter file only
// device-only code. // include device-only code.
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
#include <cuda.h> #include <cuda.h>
@@ -120,19 +121,28 @@ dim3 get_2d_grid_dims(
size_t divisor); size_t divisor);
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2); std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
// Get the num_blocks and block_dims that maximize occupancy for |kernel|, // Get the num_blocks and block_dims assuming each thread handles
// assuming each thread handles |work_per_thread| elements of |arr|. // |work_per_thread| elements of |arr|.
std::tuple<dim3, uint> get_launch_args( std::tuple<dim3, uint> get_launch_args(
size_t size, size_t size,
const Shape& shape, const Shape& shape,
const Strides& strides, const Strides& strides,
bool large, bool large,
int work_per_thread = 1); int work_per_thread = 1,
uint max_block_dim = 1024);
inline std::tuple<dim3, uint> inline std::tuple<dim3, uint> get_launch_args(
get_launch_args(const array& arr, bool large, int work_per_thread = 1) { const array& arr,
bool large,
int work_per_thread = 1,
uint max_block_dim = 1024) {
return get_launch_args( return get_launch_args(
arr.size(), arr.shape(), arr.strides(), large, work_per_thread); arr.size(),
arr.shape(),
arr.strides(),
large,
work_per_thread,
max_block_dim);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -230,9 +230,10 @@ void LayerNorm::eval_gpu(
nvtx3::scoped_range r("LayerNorm::eval_gpu"); nvtx3::scoped_range r("LayerNorm::eval_gpu");
auto& s = stream(); auto& s = stream();
auto& out = outputs[0]; auto& out = outputs[0];
auto& encoder = cu::get_command_encoder(s);
// Make sure that the last dimension is contiguous. // Make sure that the last dimension is contiguous.
auto set_output = [&s, &out](const array& x) { auto set_output = [&s, &out, &encoder](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) { if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2]; auto s = x.strides()[x.ndim() - 2];
@@ -243,7 +244,7 @@ void LayerNorm::eval_gpu(
out.copy_shared_buffer(x); out.copy_shared_buffer(x);
} else { } else {
out.set_data( out.set_data(
allocator::malloc(x.data_size() * x.itemsize()), cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
x.data_size(), x.data_size(),
x.strides(), x.strides(),
x.flags()); x.flags());
@@ -265,7 +266,6 @@ void LayerNorm::eval_gpu(
int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(x); encoder.set_input_array(x);
encoder.set_input_array(w); encoder.set_input_array(w);
encoder.set_input_array(b); encoder.set_input_array(b);
@@ -280,10 +280,10 @@ void LayerNorm::eval_gpu(
n_rows, n_rows,
block_dim(), block_dim(),
0, 0,
x.data<DataType>(), gpu_ptr<DataType>(x),
w.data<DataType>(), gpu_ptr<DataType>(w),
b.data<DataType>(), gpu_ptr<DataType>(b),
out.data<DataType>(), gpu_ptr<DataType>(out),
eps_, eps_,
axis_size, axis_size,
w_stride, w_stride,
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
gx.copy_shared_buffer(g); gx.copy_shared_buffer(g);
g_in_gx = true; g_in_gx = true;
} else { } else {
gx.set_data(allocator::malloc(gx.nbytes())); gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
} }
if (g_copied && !g_in_gx) { if (g_copied && !g_in_gx) {
encoder.add_temporary(g); encoder.add_temporary(g);
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
g_in_gw = true; g_in_gw = true;
gw_temp.copy_shared_buffer(g); gw_temp.copy_shared_buffer(g);
} else { } else {
gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
encoder.add_temporary(gw_temp); encoder.add_temporary(gw_temp);
} }
} }
@@ -393,11 +393,11 @@ void LayerNormVJP::eval_gpu(
n_rows, n_rows,
block_dim(), block_dim(),
0, 0,
x.data<DataType>(), gpu_ptr<DataType>(x),
w.data<DataType>(), gpu_ptr<DataType>(w),
g.data<DataType>(), gpu_ptr<DataType>(g),
gx.data<DataType>(), gpu_ptr<DataType>(gx),
gw_temp.data<DataType>(), gpu_ptr<DataType>(gw_temp),
eps_, eps_,
axis_size, axis_size,
w_stride); w_stride);

60
mlx/backend/cuda/load.cpp Normal file
View File

@@ -0,0 +1,60 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <utility>
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/primitives.h"
namespace {
template <const uint8_t scalar_size>
void swap_endianness(uint8_t* data_bytes, size_t N) {
struct Elem {
uint8_t bytes[scalar_size];
};
Elem* data = reinterpret_cast<Elem*>(data_bytes);
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < (scalar_size / 2); j++) {
std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]);
}
}
}
} // namespace
namespace mlx::core {
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(stream());
auto size = out.size();
auto nbytes = size * out.itemsize();
out.set_data(cu::malloc_async(nbytes, encoder.stream()));
auto out_ptr = malloc(nbytes);
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
if (swap_endianness_) {
switch (out.itemsize()) {
case 2:
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 4:
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
case 8:
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
break;
}
}
CHECK_CUDA_ERROR(cudaMemcpyAsync(
gpu_ptr<void>(out),
out_ptr,
nbytes,
cudaMemcpyDefault,
encoder.stream()));
CHECK_CUDA_ERROR(cudaLaunchHostFunc(encoder.stream(), free, out_ptr));
}
} // namespace mlx::core

View File

@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = ensure_contiguous(inputs[0]); auto in = ensure_contiguous(inputs[0]);
if (in.flags().row_contiguous) { if (in.flags().row_contiguous) {
out.set_data(allocator::malloc(out.nbytes())); out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
} else { } else {
auto n = in.shape(-1); auto n = in.shape(-1);
auto flags = in.flags(); auto flags = in.flags();
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
flags.col_contiguous = col_contig; flags.col_contiguous = col_contig;
out.set_data( out.set_data(
allocator::malloc(in.nbytes() / n), cu::malloc_async(in.nbytes() / n, encoder.stream()),
in.data_size() / n, in.data_size() / n,
std::move(strides), std::move(strides),
flags); flags);
@@ -151,8 +151,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
n_rows, n_rows,
block_dim(), block_dim(),
0, 0,
in.data<DataType>(), gpu_ptr<DataType>(in),
out.data<DataType>(), gpu_ptr<DataType>(out),
axis_size); axis_size);
}); });
}); });

Some files were not shown because too many files have changed in this diff Show More