mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
111 Commits
6441c21a94
...
v0.29.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4bce5f9b2d | ||
|
|
e9eab527eb | ||
|
|
36ca62dba8 | ||
|
|
9cbb1b0148 | ||
|
|
9bfc476d72 | ||
|
|
25e2356316 | ||
|
|
226a1d24e0 | ||
|
|
630350ad3e | ||
|
|
380aeb58ae | ||
|
|
f37389d100 | ||
|
|
e89e8b4272 | ||
|
|
85a8824a8c | ||
|
|
f5d4397e5c | ||
|
|
343e33b6d5 | ||
|
|
0073096dd1 | ||
|
|
e3d004fed9 | ||
|
|
a393435d28 | ||
|
|
a7a94b29d7 | ||
|
|
22a5da76c8 | ||
|
|
287c63a093 | ||
|
|
1c9ae1eaa1 | ||
|
|
c2c3e0b0a2 | ||
|
|
b0cc71ae71 | ||
|
|
e88f2d4a8e | ||
|
|
9cee557423 | ||
|
|
bbf1423953 | ||
|
|
eb24267b56 | ||
|
|
dc371ae7a5 | ||
|
|
e76a8dd5c5 | ||
|
|
b466dea982 | ||
|
|
7a6adda1e6 | ||
|
|
1a9f820af6 | ||
|
|
d4f4ff3c5e | ||
|
|
7c7e48dbd1 | ||
|
|
fbbf3b9b3e | ||
|
|
bf01ad9367 | ||
|
|
ae438d05fa | ||
|
|
711a645807 | ||
|
|
aa9d44b3d4 | ||
|
|
ec2ab42888 | ||
|
|
787c0d90cd | ||
|
|
e8b604a6a3 | ||
|
|
50cc09887f | ||
|
|
3f730e77aa | ||
|
|
caecbe876a | ||
|
|
8afb6d62f2 | ||
|
|
6ccfa603cd | ||
|
|
36cad99a11 | ||
|
|
ee18e1cbf0 | ||
|
|
af120c2bc0 | ||
|
|
6a3acf2301 | ||
|
|
d6977f2a57 | ||
|
|
db5443e831 | ||
|
|
52b8384d10 | ||
|
|
44cc5da4bc | ||
|
|
dde3682b69 | ||
|
|
17310d91a6 | ||
|
|
b194d65a6a | ||
|
|
a44b27f5f8 | ||
|
|
e5a33f2223 | ||
|
|
c1e3340b23 | ||
|
|
8f163a367d | ||
|
|
89a3df9014 | ||
|
|
c5d2937aa5 | ||
|
|
b61a65e313 | ||
|
|
04cbb4191c | ||
|
|
c5460762e7 | ||
|
|
8ce49cd39e | ||
|
|
9c68b50853 | ||
|
|
111f1e71af | ||
|
|
827003d568 | ||
|
|
d363a76aa4 | ||
|
|
70560b6bd5 | ||
|
|
7ef8a6f2d5 | ||
|
|
31c6f6e33f | ||
|
|
584d48458e | ||
|
|
5cf984ca87 | ||
|
|
a9bac3d9e5 | ||
|
|
5458d43247 | ||
|
|
a4dba65220 | ||
|
|
3dcb286baf | ||
|
|
4822c3dbe9 | ||
|
|
2ca75bb529 | ||
|
|
db14e29a0b | ||
|
|
d2f540f4e0 | ||
|
|
333ffea273 | ||
|
|
f55b6f1f2f | ||
|
|
30561229c7 | ||
|
|
068a4612e9 | ||
|
|
5722c147de | ||
|
|
f6819a1f26 | ||
|
|
f93f87c802 | ||
|
|
9392fc3f88 | ||
|
|
e843c4d8d5 | ||
|
|
0c5fc63a36 | ||
|
|
e397177f6e | ||
|
|
f4c8888cbe | ||
|
|
25c1e03205 | ||
|
|
512281781c | ||
|
|
ac85ddfdb7 | ||
|
|
65d0d40232 | ||
|
|
cea9369610 | ||
|
|
e7c6e1db82 | ||
|
|
c5fcd5b61b | ||
|
|
1df9887998 | ||
|
|
73f22d6226 | ||
|
|
c422050ca7 | ||
|
|
1ba18ff7d9 | ||
|
|
37b440faa8 | ||
|
|
888b13ed63 | ||
|
|
4abb218d21 |
@@ -18,13 +18,14 @@ jobs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
xcode: "26.0.0"
|
||||
resource_class: m4pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install
|
||||
command: |
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
brew install python@3.9
|
||||
brew install doxygen
|
||||
python3.9 -m venv env
|
||||
@@ -89,7 +90,8 @@ jobs:
|
||||
command: |
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
uv pip install -e ".[dev]" -v
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
@@ -118,7 +120,7 @@ jobs:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
default: "26.0.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
@@ -126,12 +128,13 @@ jobs:
|
||||
xcode: << parameters.xcode_version >>
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
resource_class: m4pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
||||
brew install openmpi uv
|
||||
- run:
|
||||
@@ -196,7 +199,7 @@ jobs:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
uv pip install -e .
|
||||
uv pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
uv run --no-project python -m xmlrunner discover \
|
||||
@@ -222,15 +225,20 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install libnccl2 libnccl-dev
|
||||
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||
rm -rf ccache-4.11.3-linux-x86_64
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Set CCache size
|
||||
command: ccache --max-size 1G
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
uv pip install cmake
|
||||
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Run Python tests
|
||||
@@ -238,12 +246,23 @@ jobs:
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
cmake . -B build \
|
||||
-DMLX_BUILD_CUDA=ON \
|
||||
-DCMAKE_CUDA_COMPILER=`which nvcc` \
|
||||
-DCMAKE_BUILD_TYPE=DEBUG
|
||||
cmake --build build -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||
- run:
|
||||
name: CCache report
|
||||
command: |
|
||||
ccache --show-stats
|
||||
ccache --zero-stats
|
||||
ccache --max-size 400MB
|
||||
ccache --cleanup
|
||||
- save_cache:
|
||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||
@@ -257,7 +276,7 @@ jobs:
|
||||
default: "3.9"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
default: "26.0.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
@@ -266,7 +285,7 @@ jobs:
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: m2pro.medium
|
||||
resource_class: m4pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
@@ -274,11 +293,15 @@ jobs:
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@<< parameters.python_version >>
|
||||
brew install openmpi
|
||||
python<< parameters.python_version >> -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
mkdir -p ~/miniconda3
|
||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
source ~/miniconda3/bin/activate
|
||||
conda init --all
|
||||
conda create -n env python=<< parameters.python_version >> -y
|
||||
conda activate env
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
@@ -288,19 +311,19 @@ jobs:
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||
- when:
|
||||
@@ -310,7 +333,7 @@ jobs:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||
- when:
|
||||
@@ -319,7 +342,7 @@ jobs:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
conda activate env
|
||||
twine upload dist/*
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
@@ -392,7 +415,7 @@ jobs:
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
resource_class: xlarge
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -439,7 +462,7 @@ workflows:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
macosx_deployment_target: ["13.5", "15.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test:
|
||||
matrix:
|
||||
@@ -464,68 +487,7 @@ workflows:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
@@ -567,7 +529,7 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
macosx_deployment_target: ["13.5", "15.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
@@ -586,53 +548,7 @@ workflows:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
@@ -651,68 +567,7 @@ workflows:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
|
||||
@@ -19,12 +19,17 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
||||
# Organizations
|
||||
|
||||
MLX has received contributions from the following companies:
|
||||
- NVIDIA Corporation & Affiliates
|
||||
|
||||
# Third-Party Software
|
||||
|
||||
MLX leverages several third-party software, listed here together with
|
||||
|
||||
@@ -26,6 +26,7 @@ set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
# ----------------------------- Configuration -----------------------------
|
||||
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
||||
@@ -87,22 +88,21 @@ cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
set(METAL_LIB "-framework Metal")
|
||||
set(FOUNDATION_LIB "-framework Foundation")
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
set(MLX_METAL_DEBUG OFF)
|
||||
elseif(MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
if(MLX_BUILD_METAL)
|
||||
find_library(METAL_LIB Metal)
|
||||
find_library(FOUNDATION_LIB Foundation)
|
||||
find_library(QUARTZ_LIB QuartzCore)
|
||||
if(METAL_LIB)
|
||||
message(STATUS "Metal found ${METAL_LIB}")
|
||||
else()
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
||||
endif()
|
||||
|
||||
if(MLX_METAL_DEBUG)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
@@ -111,7 +111,8 @@ elseif(MLX_BUILD_METAL)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
@@ -140,6 +141,12 @@ elseif(MLX_BUILD_METAL)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||
# With newer clang/gcc versions following libs are implicitly linked, but when
|
||||
# building on old distributions they need to be explicitly listed.
|
||||
target_link_libraries(mlx PRIVATE dl pthread)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
if(MSVC)
|
||||
# GGUF does not build with MSVC.
|
||||
@@ -167,7 +174,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
message(STATUS "Accelerate not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
|
||||
38
README.md
38
README.md
@@ -2,7 +2,7 @@
|
||||
|
||||
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
|
||||
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
|
||||
[**Examples**](#examples)
|
||||
[**Examples**](#examples)
|
||||
|
||||
[](https://circleci.com/gh/ml-explore/mlx)
|
||||
|
||||
@@ -11,37 +11,37 @@ brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||
more complex models.
|
||||
|
||||
- **Composable function transformations**: MLX supports composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
and computation graph optimization.
|
||||
- **Composable function transformations**: MLX supports composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
and computation graph optimization.
|
||||
|
||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||
materialized when needed.
|
||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||
materialized when needed.
|
||||
|
||||
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||
dynamically. Changing the shapes of function arguments does not trigger
|
||||
slow compilations, and debugging is simple and intuitive.
|
||||
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||
dynamically. Changing the shapes of function arguments does not trigger
|
||||
slow compilations, and debugging is simple and intuitive.
|
||||
|
||||
- **Multi-device**: Operations can run on any of the supported devices
|
||||
(currently the CPU and the GPU).
|
||||
- **Multi-device**: Operations can run on any of the supported devices
|
||||
(currently the CPU and the GPU).
|
||||
|
||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||
Operations on MLX arrays can be performed on any of the supported
|
||||
device types without transferring data.
|
||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||
Operations on MLX arrays can be performed on any of the supported
|
||||
device types without transferring data.
|
||||
|
||||
MLX is designed by machine learning researchers for machine learning
|
||||
researchers. The framework is intended to be user-friendly, but still efficient
|
||||
to train and deploy models. The design of the framework itself is also
|
||||
conceptually simple. We intend to make it easy for researchers to extend and
|
||||
improve MLX with the goal of quickly exploring new ideas.
|
||||
improve MLX with the goal of quickly exploring new ideas.
|
||||
|
||||
The design of MLX is inspired by frameworks like
|
||||
[NumPy](https://numpy.org/doc/stable/index.html),
|
||||
@@ -91,7 +91,7 @@ Checkout the
|
||||
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
||||
for more information on building the C++ and Python APIs from source.
|
||||
|
||||
## Contributing
|
||||
## Contributing
|
||||
|
||||
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
||||
on contributing to MLX. See the
|
||||
@@ -110,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
||||
MLX useful in your research and wish to cite it, please use the following
|
||||
BibTex entry:
|
||||
|
||||
```
|
||||
```text
|
||||
@software{mlx2023,
|
||||
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
||||
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
||||
|
||||
@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
||||
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
||||
|
||||
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
||||
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
|
||||
np.float32
|
||||
)
|
||||
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
dtypes = ("float32", "float16")
|
||||
dtypes = ("float32", "float16", "complex64")
|
||||
transposes = ("nn", "nt", "tn")
|
||||
shapes = (
|
||||
(16, 234, 768, 3072),
|
||||
@@ -187,7 +185,7 @@ if __name__ == "__main__":
|
||||
diff = gflops_mx / gflops_pt - 1.0
|
||||
|
||||
print(
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
|
||||
)
|
||||
if gflops_pt >= 2.0 * gflops_mx:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
|
||||
@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
||||
|
||||
|
||||
for transpose in (False, True):
|
||||
for dtype in ("float32", "float16"):
|
||||
for dtype in ("float32", "float16", "complex64"):
|
||||
fig, axs = plt.subplots(
|
||||
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
||||
)
|
||||
@@ -215,7 +215,7 @@ for transpose in (False, True):
|
||||
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
||||
fig.savefig(
|
||||
os.path.join(
|
||||
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
|
||||
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
|
||||
)
|
||||
)
|
||||
plt.close(fig)
|
||||
|
||||
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@@ -0,0 +1,54 @@
|
||||
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
||||
# directories.
|
||||
|
||||
set(NCCL_ROOT_DIR
|
||||
$ENV{NCCL_ROOT_DIR}
|
||||
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||
|
||||
find_path(
|
||||
NCCL_INCLUDE_DIRS
|
||||
NAMES nccl.h
|
||||
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||
|
||||
if($ENV{USE_STATIC_NCCL})
|
||||
message(
|
||||
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||
set(NCCL_LIBNAME "libnccl_static.a")
|
||||
else()
|
||||
set(NCCL_LIBNAME "nccl")
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
NCCL_LIBRARIES
|
||||
NAMES ${NCCL_LIBNAME}
|
||||
HINTS ${NCCL_LIB_DIR}
|
||||
${NCCL_ROOT_DIR}
|
||||
${NCCL_ROOT_DIR}/lib
|
||||
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||
${NCCL_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||
NCCL_LIBRARIES)
|
||||
|
||||
if(NCCL_FOUND)
|
||||
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||
message(
|
||||
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||
file(
|
||||
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||
LIMIT_COUNT 1)
|
||||
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||
endif()
|
||||
message(
|
||||
STATUS
|
||||
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||
endif()
|
||||
@@ -127,7 +127,8 @@ relying on a copy from ``ensure_row_contiguous``:
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
source=source,
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
@@ -138,7 +139,6 @@ relying on a copy from ``ensure_row_contiguous``:
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ are the CPU and GPU.
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/cuda
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
|
||||
@@ -271,7 +271,7 @@ and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
apt-get update -y
|
||||
apt-get -y install cuda-toolkit-12-9
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
||||
|
||||
|
||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||
|
||||
9
docs/src/python/cuda.rst
Normal file
9
docs/src/python/cuda.rst
Normal file
@@ -0,0 +1,9 @@
|
||||
CUDA
|
||||
=====
|
||||
|
||||
.. currentmodule:: mlx.core.cuda
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
is_available
|
||||
@@ -13,3 +13,4 @@ Fast
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
metal_kernel
|
||||
cuda_kernel
|
||||
|
||||
@@ -27,6 +27,7 @@ simple functions.
|
||||
mish
|
||||
prelu
|
||||
relu
|
||||
relu2
|
||||
relu6
|
||||
selu
|
||||
sigmoid
|
||||
|
||||
@@ -50,6 +50,7 @@ Layers
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU2
|
||||
ReLU6
|
||||
RNN
|
||||
RoPE
|
||||
|
||||
@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
|
||||
.. code-block:: python
|
||||
|
||||
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||
timeit(nn.gelu, x)
|
||||
timeit(mx.compile(nn.gelu), x)
|
||||
timeit(gelu, x)
|
||||
timeit(mx.compile(gelu), x)
|
||||
|
||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||
five times faster.
|
||||
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z), state
|
||||
return mx.exp(z)
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
|
||||
@@ -184,7 +184,7 @@ almost identical to the example above:
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
||||
grads = mx.nn.average_gradients(grads) # <---- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
|
||||
@@ -164,11 +164,11 @@ to export a function which can be used for inputs with variable shapes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
|
||||
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
|
||||
imported_abs = mx.import_function("fun.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_abs(mx.array(-1.0))
|
||||
out, = imported_abs(mx.array([-1.0]))
|
||||
|
||||
# Also ok
|
||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||
|
||||
@@ -107,8 +107,20 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
Note that unlike NumPy, slicing an array creates a copy, not a view. So
|
||||
mutating it does not mutate the original array:
|
||||
|
||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> b = a[:]
|
||||
>>> b[2] = 0
|
||||
>>> b
|
||||
array([1, 2, 0], dtype=int32)
|
||||
>>> a
|
||||
array([1, 2, 3], dtype=int32)
|
||||
|
||||
Also unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||
const array& a,
|
||||
const array& b) {
|
||||
if (a.ndim() == 2) {
|
||||
return {{1}, {0}, {0}};
|
||||
return {Shape{1}, Strides{0}, Strides{0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||
collapse_batches(const array& a, const array& b, const array& c) {
|
||||
if (a.ndim() == 2) {
|
||||
return {{1}, {0}, {0}, {0}};
|
||||
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
|
||||
@@ -11,6 +11,8 @@ namespace mlx::core {
|
||||
enum class TernaryOpType {
|
||||
ScalarScalarScalar,
|
||||
VectorVectorVector,
|
||||
VectorVectorScalar,
|
||||
VectorScalarVector,
|
||||
General,
|
||||
};
|
||||
|
||||
@@ -25,6 +27,14 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
|
||||
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
||||
c.flags().col_contiguous)) {
|
||||
topt = TernaryOpType::VectorVectorVector;
|
||||
} else if (
|
||||
b.data_size() == 1 && a.flags().row_contiguous &&
|
||||
c.flags().row_contiguous) {
|
||||
topt = TernaryOpType::VectorScalarVector;
|
||||
} else if (
|
||||
c.data_size() == 1 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
topt = TernaryOpType::VectorVectorScalar;
|
||||
} else {
|
||||
topt = TernaryOpType::General;
|
||||
}
|
||||
@@ -59,6 +69,8 @@ inline void set_ternary_op_output_data(
|
||||
b.flags());
|
||||
}
|
||||
break;
|
||||
case TernaryOpType::VectorVectorScalar:
|
||||
case TernaryOpType::VectorScalarVector:
|
||||
case TernaryOpType::General:
|
||||
// Try to donate an input which is row_contiguous
|
||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||
|
||||
@@ -228,31 +228,4 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||
}
|
||||
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||
int ndim = x.ndim();
|
||||
if (axis1 < 0) {
|
||||
axis1 += ndim;
|
||||
}
|
||||
if (axis2 < 0) {
|
||||
axis2 += ndim;
|
||||
}
|
||||
|
||||
auto shape = x.shape();
|
||||
std::swap(shape[axis1], shape[axis2]);
|
||||
auto strides = x.strides();
|
||||
std::swap(strides[axis1], strides[axis2]);
|
||||
|
||||
auto [data_size, row_contiguous, col_contiguous] =
|
||||
check_contiguity(shape, strides);
|
||||
bool contiguous = data_size == x.data_size();
|
||||
|
||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(
|
||||
x,
|
||||
std::move(strides),
|
||||
{contiguous, row_contiguous, col_contiguous},
|
||||
x.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -196,9 +196,6 @@ void shared_buffer_reshape(
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
|
||||
// Like the swapaxes op but safe to call in eval_gpu.
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||
|
||||
template <typename T>
|
||||
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||
vec.erase(std::next(vec.begin(), index));
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/version.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -94,7 +95,11 @@ void* compile(
|
||||
kernel_file_name = kernel_name;
|
||||
}
|
||||
|
||||
auto output_dir = std::filesystem::temp_directory_path();
|
||||
auto output_dir =
|
||||
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
|
||||
if (!std::filesystem::exists(output_dir)) {
|
||||
std::filesystem::create_directories(output_dir);
|
||||
}
|
||||
|
||||
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
|
||||
auto shared_lib_path = (output_dir / shared_lib_name).string();
|
||||
@@ -157,10 +162,12 @@ inline void build_kernel(
|
||||
#endif
|
||||
|
||||
// Start the kernel
|
||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
||||
os << "void " << kernel_name
|
||||
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(i)) {
|
||||
@@ -175,8 +182,8 @@ inline void build_kernel(
|
||||
<< "];" << std::endl;
|
||||
// Scalars and contiguous need no strides
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
os << " const int64_t* " << xname << "_strides = strides["
|
||||
<< strides_index++ << "];" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,10 +193,8 @@ inline void build_kernel(
|
||||
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
||||
<< "*)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
|
||||
} else {
|
||||
// Add output size
|
||||
if (contiguous) {
|
||||
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
|
||||
@@ -288,17 +293,8 @@ void Compiled::eval_cpu(
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Force allocating shape/strides on heap so we can take their data() first
|
||||
// and then std::move them.
|
||||
// TODO: Refactor code to avoid heap allocation.
|
||||
shape.grow();
|
||||
for (auto& s : strides) {
|
||||
s.grow();
|
||||
}
|
||||
|
||||
// Collect function input arguments.
|
||||
std::vector<void*> args;
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
@@ -306,9 +302,6 @@ void Compiled::eval_cpu(
|
||||
const auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
args.push_back(strides[strides_index++].data());
|
||||
}
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
@@ -343,16 +336,20 @@ void Compiled::eval_cpu(
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
if (!contiguous) {
|
||||
args.push_back((void*)shape.data());
|
||||
} else {
|
||||
if (contiguous) {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
|
||||
encoder.dispatch([fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||
shape = std::move(shape)]() mutable {
|
||||
SmallVector<int64_t*> strides_ptrs;
|
||||
for (auto& s : strides) {
|
||||
strides_ptrs.push_back(s.data());
|
||||
}
|
||||
fun(shape.data(), strides_ptrs.data(), args.data());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -88,4 +88,47 @@ void matmul<double>(
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void matmul<complex64_t>(
|
||||
const complex64_t* a,
|
||||
const complex64_t* b,
|
||||
complex64_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
auto calpha = static_cast<complex64_t>(alpha);
|
||||
auto cbeta = static_cast<complex64_t>(beta);
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_cgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
&calpha,
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
lda,
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
ldb,
|
||||
&cbeta,
|
||||
out + M * N * i,
|
||||
ldc);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr)
|
||||
INSTANTIATE_LAPACK_REAL(syevd)
|
||||
INSTANTIATE_LAPACK_REAL(geev)
|
||||
INSTANTIATE_LAPACK_REAL(potrf)
|
||||
INSTANTIATE_LAPACK_REAL(gesvdx)
|
||||
INSTANTIATE_LAPACK_REAL(gesdd)
|
||||
INSTANTIATE_LAPACK_REAL(getrf)
|
||||
INSTANTIATE_LAPACK_REAL(getri)
|
||||
INSTANTIATE_LAPACK_REAL(trtri)
|
||||
|
||||
@@ -108,6 +108,9 @@ void matmul_general(
|
||||
} else if (out.dtype() == float64) {
|
||||
matmul_dispatch<double>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else if (out.dtype() == complex64) {
|
||||
matmul_dispatch<complex64_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
|
||||
}
|
||||
@@ -128,10 +131,6 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
if (out.size() == 0) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return;
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
@@ -13,6 +11,35 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
const static float MXFP4_LUT[16] = {
|
||||
+0.0f,
|
||||
+0.5f,
|
||||
+1.0f,
|
||||
+1.5f,
|
||||
+2.0f,
|
||||
+3.0f,
|
||||
+4.0f,
|
||||
+6.0f,
|
||||
-0.0f,
|
||||
-0.5f,
|
||||
-1.0f,
|
||||
-1.5f,
|
||||
-2.0f,
|
||||
-3.0f,
|
||||
-4.0f,
|
||||
-6.0f};
|
||||
|
||||
template <typename T>
|
||||
static inline T dequantize_scale(uint8_t s) {
|
||||
using FOrI = union {
|
||||
bfloat16_t f;
|
||||
uint16_t i;
|
||||
};
|
||||
FOrI out;
|
||||
out.i = (s == 0 ? 0x40 : (static_cast<uint16_t>(s) << 7));
|
||||
return static_cast<T>(out.f);
|
||||
}
|
||||
|
||||
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
@@ -407,6 +434,231 @@ void _qmm_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int group_size = 32;
|
||||
constexpr int pack_factor = get_pack_factor(4, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(4);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint8_t* w_local = (const uint8_t*)w;
|
||||
const uint8_t* scales_local = scales;
|
||||
|
||||
std::fill(result, result + N, 0);
|
||||
|
||||
for (int k = 0; k < K; k++) {
|
||||
T* result_local = result;
|
||||
T xi = *x++;
|
||||
|
||||
for (int n = 0; n < N; n += group_size) {
|
||||
T scale = dequantize_scale<T>(*scales_local++);
|
||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||
uint8_t wi = *w_local++;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
(*result_local++) +=
|
||||
xi * scale * static_cast<T>(MXFP4_LUT[wi & 0xf]);
|
||||
wi >>= 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result += N;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_t(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int group_size = 32;
|
||||
constexpr int pack_factor = get_pack_factor(4, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(4);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint8_t* w_local = (const uint8_t*)w;
|
||||
const uint8_t* scales_local = scales;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
const T* x_local = x;
|
||||
T sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = dequantize_scale<T>(*scales_local++);
|
||||
|
||||
T gsum = 0;
|
||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||
uint8_t wi = *w_local++;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
gsum += (*x_local++) * static_cast<T>(MXFP4_LUT[wi & 0xf]);
|
||||
wi >>= 4;
|
||||
}
|
||||
}
|
||||
sum += scale * gsum;
|
||||
}
|
||||
*result = sum;
|
||||
result++;
|
||||
}
|
||||
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
template <int S>
|
||||
simd::Simd<float, S> mxfp4_extract_bits_simd(const uint32_t* w) {
|
||||
if constexpr (S == 8) {
|
||||
constexpr std::array<uint32_t, 8> shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}};
|
||||
auto shifts(*(simd::Simd<uint32_t, S>*)&shifts_);
|
||||
auto wi = simd::Simd<uint32_t, S>(*w);
|
||||
wi = wi >> shifts;
|
||||
wi = wi & 0xf;
|
||||
simd::Simd<float, S> w_out;
|
||||
for (int i = 0; i < S; ++i) {
|
||||
w_out[i] = MXFP4_LUT[wi[i]];
|
||||
}
|
||||
return w_out;
|
||||
} else {
|
||||
// Appease compiler.. but should never get here
|
||||
throw std::runtime_error("Unsupported combination for simd qmm.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_t_simd(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int group_size = 32;
|
||||
constexpr int pack_factor = 32 / 4;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
constexpr int S = simd::max_size<T>;
|
||||
static_assert(
|
||||
S % pack_factor == 0, "SIMD size must be divisible by pack factor");
|
||||
constexpr int packs_per_simd = S / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const uint8_t* scales_local = scales;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
simd::Simd<float, S> acc(0);
|
||||
auto x_local = x;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = dequantize_scale<T>(*scales_local++);
|
||||
|
||||
simd::Simd<float, S> g_acc(0);
|
||||
for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) {
|
||||
// Extract bits
|
||||
auto wf = mxfp4_extract_bits_simd<S>(w_local);
|
||||
w_local += packs_per_simd;
|
||||
simd::Simd<float, S> x_simd = simd::load<T, S>(x_local);
|
||||
g_acc = g_acc + x_simd * wf;
|
||||
x_local += S;
|
||||
}
|
||||
acc = acc + scale * g_acc;
|
||||
}
|
||||
|
||||
*result = T(simd::sum(acc));
|
||||
result++;
|
||||
}
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_dispatch_transpose(
|
||||
T* result,
|
||||
const T* x,
|
||||
const uint32_t* w,
|
||||
const uint8_t* scales,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
bool transposed_w) {
|
||||
if (transposed_w) {
|
||||
// the simd size must be a multiple of the number of elements per word
|
||||
if constexpr (simd::max_size<T> % 8 == 0) {
|
||||
mxfp4_qmm_t_simd<T>(result, x, w, scales, M, N, K);
|
||||
} else {
|
||||
mxfp4_qmm_t<T>(result, x, w, scales, M, N, K);
|
||||
}
|
||||
} else {
|
||||
mxfp4_qmm<T>(result, x, w, scales, M, N, K);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
int batch_size = x.size() / (K * M);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<uint8_t>();
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
mxfp4_qmm_dispatch_transpose<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
|
||||
void mxfp4_qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case bfloat16:
|
||||
mxfp4_qmm_dispatch_typed<bfloat16_t>(out, x, w, scales, transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
mxfp4_qmm_dispatch_typed<float16_t>(out, x, w, scales, transposed_w);
|
||||
break;
|
||||
case float32:
|
||||
mxfp4_qmm_dispatch_typed<float>(out, x, w, scales, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void _bs_qmm_dispatch_typed(
|
||||
array& out,
|
||||
@@ -513,115 +765,198 @@ void _bs_qmm_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void mxfp4_bs_qmm_dispatch_typed(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<uint8_t>();
|
||||
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
||||
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices.shape(), lhs_indices.strides())];
|
||||
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices.shape(), rhs_indices.strides())];
|
||||
mxfp4_qmm_dispatch_transpose<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
|
||||
scales_ptr +
|
||||
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
|
||||
void mxfp4_bs_qmm_dispatch(
|
||||
array& out,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
mxfp4_bs_qmm_dispatch_typed<float>(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
mxfp4_bs_qmm_dispatch_typed<float16_t>(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
mxfp4_bs_qmm_dispatch_typed<bfloat16_t>(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"[quantized_matmul] only floating types are supported");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
std::vector<array> temps;
|
||||
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_cpy, CopyType::General, s);
|
||||
encoder.add_temporary(arr_cpy);
|
||||
return arr_cpy;
|
||||
}
|
||||
};
|
||||
|
||||
auto x = ensure_row_contiguous(x_pre);
|
||||
auto w = ensure_row_contiguous(w_pre);
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
});
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto biases = ensure_row_contiguous(inputs[3]);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
});
|
||||
} else {
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
transpose_ = transpose_]() mutable {
|
||||
mxfp4_qmm_dispatch(out, x, w, scales, transpose_);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
auto& lhs_indices = inputs[4];
|
||||
auto& rhs_indices = inputs[5];
|
||||
auto& lhs_indices = inputs[inputs.size() - 2];
|
||||
auto& rhs_indices = inputs[inputs.size() - 1];
|
||||
|
||||
std::vector<array> temps;
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
auto ensure_row_contiguous_last_dims = [s = stream(),
|
||||
&temps](const array& arr) {
|
||||
&encoder](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_cpy, CopyType::General, s);
|
||||
encoder.add_temporary(arr_cpy);
|
||||
return arr_cpy;
|
||||
}
|
||||
};
|
||||
|
||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
});
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto biases = ensure_row_contiguous_last_dims(inputs[3]);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
});
|
||||
} else {
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
transpose_ = transpose_]() mutable {
|
||||
mxfp4_bs_qmm_dispatch(
|
||||
out, x, w, scales, lhs_indices, rhs_indices, transpose_);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
@@ -705,7 +1040,7 @@ void dispatch_quantize(
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
void fast::Quantize::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto ensure_row_contiguous = [s = stream()](const array& arr) {
|
||||
@@ -764,7 +1099,7 @@ void fast::AffineQuantize::eval_cpu(
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
"[fast::Quantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
// There seems to be a bug in sims/base.h
|
||||
// There seems to be a bug in simd/base_simd.h
|
||||
// __XROS_2_0 is not defined, the expression evaluates
|
||||
// to true instead of false setting the SIMD library
|
||||
// higher than it should be even on macOS < 15
|
||||
@@ -234,6 +234,7 @@ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
|
||||
|
||||
template <typename MaskT, typename T1, typename T2, int N>
|
||||
Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
|
||||
static_assert(std::is_same_v<MaskT, bool>);
|
||||
if constexpr (sizeof(T1) == 1) {
|
||||
return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
|
||||
} else if constexpr (sizeof(T1) == 2) {
|
||||
@@ -251,9 +252,13 @@ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
|
||||
return asd::pow(base.value, exp.value);
|
||||
} else {
|
||||
Simd<T, N> res = 1;
|
||||
while (any(exp)) {
|
||||
res = select(exp & 1, res * base, res);
|
||||
base = select(exp, base * base, base);
|
||||
// Raising an integer to a negative power is undefined
|
||||
if (any(exp < 0)) {
|
||||
return 0;
|
||||
}
|
||||
while (any(exp > 0)) {
|
||||
res = select((exp & 1) != 0, res * base, res);
|
||||
base = select(exp > 0, base * base, base);
|
||||
exp = exp >> 1;
|
||||
}
|
||||
return res;
|
||||
|
||||
@@ -15,6 +15,18 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// NaN-aware comparator that places NaNs at the end
|
||||
template <typename T>
|
||||
bool nan_aware_less(T a, T b) {
|
||||
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
|
||||
if (std::isnan(a))
|
||||
return false;
|
||||
if (std::isnan(b))
|
||||
return true;
|
||||
}
|
||||
return a < b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct StridedIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
@@ -130,7 +142,7 @@ void sort(array& out, int axis) {
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed);
|
||||
std::stable_sort(st, ed, nan_aware_less<T>);
|
||||
src_it.step();
|
||||
}
|
||||
}
|
||||
@@ -184,6 +196,15 @@ void argsort(const array& in, array& out, int axis) {
|
||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
|
||||
// Handle NaNs (place them at the end)
|
||||
if (std::is_floating_point<T>::value) {
|
||||
if (std::isnan(v1))
|
||||
return false;
|
||||
if (std::isnan(v2))
|
||||
return true;
|
||||
}
|
||||
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
@@ -219,7 +240,7 @@ void partition(array& out, int axis, int kth) {
|
||||
StridedIterator md(data_ptr, axis_stride, kth);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed);
|
||||
std::nth_element(st, md, ed, nan_aware_less<T>);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -276,6 +297,15 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
|
||||
// Handle NaNs (place them at the end)
|
||||
if (std::is_floating_point<T>::value) {
|
||||
if (std::isnan(v1))
|
||||
return false;
|
||||
if (std::isnan(v2))
|
||||
return true;
|
||||
}
|
||||
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -81,9 +81,7 @@ void svd_impl(
|
||||
// Vᵀ of shape N x N. (M x M in lapack).
|
||||
const int ldvt = M;
|
||||
|
||||
auto job_u = (u_ptr) ? "V" : "N";
|
||||
auto job_vt = (u_ptr) ? "V" : "N";
|
||||
static constexpr auto range = "A";
|
||||
auto jobz = (u_ptr) ? "A" : "N";
|
||||
|
||||
// Will contain the number of singular values after the call has returned.
|
||||
int ns = 0;
|
||||
@@ -91,30 +89,20 @@ void svd_impl(
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
static const int ignored_int = 0;
|
||||
static const T ignored_float = 0;
|
||||
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
@@ -136,20 +124,13 @@ void svd_impl(
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
gesvdx<T>(
|
||||
/* jobu = */ job_u,
|
||||
/* jobvt = */ job_vt,
|
||||
/* range = */ range,
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in_ptr + M * N * i,
|
||||
/* lda = */ &lda,
|
||||
/* vl = */ &ignored_float,
|
||||
/* vu = */ &ignored_float,
|
||||
/* il = */ &ignored_int,
|
||||
/* iu = */ &ignored_int,
|
||||
/* ns = */ &ns,
|
||||
/* s = */ s_ptr + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||
@@ -167,13 +148,6 @@ void svd_impl(
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
if (ns != K) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: expected " << K << " singular values, but " << ns
|
||||
<< " were computed.";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(in);
|
||||
|
||||
@@ -77,7 +77,8 @@ struct Real {
|
||||
struct Sigmoid {
|
||||
template <int N, typename T>
|
||||
Simd<T, N> operator()(Simd<T, N> x) {
|
||||
return 1.0f / (1.0f + simd::exp(-x));
|
||||
auto y = 1.0f / (1.0f + simd::exp(simd::abs(x)));
|
||||
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
|
||||
}
|
||||
SINGLE()
|
||||
};
|
||||
|
||||
@@ -16,8 +16,13 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
@@ -149,7 +154,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
FetchContent_Declare(
|
||||
cudnn
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||
GIT_TAG v1.12.1
|
||||
GIT_TAG v1.14.0
|
||||
GIT_SHALLOW TRUE
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||
@@ -165,6 +170,10 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
# Supress warnings: note: parameter passing for argument of type
|
||||
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||
# 10.1
|
||||
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||
|
||||
# Install CCCL headers for JIT.
|
||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||
|
||||
@@ -30,8 +30,20 @@ SmallSizePool::SmallSizePool() {
|
||||
next_free_ = buffer_;
|
||||
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||
|
||||
int device_count = 0;
|
||||
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
#if CUDART_VERSION >= 13000
|
||||
cudaMemLocation loc;
|
||||
loc.type = cudaMemLocationTypeDevice;
|
||||
loc.id = i;
|
||||
#else
|
||||
int loc = i;
|
||||
#endif // CUDART_VERSION >= 13000
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
||||
}
|
||||
|
||||
auto curr = next_free_;
|
||||
for (size_t i = 1; i < num_blocks; ++i) {
|
||||
@@ -79,7 +91,7 @@ CudaAllocator::CudaAllocator()
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
memory_limit_ = total * 0.95;
|
||||
max_pool_size_ = memory_limit_;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,23 +6,33 @@
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <typename T>
|
||||
struct Arange {
|
||||
const T start;
|
||||
const T step;
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
__device__ T operator()(uint32_t i) const {
|
||||
return start + i * step;
|
||||
template <typename T, typename IdxT, int N_WRITES>
|
||||
__global__ void arange(T* out, IdxT size, T start, T step) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_WRITES > size) {
|
||||
for (IdxT i = index * N_WRITES; i < size; ++i) {
|
||||
out[i] = start + i * step;
|
||||
}
|
||||
} else {
|
||||
AlignedVector<T, N_WRITES> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_WRITES; ++i) {
|
||||
out_vec[i] = start + (index * N_WRITES + i) * step;
|
||||
}
|
||||
|
||||
store_vector<N_WRITES>(out, index, out_vec);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
@@ -36,19 +46,23 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||
using OutType = cuda_type_t<CTYPE>;
|
||||
CTYPE step =
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||
thrust::transform(
|
||||
cu::thrust_policy(encoder.stream()),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
cu::Arange<OutType>{
|
||||
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||
constexpr int N_WRITES = 16 / sizeof(OutType);
|
||||
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);
|
||||
encoder.add_kernel_node(
|
||||
cu::arange<OutType, IdxT, N_WRITES>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
out.data<OutType>(),
|
||||
out.data_size(),
|
||||
static_cast<CTYPE>(start_),
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -267,7 +267,8 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
return std::make_tuple(
|
||||
false, std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
@@ -331,9 +332,9 @@ void Compiled::eval_gpu(
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [kernel, max_block_dims] = mod.get_kernel_and_dims(kernel_name);
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(outputs[0], large, work_per_thread);
|
||||
get_launch_args(outputs[0], large, work_per_thread, max_block_dims);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,18 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/conv/conv.h"
|
||||
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
// cudnn_frontend.h redefines this macro.
|
||||
#undef CHECK_CUDA_ERROR
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include <cudnn_frontend_find_plan.h>
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
@@ -21,9 +15,6 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Not all engines support it so can not use this API now.
|
||||
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
||||
|
||||
// Alias for better readability.
|
||||
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
||||
#define CONV_BACKWARD_INPUT \
|
||||
@@ -31,6 +22,9 @@ namespace {
|
||||
#define CONV_BACKWARD_WEIGHT \
|
||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
||||
|
||||
// Custom placeholder representing fallback kernel.
|
||||
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
|
||||
|
||||
struct ConvCacheKey {
|
||||
int device_id;
|
||||
cudnnDataType_t cudnn_dtype;
|
||||
@@ -50,203 +44,13 @@ struct ConvCacheKey {
|
||||
auto& conv_cache() {
|
||||
static LRUBytesKeyCache<
|
||||
ConvCacheKey,
|
||||
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
|
||||
cache(/* capacity */ 128);
|
||||
std::pair<
|
||||
cudnnBackendDescriptorType_t,
|
||||
std::optional<cudnn_frontend::ExecutionPlan>>>
|
||||
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
|
||||
return cache;
|
||||
}
|
||||
|
||||
template <typename T, typename Vec>
|
||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||
return SmallVector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
template <typename T, template <typename U> class Vec>
|
||||
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
}
|
||||
std::array<T, MAX_NDIM> result = {};
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
auto nhwc_to_nchw(const array& x) {
|
||||
auto shape = convert_vector<int64_t>(x.shape());
|
||||
shape.insert(shape.begin() + 1, shape.back());
|
||||
shape.erase(shape.end() - 1);
|
||||
auto strides = convert_vector<int64_t>(x.strides());
|
||||
strides.insert(strides.begin() + 1, strides.back());
|
||||
strides.erase(strides.end() - 1);
|
||||
return std::make_tuple(std::move(shape), std::move(strides));
|
||||
}
|
||||
|
||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
return CUDNN_DATA_INT8;
|
||||
case int32:
|
||||
return CUDNN_DATA_INT32;
|
||||
case uint8:
|
||||
return CUDNN_DATA_UINT8;
|
||||
case float16:
|
||||
return CUDNN_DATA_HALF;
|
||||
case bfloat16:
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
case float32:
|
||||
return CUDNN_DATA_FLOAT;
|
||||
case float64:
|
||||
return CUDNN_DATA_DOUBLE;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
inline uint8_t get_alignment(const array& x) {
|
||||
uint8_t alignment = 1;
|
||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||
for (; alignment < 32; alignment *= 2) {
|
||||
if (address % (alignment * 2)) {
|
||||
return alignment;
|
||||
}
|
||||
}
|
||||
return alignment;
|
||||
}
|
||||
|
||||
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
||||
auto [shape, strides] = nhwc_to_nchw(x);
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(shape.size(), shape.data())
|
||||
.setStrides(strides.size(), strides.data())
|
||||
.setId(id)
|
||||
.setAlignment(get_alignment(x))
|
||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||
.build();
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigList get_engine_configs(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph,
|
||||
bool use_fallback = false) {
|
||||
cudnn_frontend::GeneratorSource source;
|
||||
if (use_fallback) {
|
||||
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setOperation(backend_type)
|
||||
.build();
|
||||
return fallback.getFallbackList();
|
||||
};
|
||||
} else {
|
||||
source = [](cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||
.build();
|
||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||
};
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
||||
auto configs = generator.generate_engine_config(op_graph);
|
||||
|
||||
cudnn_frontend::EngineConfigList filtered_configs;
|
||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||
if (cudnn_frontend::hasNumericalNote<
|
||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||
return true;
|
||||
}
|
||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||
dtype == float32 && !env::enable_tf32()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return filtered_configs;
|
||||
}
|
||||
|
||||
bool execute_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y) {
|
||||
int workspace_size = plan.getWorkspaceSize();
|
||||
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
||||
|
||||
int64_t uids[3] = {'x', 'w', 'y'};
|
||||
void* data_ptrs[3] = {
|
||||
x.data<void>(),
|
||||
w.data<void>(),
|
||||
y.data<void>(),
|
||||
};
|
||||
|
||||
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
||||
.setWorkspacePointer(workspace.data<void>())
|
||||
.setDataPointers(3, data_ptrs)
|
||||
.setUids(3, uids)
|
||||
.build();
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||
cudaGraph_t graph;
|
||||
cudaGraphCreate(&graph, 0);
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||
if (cudnnBackendPopulateCudaGraph(
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
encoder.add_graph_node(graph);
|
||||
#else
|
||||
auto capture = encoder.capture_context();
|
||||
if (cudnnBackendExecute(
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
// Discard the captured graph when failed.
|
||||
capture.discard = true;
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
encoder.add_temporary(workspace);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool try_engines(
|
||||
cu::CommandEncoder& encoder,
|
||||
const ConvCacheKey& cache_key,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
cudnn_frontend::EngineConfigList& configs,
|
||||
const std::string& op_graph_tag,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y) {
|
||||
for (auto& config : configs) {
|
||||
try {
|
||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setEngineConfig(config, op_graph_tag)
|
||||
.build();
|
||||
if (execute_plan(encoder, plan, x, w, y)) {
|
||||
conv_cache().emplace(
|
||||
cache_key, std::make_pair(backend_type, std::move(plan)));
|
||||
return true;
|
||||
}
|
||||
} catch (cudnn_frontend::cudnnException& error) {
|
||||
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
auto get_conv_op_settings(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
array& x,
|
||||
@@ -291,7 +95,7 @@ auto get_conv_op_settings(
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
||||
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
@@ -317,9 +121,9 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
||||
.build();
|
||||
|
||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||
.setxDesc(build_tensor('x', x))
|
||||
.setwDesc(build_tensor('w', w))
|
||||
.setyDesc(build_tensor('y', y))
|
||||
.setxDesc(build_cudnn_tensor_nchw('x', x))
|
||||
.setwDesc(build_cudnn_tensor_nchw('w', w))
|
||||
.setyDesc(build_cudnn_tensor_nchw('y', y))
|
||||
.setcDesc(conv_desc)
|
||||
.build();
|
||||
|
||||
@@ -336,6 +140,42 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
|
||||
array group_transpose(
|
||||
const array& x,
|
||||
int groups,
|
||||
int group_dim,
|
||||
int axis1,
|
||||
int axis2,
|
||||
Stream s) {
|
||||
if (groups == 1) {
|
||||
return swapaxes_in_eval(x, axis1, axis2);
|
||||
}
|
||||
int ndim = x.ndim();
|
||||
if (group_dim < 0) {
|
||||
group_dim += ndim;
|
||||
}
|
||||
if (axis1 < 0) {
|
||||
axis1 += ndim;
|
||||
}
|
||||
if (axis2 < 0) {
|
||||
axis2 += ndim;
|
||||
}
|
||||
if (group_dim <= axis1) {
|
||||
axis1 += 1;
|
||||
}
|
||||
if (group_dim <= axis2) {
|
||||
axis2 += 1;
|
||||
}
|
||||
auto shape = x.shape();
|
||||
shape.insert(shape.begin() + group_dim, groups);
|
||||
shape[group_dim + 1] = shape[group_dim + 1] / groups;
|
||||
array x_trans = reshape_in_eval(x, std::move(shape), s);
|
||||
x_trans = swapaxes_in_eval(x_trans, axis1, axis2);
|
||||
x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s);
|
||||
return x_trans;
|
||||
}
|
||||
|
||||
// Do necessary transposes and copies to prepare the inputs and outputs for
|
||||
// building the cuDNN conv op. It is safe to be called multiple times in one
|
||||
// eval_gpu, with cost of possible redundant copies.
|
||||
@@ -345,13 +185,14 @@ std::tuple<array, array, array> prepare_args(
|
||||
array in,
|
||||
array wt,
|
||||
array out,
|
||||
int groups,
|
||||
Stream s) {
|
||||
// Transpose the args depending on the backend type.
|
||||
// TODO: Handle groups.
|
||||
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||
wt = swapaxes_in_eval(wt, 0, -1);
|
||||
wt = group_transpose(wt, groups, 0, 0, -1, s);
|
||||
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||
in = swapaxes_in_eval(in, 0, -1);
|
||||
in = group_transpose(in, groups, -1, 0, -1, s);
|
||||
wt = swapaxes_in_eval(wt, 0, -1);
|
||||
// Create a contiguous array that shares the data with |out|, but with dim
|
||||
// C_in and C_out swapped.
|
||||
@@ -444,12 +285,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
ConvCacheKey cache_key{
|
||||
encoder.device().cuda_device(),
|
||||
dtype_to_cudnn_type(dtype),
|
||||
fixed_vector(in.shape()),
|
||||
fixed_vector(wt.shape()),
|
||||
fixed_vector(kernel_strides_),
|
||||
fixed_vector(padding_lo_),
|
||||
fixed_vector(padding_hi_),
|
||||
fixed_vector(kernel_dilation_),
|
||||
vector_key(in.shape()),
|
||||
vector_key(wt.shape()),
|
||||
vector_key(kernel_strides_),
|
||||
vector_key(padding_lo_),
|
||||
vector_key(padding_hi_),
|
||||
vector_key(kernel_dilation_),
|
||||
groups_,
|
||||
flip_,
|
||||
get_alignment(in),
|
||||
@@ -457,11 +298,29 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
get_alignment(out)};
|
||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||
auto& [backend_type, plan] = it->second;
|
||||
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (!execute_plan(encoder, plan, x, w, y)) {
|
||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||
if (plan) {
|
||||
// Run cached plan.
|
||||
std::tie(in, wt, out) =
|
||||
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||
}
|
||||
} else {
|
||||
// Run fallback kernel.
|
||||
gemm_conv(
|
||||
encoder,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
kernel_strides_,
|
||||
padding_lo_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_,
|
||||
s);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -490,7 +349,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
||||
for (auto try_backend : try_backends) {
|
||||
auto [in_copy, wt_copy, out_copy] =
|
||||
prepare_args(encoder, try_backend, in, wt, out, s);
|
||||
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
|
||||
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
||||
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
||||
try_backend,
|
||||
@@ -502,7 +361,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
padding_hi_,
|
||||
kernel_dilation_,
|
||||
input_dilation_);
|
||||
op_graph = build_op_graph(
|
||||
op_graph = build_conv_op_graph(
|
||||
encoder,
|
||||
try_backend,
|
||||
dtype,
|
||||
@@ -521,26 +380,38 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!op_graph) {
|
||||
throw std::runtime_error("[conv] Can not build op graph.");
|
||||
|
||||
if (op_graph) {
|
||||
// Find a plan for the graph and execute it.
|
||||
auto plan = find_cudnn_plan_from_op_graph(
|
||||
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
||||
if (plan) {
|
||||
// Setup inputs and outputs.
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||
conv_cache().emplace(
|
||||
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get ready to execute the graph.
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
|
||||
// Try to run plans based on heuristics.
|
||||
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||
auto tag = op_graph->getTag();
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||
return;
|
||||
}
|
||||
// Then try fallback plans.
|
||||
configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("[conv] Unable to find a working engine.");
|
||||
// Use fallback kernel for settings not supported by cuDNN.
|
||||
gemm_conv(
|
||||
encoder,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
kernel_strides_,
|
||||
padding_lo_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_,
|
||||
s);
|
||||
conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
126
mlx/backend/cuda/conv/conv.h
Normal file
126
mlx/backend/cuda/conv/conv.h
Normal file
@@ -0,0 +1,126 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <int NDIM>
|
||||
struct ConvParams {
|
||||
int N; // Batch size
|
||||
int C; // In channels
|
||||
int O; // Out channels
|
||||
int strides[NDIM];
|
||||
int padding[NDIM];
|
||||
int kernel_dilation[NDIM];
|
||||
int input_dilation[NDIM];
|
||||
int groups;
|
||||
bool flip;
|
||||
int in_spatial_dims[NDIM];
|
||||
int wt_spatial_dims[NDIM];
|
||||
int out_spatial_dims[NDIM];
|
||||
int64_t in_strides[NDIM + 2];
|
||||
|
||||
ConvParams(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
const array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
int groups,
|
||||
bool flip)
|
||||
: N(in.shape(0)),
|
||||
C(in.shape(-1)),
|
||||
O(wt.shape(0)),
|
||||
groups(groups),
|
||||
flip(flip) {
|
||||
std::copy_n(strides.begin(), NDIM, this->strides);
|
||||
std::copy_n(padding.begin(), NDIM, this->padding);
|
||||
std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation);
|
||||
std::copy_n(input_dilation.begin(), NDIM, this->input_dilation);
|
||||
std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims);
|
||||
std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims);
|
||||
std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims);
|
||||
std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides);
|
||||
}
|
||||
};
|
||||
|
||||
void gemm_grouped_conv(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
int groups,
|
||||
bool flip,
|
||||
Stream s);
|
||||
|
||||
void gemm_conv(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
bool flip,
|
||||
Stream s);
|
||||
|
||||
inline void gemm_conv(
|
||||
cu::CommandEncoder& encoder,
|
||||
array in,
|
||||
array wt,
|
||||
array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
int groups,
|
||||
bool flip,
|
||||
Stream s) {
|
||||
if (!in.flags().row_contiguous) {
|
||||
in = contiguous_copy_gpu(in, s);
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
if (!wt.flags().row_contiguous) {
|
||||
wt = contiguous_copy_gpu(wt, s);
|
||||
encoder.add_temporary(wt);
|
||||
}
|
||||
|
||||
if (groups == 1) {
|
||||
gemm_conv(
|
||||
encoder,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
strides,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
flip,
|
||||
s);
|
||||
} else {
|
||||
gemm_grouped_conv(
|
||||
encoder,
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
strides,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
groups,
|
||||
flip,
|
||||
s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
217
mlx/backend/cuda/conv/gemm_conv.cu
Normal file
217
mlx/backend/cuda/conv/gemm_conv.cu
Normal file
@@ -0,0 +1,217 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/conv/conv.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, int NDIM>
|
||||
__global__ void naive_unfold_nd(
|
||||
const T* in,
|
||||
T* out,
|
||||
int filter_size,
|
||||
int out_pixels,
|
||||
const __grid_constant__ ConvParams<NDIM> params) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto tid = block.group_index();
|
||||
auto lid = block.thread_index();
|
||||
|
||||
int index_batch = tid.z / out_pixels; // [0, N)
|
||||
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
|
||||
int index_wt_spatial =
|
||||
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
|
||||
|
||||
if (index_wt_spatial >= filter_size / params.C) {
|
||||
return;
|
||||
}
|
||||
|
||||
in += tid.y; // [0, C)
|
||||
out += tid.z * filter_size + index_wt_spatial * params.C + tid.y;
|
||||
|
||||
bool valid = index_batch < params.N;
|
||||
|
||||
// Get the coordinates in input.
|
||||
int index_in[NDIM] = {};
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int index_out = index_out_spatial % params.out_spatial_dims[i];
|
||||
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
|
||||
|
||||
if (params.flip) {
|
||||
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
|
||||
}
|
||||
|
||||
int index = index_out * params.strides[i] - params.padding[i] +
|
||||
index_wt * params.kernel_dilation[i];
|
||||
int index_max =
|
||||
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
|
||||
|
||||
valid &= (index >= 0) && (index < index_max) &&
|
||||
(index % params.input_dilation[i] == 0);
|
||||
|
||||
index_in[i] = index / params.input_dilation[i];
|
||||
|
||||
index_out_spatial /= params.out_spatial_dims[i];
|
||||
index_wt_spatial /= params.wt_spatial_dims[i];
|
||||
}
|
||||
|
||||
if (valid) {
|
||||
int in_offset = index_batch * params.in_strides[0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
in_offset += index_in[i] * params.in_strides[i + 1];
|
||||
}
|
||||
*out = in[in_offset];
|
||||
} else {
|
||||
*out = T{0};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <int NDIM>
|
||||
array unfold_inputs_nd(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
int mat_M,
|
||||
int mat_K,
|
||||
int mat_N,
|
||||
ConvParams<NDIM>& params) {
|
||||
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
|
||||
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
||||
encoder.add_temporary(unfolded);
|
||||
|
||||
int filter_size = params.C;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
filter_size *= params.wt_spatial_dims[i];
|
||||
}
|
||||
|
||||
int out_pixels = 1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
out_pixels *= params.out_spatial_dims[i];
|
||||
}
|
||||
|
||||
int wt_spatial_size = mat_K / params.C;
|
||||
dim3 block_dims;
|
||||
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
|
||||
dim3 num_blocks;
|
||||
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
|
||||
num_blocks.y = params.C;
|
||||
num_blocks.z = mat_M;
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(unfolded);
|
||||
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
encoder.add_kernel_node(
|
||||
cu::naive_unfold_nd<DataType, NDIM>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<DataType>(),
|
||||
unfolded.data<DataType>(),
|
||||
filter_size,
|
||||
out_pixels,
|
||||
params);
|
||||
});
|
||||
|
||||
return unfolded;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
void gemm_conv_nd(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
ConvParams<NDIM>& params,
|
||||
Stream s) {
|
||||
// Get gemm shapes.
|
||||
int mat_M = out.size() / params.O; // N * H_out * W_out
|
||||
int mat_K = wt.size() / params.O; // C * H_wt * W_wt
|
||||
int mat_N = params.O; // O
|
||||
|
||||
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
|
||||
array in_unfolded =
|
||||
unfold_inputs_nd<NDIM>(encoder, in, mat_M, mat_K, mat_N, params);
|
||||
|
||||
// Reshape weight to (C * H_wt * W_wt, O) for gemm.
|
||||
array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {});
|
||||
wt_reshaped.copy_shared_buffer(
|
||||
wt,
|
||||
{1, mat_K},
|
||||
{false, false, /* col_contiguous */ true},
|
||||
wt.data_size());
|
||||
|
||||
// Single batch.
|
||||
Shape batch_shape{1};
|
||||
Strides a_batch_strides{0};
|
||||
Strides b_batch_strides{0};
|
||||
|
||||
// Run matmul.
|
||||
CublasGemm gemm(
|
||||
encoder.device(),
|
||||
in.dtype(),
|
||||
false, // a_transposed
|
||||
mat_M, // a_rows
|
||||
mat_K, // a_cols
|
||||
mat_K, // lda
|
||||
true, // b_transposed
|
||||
mat_K, // b_rows
|
||||
mat_N, // b_cols
|
||||
mat_K, // ldb
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
gemm.run(
|
||||
encoder,
|
||||
out,
|
||||
in_unfolded,
|
||||
wt_reshaped,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides);
|
||||
}
|
||||
|
||||
void gemm_conv(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
bool flip,
|
||||
Stream s) {
|
||||
int conv_ndim = in.ndim() - 2;
|
||||
if (conv_ndim < 1 || conv_ndim > 3) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
|
||||
}
|
||||
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
|
||||
ConvParams<ndim_constant()> params(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
strides,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
1, // groups
|
||||
flip);
|
||||
gemm_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
231
mlx/backend/cuda/conv/gemm_grouped_conv.cu
Normal file
231
mlx/backend/cuda/conv/gemm_grouped_conv.cu
Normal file
@@ -0,0 +1,231 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/conv/conv.h"
|
||||
#include "mlx/backend/cuda/gemms/cublas_gemm.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, int NDIM>
|
||||
__global__ void naive_grouped_unfold_transpose_nd(
|
||||
const T* in,
|
||||
T* out,
|
||||
int filter_size,
|
||||
int out_pixels,
|
||||
const __grid_constant__ ConvParams<NDIM> params) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto tid = block.group_index();
|
||||
auto lid = block.thread_index();
|
||||
|
||||
int index_batch = tid.z / out_pixels; // [0, N)
|
||||
int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out)
|
||||
int index_wt_spatial =
|
||||
tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt)
|
||||
|
||||
if (index_wt_spatial >= filter_size / params.C) {
|
||||
return;
|
||||
}
|
||||
|
||||
in += tid.y; // [0, C)
|
||||
out += tid.z * filter_size + tid.y * (filter_size / params.C);
|
||||
|
||||
bool valid = index_batch < params.N;
|
||||
|
||||
// Get the coordinates in input.
|
||||
int index_in[NDIM] = {};
|
||||
int wt_stride = 1;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int index_out = index_out_spatial % params.out_spatial_dims[i];
|
||||
int index_wt = index_wt_spatial % params.wt_spatial_dims[i];
|
||||
out += index_wt * wt_stride;
|
||||
|
||||
if (params.flip) {
|
||||
index_wt = params.wt_spatial_dims[i] - index_wt - 1;
|
||||
}
|
||||
|
||||
int index = index_out * params.strides[i] - params.padding[i] +
|
||||
index_wt * params.kernel_dilation[i];
|
||||
int index_max =
|
||||
1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1);
|
||||
|
||||
valid &= (index >= 0) && (index < index_max) &&
|
||||
(index % params.input_dilation[i] == 0);
|
||||
|
||||
index_in[i] = index / params.input_dilation[i];
|
||||
|
||||
index_out_spatial /= params.out_spatial_dims[i];
|
||||
index_wt_spatial /= params.wt_spatial_dims[i];
|
||||
wt_stride *= params.wt_spatial_dims[i];
|
||||
}
|
||||
|
||||
if (valid) {
|
||||
int in_offset = index_batch * params.in_strides[0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
in_offset += index_in[i] * params.in_strides[i + 1];
|
||||
}
|
||||
*out = in[in_offset];
|
||||
} else {
|
||||
*out = T{0};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <int NDIM>
|
||||
array grouped_unfold_transpose_inputs_nd(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
int mat_M,
|
||||
int mat_K,
|
||||
int mat_N,
|
||||
ConvParams<NDIM>& params) {
|
||||
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
|
||||
unfolded.set_data(allocator::malloc(unfolded.nbytes()));
|
||||
encoder.add_temporary(unfolded);
|
||||
|
||||
int filter_size = params.C;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
filter_size *= params.wt_spatial_dims[i];
|
||||
}
|
||||
|
||||
int out_pixels = 1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
out_pixels *= params.out_spatial_dims[i];
|
||||
}
|
||||
|
||||
int wt_spatial_size = (mat_K * params.groups) / params.C;
|
||||
dim3 block_dims;
|
||||
block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024);
|
||||
dim3 num_blocks;
|
||||
num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x);
|
||||
num_blocks.y = params.C;
|
||||
num_blocks.z = mat_M;
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(unfolded);
|
||||
dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
encoder.add_kernel_node(
|
||||
cu::naive_grouped_unfold_transpose_nd<DataType, NDIM>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<DataType>(),
|
||||
unfolded.data<DataType>(),
|
||||
filter_size,
|
||||
out_pixels,
|
||||
params);
|
||||
});
|
||||
|
||||
return unfolded;
|
||||
}
|
||||
|
||||
template <int NDIM>
|
||||
void gemm_grouped_conv_nd(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
ConvParams<NDIM>& params,
|
||||
Stream s) {
|
||||
// Get gemm shapes.
|
||||
int C_per_group = params.C / params.groups;
|
||||
int O_per_group = params.O / params.groups;
|
||||
int mat_M = out.size() / params.O; // N * H_out * W_out
|
||||
int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt
|
||||
int mat_N = O_per_group; // O_per_group
|
||||
|
||||
// Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm.
|
||||
array in_unfolded = grouped_unfold_transpose_inputs_nd<NDIM>(
|
||||
encoder, in, mat_M, mat_K, mat_N, params);
|
||||
|
||||
// Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm.
|
||||
int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1);
|
||||
array wt_view(
|
||||
{params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {});
|
||||
wt_view.copy_shared_buffer(
|
||||
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||
array wt_reshaped = contiguous_copy_gpu(wt_view, s);
|
||||
|
||||
// Batch with size of groups.
|
||||
Shape batch_shape{params.groups};
|
||||
Strides a_batch_strides{mat_K};
|
||||
Strides b_batch_strides{mat_N * mat_K};
|
||||
|
||||
// Run matmul.
|
||||
CublasGemm gemm(
|
||||
encoder.device(),
|
||||
in.dtype(),
|
||||
false, // a_transposed
|
||||
mat_M, // a_rows
|
||||
mat_K, // a_cols
|
||||
mat_K * params.groups, // lda
|
||||
true, // b_transposed
|
||||
mat_K, // b_rows
|
||||
mat_N, // b_cols
|
||||
mat_K, // ldb
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
gemm.set_out(
|
||||
out.dtype(),
|
||||
false, // out_transposed
|
||||
mat_M, // out_rows
|
||||
mat_N, // out_cols
|
||||
mat_N * params.groups, // out_ld
|
||||
params.groups, // batch_count
|
||||
mat_N); // batch_stride
|
||||
gemm.run(
|
||||
encoder,
|
||||
out,
|
||||
in_unfolded,
|
||||
wt_reshaped,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides);
|
||||
}
|
||||
|
||||
void gemm_grouped_conv(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array& out,
|
||||
const std::vector<int>& strides,
|
||||
const std::vector<int>& padding,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation,
|
||||
int groups,
|
||||
bool flip,
|
||||
Stream s) {
|
||||
int conv_ndim = in.ndim() - 2;
|
||||
if (conv_ndim < 1 || conv_ndim > 3) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim));
|
||||
}
|
||||
dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) {
|
||||
ConvParams<ndim_constant()> params(
|
||||
in,
|
||||
wt,
|
||||
out,
|
||||
strides,
|
||||
padding,
|
||||
kernel_dilation,
|
||||
input_dilation,
|
||||
groups,
|
||||
flip);
|
||||
gemm_grouped_conv_nd<ndim_constant()>(encoder, in, wt, out, params, s);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -15,8 +15,8 @@ void copy_gpu_inplace(
|
||||
int64_t offset_out,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_offset_in,
|
||||
const std::optional<array>& dynamic_offset_out) {
|
||||
std::optional<array> dynamic_offset_in,
|
||||
std::optional<array> dynamic_offset_out) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -44,6 +44,16 @@ void copy_gpu_inplace(
|
||||
strides_vec[0]);
|
||||
} else {
|
||||
if (dynamic_offset_in || dynamic_offset_out) {
|
||||
if (!dynamic_offset_in) {
|
||||
dynamic_offset_in = array(0, int64);
|
||||
encoder.add_temporary(*dynamic_offset_in);
|
||||
}
|
||||
if (!dynamic_offset_out) {
|
||||
dynamic_offset_out = array(0, int64);
|
||||
encoder.add_temporary(*dynamic_offset_out);
|
||||
}
|
||||
encoder.set_input_array(*dynamic_offset_in);
|
||||
encoder.set_input_array(*dynamic_offset_out);
|
||||
copy_general_dynamic(
|
||||
encoder,
|
||||
ctype,
|
||||
@@ -54,8 +64,8 @@ void copy_gpu_inplace(
|
||||
shape_collapsed,
|
||||
strides_vec[0],
|
||||
strides_vec[1],
|
||||
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
|
||||
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
|
||||
*dynamic_offset_in,
|
||||
*dynamic_offset_out);
|
||||
} else {
|
||||
copy_general(
|
||||
encoder,
|
||||
|
||||
275
mlx/backend/cuda/cudnn_utils.cpp
Normal file
275
mlx/backend/cuda/cudnn_utils.cpp
Normal file
@@ -0,0 +1,275 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Create a cudnn tensor descriptor.
|
||||
template <typename Vec>
|
||||
inline cudnn_frontend::Tensor build_cudnn_tensor(
|
||||
int64_t id,
|
||||
const array& x,
|
||||
const Vec& shape,
|
||||
const Vec& strides) {
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(shape.size(), shape.data())
|
||||
.setStrides(strides.size(), strides.data())
|
||||
.setId(id)
|
||||
.setAlignment(get_alignment(x))
|
||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||
.build();
|
||||
}
|
||||
|
||||
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
|
||||
// whether a tensor is contiguous is determined with:
|
||||
// shape[dim] == shape[dim + 1] * strides[dim + 1]
|
||||
// So a contiguous array with singleton dims in MLX may be mistakenly treated
|
||||
// as strided in cuDNN, and we work around it by normalizing the strides.
|
||||
Strides normalized_strides(const array& x) {
|
||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||
return x.strides();
|
||||
}
|
||||
Strides strides = x.strides();
|
||||
for (int i = x.ndim() - 2; i >= 0; --i) {
|
||||
if (x.shape(i) == 1) {
|
||||
strides[i] = x.shape(i + 1) * strides[i + 1];
|
||||
}
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
// Return the shape and strides after transposing from NHWC to NCHW.
|
||||
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
||||
assert(shape.size() >= 3);
|
||||
shape.insert(shape.begin() + 1, shape.back());
|
||||
shape.erase(shape.end() - 1);
|
||||
strides.insert(strides.begin() + 1, strides.back());
|
||||
strides.erase(strides.end() - 1);
|
||||
return std::make_tuple(std::move(shape), std::move(strides));
|
||||
}
|
||||
|
||||
inline auto nhwc_to_nchw(const array& x) {
|
||||
return nhwc_to_nchw(
|
||||
convert_vector<int64_t>(x.shape()), normalized_strides(x));
|
||||
}
|
||||
|
||||
// Return available engines for a |op_graph|.
|
||||
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph,
|
||||
bool use_fallback = true) {
|
||||
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
|
||||
sources.push_back([](auto& op_graph) {
|
||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||
.build();
|
||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||
});
|
||||
if (use_fallback) {
|
||||
sources.push_back([&backend_type](auto& op_graph) {
|
||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setOperation(backend_type)
|
||||
.build();
|
||||
return fallback.getFallbackList();
|
||||
});
|
||||
}
|
||||
|
||||
auto configs =
|
||||
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
|
||||
.generate_engine_config(op_graph);
|
||||
|
||||
cudnn_frontend::EngineConfigList filtered_configs;
|
||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||
if (cudnn_frontend::hasNumericalNote<
|
||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||
return true;
|
||||
}
|
||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||
dtype == float32 && !env::enable_tf32()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return filtered_configs;
|
||||
}
|
||||
|
||||
// Take |engine_configs| and |op_graph| and find a working execution plans
|
||||
// from them.
|
||||
std::optional<cudnn_frontend::ExecutionPlan>
|
||||
find_cudnn_plan_from_engine_configs(
|
||||
cudnnHandle_t handle,
|
||||
const cudnn_frontend::EngineConfigList& engine_configs,
|
||||
const cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto op_graph_tag = op_graph.getTag();
|
||||
for (const auto& config : engine_configs) {
|
||||
try {
|
||||
return cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(handle)
|
||||
.setEngineConfig(config, op_graph_tag)
|
||||
.build();
|
||||
} catch (cudnn_frontend::cudnnException& error) {
|
||||
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Prepare workspace and args to execute plan.
|
||||
template <typename F>
|
||||
bool prepare_cudnn_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs,
|
||||
F&& execute) {
|
||||
int workspace_size = plan.getWorkspaceSize();
|
||||
array workspace(
|
||||
workspace_size > 0 ? allocator::malloc(workspace_size)
|
||||
: allocator::Buffer(nullptr),
|
||||
{workspace_size},
|
||||
uint8);
|
||||
|
||||
auto args = cudnn_frontend::VariantPackBuilder()
|
||||
.setWorkspacePointer(workspace.data<void>())
|
||||
.setDataPointers(num_args, data_ptrs)
|
||||
.setUids(num_args, uids)
|
||||
.build();
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
encoder.add_temporary(workspace);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
||||
auto shape = convert_vector<int64_t>(x.shape());
|
||||
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
|
||||
}
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
||||
auto [shape, strides] = nhwc_to_nchw(x);
|
||||
return build_cudnn_tensor(id, x, shape, strides);
|
||||
}
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
|
||||
if (x.ndim() == 0) {
|
||||
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
|
||||
}
|
||||
if (x.ndim() == 1) {
|
||||
int64_t s = x.shape(0);
|
||||
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
|
||||
SmallVector<int64_t, 4> strides = {s, 1, s, s};
|
||||
return build_cudnn_tensor(id, x, shape, strides);
|
||||
}
|
||||
if (x.ndim() == 2) {
|
||||
int64_t s =
|
||||
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
|
||||
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
||||
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
||||
return build_cudnn_tensor(id, x, shape, strides);
|
||||
}
|
||||
if (x.ndim() == 3 || x.ndim() == 4) {
|
||||
return build_cudnn_tensor_nchw(id, x);
|
||||
}
|
||||
throw std::runtime_error(
|
||||
fmt::format("Unsupported array with {} dims.", x.ndim()));
|
||||
}
|
||||
|
||||
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
|
||||
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(scalar_dims.size(), scalar_dims.data())
|
||||
.setStrides(scalar_dims.size(), scalar_dims.data())
|
||||
.setId(id)
|
||||
.setAlignment(16)
|
||||
.setDataType(dtype_to_cudnn_type(dtype))
|
||||
.setByValue(true)
|
||||
.build();
|
||||
}
|
||||
|
||||
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||
cudnnHandle_t handle,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
||||
if (engine_configs.empty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
||||
}
|
||||
|
||||
bool encode_cudnn_plan_with_capturing(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs) {
|
||||
return prepare_cudnn_plan(
|
||||
encoder,
|
||||
plan,
|
||||
num_args,
|
||||
uids,
|
||||
data_ptrs,
|
||||
[&](auto handle, auto plan, auto args) {
|
||||
auto capture = encoder.capture_context();
|
||||
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
|
||||
// Discard the captured graph when failed.
|
||||
capture.discard = true;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
#if CUDNN_VERSION >= 90500
|
||||
bool encode_cudnn_plan_with_graph_api(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
CudaGraph& graph,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs) {
|
||||
return prepare_cudnn_plan(
|
||||
encoder,
|
||||
plan,
|
||||
num_args,
|
||||
uids,
|
||||
data_ptrs,
|
||||
[&](auto handle, auto plan, auto args) {
|
||||
if (!graph) {
|
||||
graph = CudaGraph(encoder.device());
|
||||
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
encoder.add_graph_node(graph);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace mlx::core
|
||||
164
mlx/backend/cuda/cudnn_utils.h
Normal file
164
mlx/backend/cuda/cudnn_utils.h
Normal file
@@ -0,0 +1,164 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include <cudnn_frontend_find_plan.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class CommandEncoder;
|
||||
}
|
||||
|
||||
// Return pointer alignment of |x|'s data.
|
||||
inline uint8_t get_alignment(const array& x) {
|
||||
uint8_t alignment = 1;
|
||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||
for (; alignment < 32; alignment *= 2) {
|
||||
if (address % (alignment * 2)) {
|
||||
return alignment;
|
||||
}
|
||||
}
|
||||
return alignment;
|
||||
}
|
||||
|
||||
// Convert the type of elements in |vec| to |T|.
|
||||
template <typename T, typename Vec>
|
||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||
return SmallVector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
|
||||
//
|
||||
// There are 2 differences from the const_param util from kernel_utils.cuh:
|
||||
// 1. The rest of array is filled with 0.
|
||||
// 2. This util can be used in .cpp files.
|
||||
template <typename T, template <typename U> class Vec>
|
||||
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
}
|
||||
std::array<T, MAX_NDIM> result = {};
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
// Helpers used by get_data_ptrs to get pointers.
|
||||
inline void* get_data_ptr(const array& arr) {
|
||||
return const_cast<void*>(arr.data<void>());
|
||||
}
|
||||
|
||||
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
|
||||
inline void* get_data_ptr(T& scalar) {
|
||||
return &scalar;
|
||||
}
|
||||
|
||||
// Return an array filled with data pointers of args.
|
||||
template <typename... Args>
|
||||
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
|
||||
return {get_data_ptr(args)...};
|
||||
}
|
||||
|
||||
// Map dtype to cudnn data type.
|
||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
return CUDNN_DATA_INT8;
|
||||
case int32:
|
||||
return CUDNN_DATA_INT32;
|
||||
case uint8:
|
||||
return CUDNN_DATA_UINT8;
|
||||
case float16:
|
||||
return CUDNN_DATA_HALF;
|
||||
case bfloat16:
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
case float32:
|
||||
return CUDNN_DATA_FLOAT;
|
||||
case float64:
|
||||
return CUDNN_DATA_DOUBLE;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
// Create a tensor descriptor from |x|.
|
||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
|
||||
|
||||
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
|
||||
|
||||
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
|
||||
// from NHWC to NCHW.
|
||||
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
|
||||
|
||||
// Create a 4D scalar tensor descriptor, which is passed by value.
|
||||
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
|
||||
|
||||
// Find a working plan for |op_graph|.
|
||||
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
||||
cudnnHandle_t handle,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph);
|
||||
|
||||
// Encode the plan to command buffer by capturing.
|
||||
bool encode_cudnn_plan_with_capturing(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs);
|
||||
|
||||
#if CUDNN_VERSION >= 90500
|
||||
// Encode the plan to command buffer by using native graph api of cudnn. If the
|
||||
// |graph| is empty it will be populated, otherwise it will be updated.
|
||||
bool encode_cudnn_plan_with_graph_api(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
CudaGraph& graph,
|
||||
int num_args,
|
||||
const int64_t* uids,
|
||||
void** data_ptrs);
|
||||
#endif
|
||||
|
||||
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
|
||||
template <typename... Args>
|
||||
bool encode_cudnn_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
std::initializer_list<int64_t> uids,
|
||||
Args&... args) {
|
||||
assert(uids.size() == sizeof...(args));
|
||||
auto data_ptrs = get_data_ptrs(args...);
|
||||
return encode_cudnn_plan_with_capturing(
|
||||
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
|
||||
}
|
||||
|
||||
#if CUDNN_VERSION >= 90500
|
||||
template <typename... Args>
|
||||
bool encode_cudnn_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
CudaGraph& graph,
|
||||
std::initializer_list<int64_t> uids,
|
||||
Args&... args) {
|
||||
assert(uids.size() == sizeof...(args));
|
||||
auto data_ptrs = get_data_ptrs(args...);
|
||||
return encode_cudnn_plan_with_graph_api(
|
||||
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace mlx::core
|
||||
379
mlx/backend/cuda/custom_kernel.cpp
Normal file
379
mlx/backend/cuda/custom_kernel.cpp
Normal file
@@ -0,0 +1,379 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const char* default_header = R"(
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#define inf cuda::std::numeric_limits<float>::infinity()
|
||||
|
||||
)";
|
||||
|
||||
std::string template_arguments_hash(
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
|
||||
if (template_args.empty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string hash;
|
||||
hash.reserve(512);
|
||||
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
hash += fmt::format("_{}", std::get<int>(arg));
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
hash += (std::get<bool>(arg)) ? "_t" : "_f";
|
||||
} else if (std::holds_alternative<Dtype>(arg)) {
|
||||
hash += "_";
|
||||
hash += get_type_string(std::get<Dtype>(arg));
|
||||
}
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
|
||||
std::string build_kernel(
|
||||
const std::string& func_name,
|
||||
const std::string& header,
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 8192);
|
||||
kernel_source += default_header;
|
||||
kernel_source += header;
|
||||
kernel_source +=
|
||||
"namespace mlx::core::cu {\n\n"
|
||||
"namespace cg = cooperative_groups;\n\n";
|
||||
|
||||
kernel_source += "__global__ void ";
|
||||
kernel_source += func_name;
|
||||
kernel_source += "(\n";
|
||||
|
||||
// Add inputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const auto& name = input_names[i];
|
||||
const auto& arr = inputs[i];
|
||||
kernel_source += " const ";
|
||||
kernel_source += dtype_to_cuda_type(arr.dtype());
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
kernel_source += ",\n";
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
kernel_source += " const __grid_constant__ Shape ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_shape,\n";
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
kernel_source += " const __grid_constant__ Strides ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_strides,\n";
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
kernel_source += " const __grid_constant__ int ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_ndim,\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add outputs
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
const auto& name = output_names[i];
|
||||
const auto& dtype = output_dtypes[i];
|
||||
kernel_source += " ";
|
||||
kernel_source += dtype_to_cuda_type(dtype);
|
||||
kernel_source += "* ";
|
||||
kernel_source += name;
|
||||
if (i < output_names.size() - 1) {
|
||||
kernel_source += ",\n";
|
||||
} else {
|
||||
kernel_source += ") {\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Set compile time constants
|
||||
if (!template_args.empty()) {
|
||||
for (const auto& [name, arg] : template_args) {
|
||||
if (std::holds_alternative<int>(arg)) {
|
||||
kernel_source +=
|
||||
fmt::format(" constexpr int {} = {};\n", name, std::get<int>(arg));
|
||||
} else if (std::holds_alternative<bool>(arg)) {
|
||||
kernel_source += fmt::format(
|
||||
" constexpr bool {} = {};\n", name, std::get<bool>(arg));
|
||||
} else {
|
||||
kernel_source += fmt::format(
|
||||
" using {} = {};\n",
|
||||
name,
|
||||
dtype_to_cuda_type(std::get<Dtype>(arg)));
|
||||
}
|
||||
}
|
||||
kernel_source += "\n";
|
||||
}
|
||||
|
||||
kernel_source += source;
|
||||
kernel_source += "\n}\n\n} // namespace mlx::core::cu\n";
|
||||
|
||||
return kernel_source;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CustomKernelFunction cuda_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::string& source,
|
||||
const std::string& header,
|
||||
bool ensure_row_contiguous,
|
||||
int shared_memory) {
|
||||
if (output_names.empty()) {
|
||||
throw std::invalid_argument(
|
||||
"[custom_kernel] Must specify at least one output.");
|
||||
}
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
|
||||
return [=, shape_infos = std::move(shape_infos)](
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>&
|
||||
template_args = {},
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s_ = {}) {
|
||||
if (inputs.size() != input_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[custom_kernel] Expected `inputs` to have size "
|
||||
<< input_names.size() << " but got size " << inputs.size() << "."
|
||||
<< std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_shapes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[custom_kernel] Expected `output_shapes` to have size "
|
||||
<< output_names.size() << " but got size " << output_shapes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (output_dtypes.size() != output_names.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[custom_kernel] Expected `output_dtypes` to have size "
|
||||
<< output_names.size() << " but got size " << output_dtypes.size()
|
||||
<< "." << std::endl;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto s = to_stream(s_);
|
||||
if (s.device != Device::gpu) {
|
||||
throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
|
||||
}
|
||||
|
||||
std::string kernel_name =
|
||||
"custom_kernel_" + name + template_arguments_hash(template_args);
|
||||
std::string kernel_source = build_kernel(
|
||||
kernel_name,
|
||||
header,
|
||||
source,
|
||||
input_names,
|
||||
inputs,
|
||||
output_names,
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos);
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Generated source code for `" << kernel_name
|
||||
<< "`:" << std::endl
|
||||
<< "```" << std::endl
|
||||
<< kernel_source << std::endl
|
||||
<< "```" << std::endl;
|
||||
}
|
||||
|
||||
return array::make_arrays(
|
||||
std::move(output_shapes),
|
||||
std::move(output_dtypes),
|
||||
std::make_shared<CustomKernel>(
|
||||
s,
|
||||
std::move(kernel_name),
|
||||
std::move(kernel_source),
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value,
|
||||
std::vector<ScalarArg>{},
|
||||
false,
|
||||
shared_memory),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<array> precompiled_cuda_kernel(
|
||||
const std::string& name,
|
||||
const std::string& compiled_source,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<Shape>& output_shapes,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<ScalarArg>& scalars,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
int shared_memory,
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice s) {
|
||||
std::vector<CustomKernelShapeInfo> shape_infos(
|
||||
inputs.size(), CustomKernelShapeInfo{false, false, false});
|
||||
return array::make_arrays(
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
std::make_shared<CustomKernel>(
|
||||
to_stream(s),
|
||||
name,
|
||||
compiled_source,
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value,
|
||||
scalars,
|
||||
true,
|
||||
shared_memory),
|
||||
inputs);
|
||||
}
|
||||
|
||||
void CustomKernel::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("CustomKernel::eval_gpu");
|
||||
auto& s = stream();
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
// Allocate and initialize the output arrays
|
||||
for (auto& out : outputs) {
|
||||
if (init_value_) {
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
fill_gpu(copies.back(), out, s);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
// Create the input arrays and copy if needed
|
||||
auto check_input = [&copies, &s, this](const array& x) -> const array {
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (!ensure_row_contiguous_ || no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
std::vector<array> checked_inputs;
|
||||
for (const array& in : inputs) {
|
||||
checked_inputs.push_back(check_input(in));
|
||||
}
|
||||
|
||||
// Compile the custom kernel
|
||||
std::string kernel_name =
|
||||
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
|
||||
cu::JitModule& mod = cu::get_jit_module(
|
||||
s.device,
|
||||
name_,
|
||||
[&]() {
|
||||
return std::make_tuple(
|
||||
is_precompiled_, source_, std::vector{kernel_name});
|
||||
},
|
||||
false);
|
||||
|
||||
// Make the arguments
|
||||
cu::KernelArgs args;
|
||||
for (int i = 0; i < checked_inputs.size(); i++) {
|
||||
const array& in = checked_inputs[i];
|
||||
auto& shape_info = shape_infos_[i];
|
||||
args.append(in);
|
||||
if (shape_info.shape) {
|
||||
args.append_ndim(in.shape());
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
args.append_ndim(in.strides());
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
args.append<int32_t>(in.ndim());
|
||||
}
|
||||
}
|
||||
for (auto& out : outputs) {
|
||||
args.append(out);
|
||||
}
|
||||
for (auto& s : scalar_arguments_) {
|
||||
if (std::holds_alternative<bool>(s)) {
|
||||
args.append(std::get<bool>(s));
|
||||
} else if (std::holds_alternative<int>(s)) {
|
||||
args.append(std::get<int>(s));
|
||||
} else if (std::holds_alternative<float>(s)) {
|
||||
args.append(std::get<float>(s));
|
||||
}
|
||||
}
|
||||
|
||||
// Make the grid
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
|
||||
|
||||
// Call the kernel
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : checked_inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
for (const auto& out : outputs) {
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
for (const auto& t : copies) {
|
||||
encoder.add_temporary(t);
|
||||
}
|
||||
auto kernel =
|
||||
mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
|
||||
if (smem > 0 && smem > 48000) {
|
||||
cuFuncSetAttribute(
|
||||
kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
|
||||
}
|
||||
});
|
||||
encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
@@ -14,10 +14,6 @@ namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
// This should be less than 255
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||
|
||||
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
@@ -27,11 +23,11 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
}
|
||||
}
|
||||
|
||||
int cuda_graph_cache_size() {
|
||||
static int cache_size = []() {
|
||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
||||
bool use_cuda_graphs() {
|
||||
static bool use_graphs = []() {
|
||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||
}();
|
||||
return cache_size;
|
||||
return use_graphs;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -68,8 +64,8 @@ Device::~Device() {
|
||||
|
||||
void Device::make_current() {
|
||||
// We need to set/get current CUDA device very frequently, cache it to reduce
|
||||
// actual calls of CUDA APIs. This function assumes single-thread in host.
|
||||
static int current = 0;
|
||||
// actual calls of CUDA APIs.
|
||||
static thread_local int current = 0;
|
||||
if (current != device_) {
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(device_));
|
||||
current = device_;
|
||||
@@ -86,14 +82,20 @@ CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
|
||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
enc.device().make_current();
|
||||
if (!use_cuda_graphs()) {
|
||||
return;
|
||||
}
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
|
||||
if (!use_cuda_graphs()) {
|
||||
enc.node_count_++;
|
||||
return;
|
||||
}
|
||||
|
||||
graph.end_capture(enc.stream());
|
||||
if (discard) {
|
||||
return;
|
||||
}
|
||||
@@ -107,6 +109,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
||||
|
||||
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
||||
enc.in_concurrent_ = false;
|
||||
if (!use_cuda_graphs()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Use an empty graph node for synchronization
|
||||
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
||||
@@ -185,37 +190,46 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(Device& d)
|
||||
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
}
|
||||
: device_(d),
|
||||
stream_(d),
|
||||
graph_(d),
|
||||
worker_(d),
|
||||
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
}
|
||||
|
||||
void CommandEncoder::set_input_array(const array& arr) {
|
||||
if (!use_cuda_graphs()) {
|
||||
return;
|
||||
}
|
||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||
active_deps_.push_back(id);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_output_array(const array& arr) {
|
||||
if (!use_cuda_graphs()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||
active_deps_.push_back(id);
|
||||
active_outputs_.push_back(id);
|
||||
}
|
||||
|
||||
void CommandEncoder::maybe_commit() {
|
||||
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||
commit();
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(
|
||||
void* func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
if (!use_cuda_graphs()) {
|
||||
node_count_++;
|
||||
CHECK_CUDA_ERROR(cudaLaunchKernel(
|
||||
func, grid_dim, block_dim, params, smem_bytes, stream()));
|
||||
return;
|
||||
}
|
||||
cudaKernelNodeParams kernel_params = {0};
|
||||
kernel_params.func = func;
|
||||
kernel_params.gridDim = grid_dim;
|
||||
@@ -231,6 +245,23 @@ void CommandEncoder::add_kernel_node(
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
if (!use_cuda_graphs()) {
|
||||
node_count_++;
|
||||
CHECK_CUDA_ERROR(cuLaunchKernel(
|
||||
func,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z,
|
||||
smem_bytes,
|
||||
stream(),
|
||||
params,
|
||||
nullptr));
|
||||
return;
|
||||
}
|
||||
|
||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||
kernel_params.func = func;
|
||||
kernel_params.gridDimX = grid_dim.x;
|
||||
@@ -257,20 +288,38 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
}
|
||||
|
||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
if (!use_cuda_graphs()) {
|
||||
node_count_++;
|
||||
CudaGraphExec graph_exec;
|
||||
graph_exec.instantiate(child);
|
||||
device_.make_current();
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream()));
|
||||
return;
|
||||
}
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
}
|
||||
|
||||
int CommandEncoder::get_num_ops() {
|
||||
return node_count_;
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
nvtx3::scoped_range r("CommandEncoder::commit");
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
if (node_count_ > 0) {
|
||||
if (use_cuda_graphs() && node_count_ > 0) {
|
||||
if (!from_nodes_.empty()) {
|
||||
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
||||
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
|
||||
graph_,
|
||||
from_nodes_.data(),
|
||||
to_nodes_.data(),
|
||||
#if CUDART_VERSION >= 13000
|
||||
nullptr, // edgeData
|
||||
#endif // CUDART_VERSION >= 13000
|
||||
from_nodes_.size()));
|
||||
}
|
||||
|
||||
graph_key_ += ".";
|
||||
@@ -304,19 +353,18 @@ void CommandEncoder::commit() {
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
|
||||
// Reset state
|
||||
node_count_ = 0;
|
||||
graph_node_count_ = 0;
|
||||
empty_node_count_ = 0;
|
||||
from_nodes_.clear();
|
||||
to_nodes_.clear();
|
||||
graph_key_.clear();
|
||||
node_map_.clear();
|
||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
graph_ = CudaGraph(device_);
|
||||
}
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.commit(stream_);
|
||||
node_count_ = 0;
|
||||
}
|
||||
|
||||
void CommandEncoder::synchronize() {
|
||||
|
||||
@@ -21,7 +21,7 @@ class CommandEncoder {
|
||||
struct CaptureContext {
|
||||
CaptureContext(CommandEncoder& enc);
|
||||
~CaptureContext();
|
||||
cudaGraph_t graph;
|
||||
CudaGraph graph;
|
||||
CommandEncoder& enc;
|
||||
bool discard{false};
|
||||
};
|
||||
@@ -76,9 +76,6 @@ class CommandEncoder {
|
||||
uint32_t smem_bytes,
|
||||
void** params);
|
||||
|
||||
// Low-level graph helpers.
|
||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||
void add_graph_node(cudaGraph_t child);
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
@@ -86,7 +83,7 @@ class CommandEncoder {
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void maybe_commit();
|
||||
int get_num_ops();
|
||||
void commit();
|
||||
|
||||
Device& device() {
|
||||
@@ -101,6 +98,9 @@ class CommandEncoder {
|
||||
void synchronize();
|
||||
|
||||
private:
|
||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||
|
||||
struct GraphNode {
|
||||
cudaGraphNode_t node;
|
||||
// K = kernel
|
||||
@@ -115,7 +115,7 @@ class CommandEncoder {
|
||||
|
||||
Device& device_;
|
||||
CudaStream stream_;
|
||||
cudaGraph_t graph_;
|
||||
CudaGraph graph_;
|
||||
Worker worker_;
|
||||
char node_count_{0};
|
||||
char graph_node_count_{0};
|
||||
@@ -140,7 +140,7 @@ class Device {
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
|
||||
// Make this device the current cuda device, required by some cuda calls.
|
||||
// Make this device the current cuda device, this method is thread-safe.
|
||||
void make_current();
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
@@ -204,6 +204,12 @@ struct Power {
|
||||
__device__ T operator()(T base, T exp) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
T res = 1;
|
||||
// Raising an integer to a negative power is undefined
|
||||
if constexpr (cuda::std::is_signed_v<T>) {
|
||||
if (exp < 0) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
@@ -116,15 +115,4 @@ inline __host__ __device__ auto cast_to(SrcT x) {
|
||||
return CastOp<SrcT, DstT>{}(x);
|
||||
}
|
||||
|
||||
// Return an iterator that cast the value to DstT using CastOp.
|
||||
template <typename DstT, typename Iterator>
|
||||
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
|
||||
if constexpr (std::is_same_v<SrcT, DstT>) {
|
||||
return it;
|
||||
} else {
|
||||
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -257,8 +257,8 @@ struct Round {
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
T y = 1 / (1 + exp(-abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
T y = 1 / (1 + exp(abs(x)));
|
||||
return (x < 0) ? y : 1 - y;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file must not include any host-only code, utilies that work under both
|
||||
// This file must not include any host-only code, utilities that work under both
|
||||
// host and device can be put here.
|
||||
//
|
||||
// See more about the requirements at:
|
||||
@@ -202,7 +202,7 @@ struct Limits<
|
||||
}
|
||||
};
|
||||
|
||||
// CUDA 11 does not have host side arithmatic operators for half types.
|
||||
// CUDA 11 does not have host side arithmetic operators for half types.
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
|
||||
56
mlx/backend/cuda/distributed.cu
Normal file
56
mlx/backend/cuda/distributed.cu
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto set_input_output =
|
||||
[s = stream()](const array& in, array& out) -> std::pair<array, array> {
|
||||
if (!in.flags().row_contiguous) {
|
||||
copy_gpu(in, out, CopyType::General, s);
|
||||
return {out, out};
|
||||
} else if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
return {in, out};
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return {in, out};
|
||||
}
|
||||
};
|
||||
|
||||
auto [input, output] = set_input_output(inputs[0], outputs[0]);
|
||||
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(output);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
auto& s = stream();
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), input, output, s);
|
||||
break;
|
||||
case Max:
|
||||
distributed::detail::all_max(group(), input, output, s);
|
||||
break;
|
||||
case Min:
|
||||
distributed::detail::all_min(group(), input, output, s);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Only all reduce sum, max, and min are supported.");
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core::distributed
|
||||
@@ -5,18 +5,24 @@
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void new_stream(Stream s) {
|
||||
// Force initalization of cuda, so cuda runtime get destroyed at last.
|
||||
// Force initalization of CUDA, so CUDA runtime get destroyed at last.
|
||||
cudaFree(nullptr);
|
||||
// Make sure CUDA event pool get destroyed after device and stream.
|
||||
cu::CudaEvent::init_pool();
|
||||
// Ensure the static stream objects get created.
|
||||
cu::get_command_encoder(s);
|
||||
}
|
||||
@@ -34,7 +40,8 @@ void eval(array& arr) {
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||
auto& stream = arr.primitive().stream();
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
for (auto& in : arr.inputs()) {
|
||||
// Except for the donated one.
|
||||
@@ -45,7 +52,14 @@ void eval(array& arr) {
|
||||
for (auto& s : arr.siblings()) {
|
||||
encoder.add_temporary(s);
|
||||
}
|
||||
encoder.maybe_commit();
|
||||
|
||||
if (encoder.get_num_ops() >=
|
||||
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||
scheduler::notify_new_task(stream);
|
||||
encoder.add_completed_handler(
|
||||
[stream]() { scheduler::notify_task_completion(stream); });
|
||||
encoder.commit();
|
||||
}
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -17,104 +19,180 @@ namespace cu {
|
||||
// CudaEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Cuda event managed with RAII.
|
||||
class CudaEventHandle {
|
||||
namespace {
|
||||
|
||||
// Manage cached cudaEvent_t objects.
|
||||
class CudaEventPool {
|
||||
public:
|
||||
CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
|
||||
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
CudaEventHandle create(Device& d, int flags) {
|
||||
if (!on_creation_thread()) {
|
||||
return CudaEventHandle(d, flags);
|
||||
}
|
||||
auto& cache = cache_for(d, flags);
|
||||
if (cache.empty()) {
|
||||
return CudaEventHandle(d, flags);
|
||||
} else {
|
||||
CudaEventHandle ret = std::move(cache.back());
|
||||
cache.pop_back();
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
~CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
|
||||
}
|
||||
|
||||
CudaEventHandle(const CudaEventHandle&) = delete;
|
||||
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
|
||||
|
||||
operator cudaEvent_t() const {
|
||||
return event_;
|
||||
void release(CudaEventHandle event) {
|
||||
if (!on_creation_thread()) {
|
||||
// Event will be destroyed directly instead of getting moved to cache.
|
||||
return;
|
||||
}
|
||||
cache_for(event.device, event.flags).push_back(std::move(event));
|
||||
}
|
||||
|
||||
private:
|
||||
cudaEvent_t event_;
|
||||
std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
|
||||
return cache_[d.cuda_device()][flags];
|
||||
}
|
||||
|
||||
bool on_creation_thread() {
|
||||
return std::this_thread::get_id() == thread_id_;
|
||||
}
|
||||
|
||||
// The CudaEvent may be created and destroyed on different threads (for
|
||||
// example when waiting on GPU work in CPU stream), we don't want to make
|
||||
// the cache thread-safe as it adds overhead, so we just skip cache when
|
||||
// using events in worker threads.
|
||||
std::thread::id thread_id_{std::this_thread::get_id()};
|
||||
|
||||
// {device: {flags: [events]}}
|
||||
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
|
||||
};
|
||||
|
||||
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
|
||||
CudaEventPool& cuda_event_pool() {
|
||||
static CudaEventPool pool;
|
||||
return pool;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CudaEventHandle::CudaEventHandle(Device& d, int flags)
|
||||
: device(d), flags(flags) {
|
||||
device.make_current();
|
||||
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
|
||||
assert(handle_ != nullptr);
|
||||
}
|
||||
|
||||
CudaEvent::CudaEvent(Device& d, int flags)
|
||||
: event_(cuda_event_pool().create(d, flags)) {}
|
||||
|
||||
CudaEvent::~CudaEvent() {
|
||||
cuda_event_pool().release(std::move(event_));
|
||||
}
|
||||
|
||||
void CudaEvent::wait() {
|
||||
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaEventSynchronize(*event_);
|
||||
event_.device.make_current();
|
||||
cudaEventSynchronize(event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(cudaStream_t stream) {
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaStreamWaitEvent(stream, *event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this]() mutable { wait(); });
|
||||
} else {
|
||||
auto& enc = cu::get_command_encoder(s);
|
||||
enc.commit();
|
||||
wait(enc.stream());
|
||||
}
|
||||
event_.device.make_current();
|
||||
cudaStreamWaitEvent(stream, event_);
|
||||
}
|
||||
|
||||
void CudaEvent::record(cudaStream_t stream) {
|
||||
cudaEventRecord(*event_, stream);
|
||||
recorded_ = true;
|
||||
}
|
||||
|
||||
void CudaEvent::record(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
|
||||
} else {
|
||||
auto& enc = cu::get_command_encoder(s);
|
||||
enc.commit();
|
||||
record(enc.stream());
|
||||
}
|
||||
event_.device.make_current();
|
||||
cudaEventRecord(event_, stream);
|
||||
}
|
||||
|
||||
bool CudaEvent::completed() const {
|
||||
return cudaEventQuery(*event_) == cudaSuccess;
|
||||
// Note: cudaEventQuery can be safely called from any device.
|
||||
return cudaEventQuery(event_) == cudaSuccess;
|
||||
}
|
||||
|
||||
// static
|
||||
void CudaEvent::init_pool() {
|
||||
cuda_event_pool();
|
||||
}
|
||||
|
||||
// Wraps CudaEvent with a few features:
|
||||
// 1. The class can be copied.
|
||||
// 2. Make wait/record work with CPU streams.
|
||||
// 3. Add checks for waiting on un-recorded event.
|
||||
class CopyableCudaEvent {
|
||||
public:
|
||||
explicit CopyableCudaEvent(Device& d)
|
||||
: event_(std::make_shared<CudaEvent>(
|
||||
d,
|
||||
cudaEventDisableTiming | cudaEventBlockingSync)) {}
|
||||
|
||||
void wait() {
|
||||
event_->wait();
|
||||
}
|
||||
|
||||
void wait(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this]() mutable {
|
||||
check_recorded();
|
||||
event_->wait();
|
||||
});
|
||||
} else {
|
||||
check_recorded();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.commit();
|
||||
event_->wait(encoder.stream());
|
||||
}
|
||||
}
|
||||
|
||||
void record(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
throw std::runtime_error("CudaEvent can not wait on CPU stream.");
|
||||
} else {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.commit();
|
||||
event_->record(encoder.stream());
|
||||
recorded_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_signaled() const {
|
||||
return recorded_ && event_->completed();
|
||||
}
|
||||
|
||||
private:
|
||||
void check_recorded() const {
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error(
|
||||
"Should not wait on a CudaEvent before recording.");
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<CudaEvent> event_;
|
||||
bool recorded_{false};
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SharedEvent implementations
|
||||
// AtomicEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
__host__ __device__ void event_wait(AtomicEvent::Atomic* ac, uint64_t value) {
|
||||
uint64_t current;
|
||||
while ((current = ac->load()) < value) {
|
||||
ac->wait(current);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
__host__ __device__ void event_signal(AtomicEvent::Atomic* ac, uint64_t value) {
|
||||
ac->store(value);
|
||||
ac->notify_all();
|
||||
}
|
||||
|
||||
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
__global__ void event_wait_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
|
||||
event_wait(ac, value);
|
||||
}
|
||||
|
||||
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
__global__ void event_signal_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
|
||||
event_signal(ac, value);
|
||||
}
|
||||
|
||||
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
|
||||
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
|
||||
}
|
||||
|
||||
SharedEvent::SharedEvent() {
|
||||
AtomicEvent::AtomicEvent() {
|
||||
buf_ = std::shared_ptr<Buffer>(
|
||||
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
|
||||
allocator().free(*ptr);
|
||||
@@ -123,17 +201,17 @@ SharedEvent::SharedEvent() {
|
||||
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
|
||||
}
|
||||
|
||||
void SharedEvent::wait(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
||||
event_wait(to_atomic(buf_), value);
|
||||
void AtomicEvent::wait(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::AtomicEvent::wait");
|
||||
event_wait(atomic(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||
void AtomicEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||
event_wait_kernel<<<1, 1, 0, stream>>>(atomic(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
|
||||
void AtomicEvent::wait(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::AtomicEvent::wait(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
|
||||
} else {
|
||||
@@ -144,17 +222,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
|
||||
}
|
||||
}
|
||||
|
||||
void SharedEvent::signal(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
||||
event_signal(to_atomic(buf_), value);
|
||||
void AtomicEvent::signal(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::AtomicEvent::signal");
|
||||
event_signal(atomic(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||
void AtomicEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||
event_signal_kernel<<<1, 1, 0, stream>>>(atomic(), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
||||
void AtomicEvent::signal(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::AtomicEvent::signal(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
// Signal through a GPU stream so the atomic is updated in GPU - updating
|
||||
// the atomic in CPU sometimes does not get GPU notified.
|
||||
@@ -168,14 +246,14 @@ void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
}
|
||||
}
|
||||
|
||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
||||
return to_atomic(buf_)->load() >= value;
|
||||
bool AtomicEvent::is_signaled(uint64_t value) const {
|
||||
nvtx3::scoped_range r("cu::AtomicEvent::is_signaled");
|
||||
return atomic()->load() >= value;
|
||||
}
|
||||
|
||||
uint64_t SharedEvent::value() const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
||||
return to_atomic(buf_)->load();
|
||||
uint64_t AtomicEvent::value() const {
|
||||
nvtx3::scoped_range r("cu::AtomicEvent::value");
|
||||
return atomic()->load();
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
@@ -188,14 +266,14 @@ namespace {
|
||||
|
||||
struct EventImpl {
|
||||
// CudaEvent is preferred when possible because it is fast, however we have
|
||||
// to fallback to SharedEvent in following cases:
|
||||
// to fallback to AtomicEvent in following cases:
|
||||
// 1. the event is used to wait/signal a cpu stream;
|
||||
// 2. signal value other than 1 has been specified.
|
||||
std::unique_ptr<cu::CudaEvent> cuda;
|
||||
std::unique_ptr<cu::SharedEvent> shared;
|
||||
std::unique_ptr<cu::CopyableCudaEvent> cuda;
|
||||
std::unique_ptr<cu::AtomicEvent> atomic;
|
||||
|
||||
bool is_created() const {
|
||||
return cuda || shared;
|
||||
return cuda || atomic;
|
||||
}
|
||||
|
||||
void ensure_created(Stream s, uint64_t signal_value) {
|
||||
@@ -203,10 +281,10 @@ struct EventImpl {
|
||||
return;
|
||||
}
|
||||
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
|
||||
nvtx3::mark("Using slow SharedEvent");
|
||||
shared = std::make_unique<cu::SharedEvent>();
|
||||
nvtx3::mark("Using slow AtomicEvent");
|
||||
atomic = std::make_unique<cu::AtomicEvent>();
|
||||
} else {
|
||||
cuda = std::make_unique<cu::CudaEvent>();
|
||||
cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -225,7 +303,7 @@ void Event::wait() {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait();
|
||||
} else {
|
||||
event->shared->wait(value());
|
||||
event->atomic->wait(value());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,7 +314,7 @@ void Event::wait(Stream s) {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait(s);
|
||||
} else {
|
||||
event->shared->wait(s, value());
|
||||
event->atomic->wait(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,7 +325,7 @@ void Event::signal(Stream s) {
|
||||
assert(value() == 1);
|
||||
event->cuda->record(s);
|
||||
} else {
|
||||
event->shared->signal(s, value());
|
||||
event->atomic->signal(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -258,9 +336,9 @@ bool Event::is_signaled() const {
|
||||
}
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
return event->cuda->recorded() && event->cuda->completed();
|
||||
return event->cuda->is_signaled();
|
||||
} else {
|
||||
return event->shared->is_signaled(value());
|
||||
return event->atomic->is_signaled(value());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,49 +3,60 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda/atomic>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class CudaEventHandle;
|
||||
class Device;
|
||||
|
||||
// RAII-managed move-only wrapper of cudaEvent_t.
|
||||
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
|
||||
CudaEventHandle(Device& d, int flags);
|
||||
Device& device;
|
||||
int flags;
|
||||
};
|
||||
|
||||
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
|
||||
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
||||
class CudaEvent {
|
||||
public:
|
||||
CudaEvent();
|
||||
CudaEvent(Device& d, int flags);
|
||||
~CudaEvent();
|
||||
|
||||
CudaEvent(CudaEvent&&) = default;
|
||||
CudaEvent& operator=(CudaEvent&&) = default;
|
||||
|
||||
CudaEvent(const CudaEvent&) = delete;
|
||||
CudaEvent& operator=(const CudaEvent&) = delete;
|
||||
|
||||
void wait();
|
||||
void wait(cudaStream_t stream);
|
||||
void wait(Stream s);
|
||||
void record(cudaStream_t stream);
|
||||
void record(Stream s);
|
||||
|
||||
// Return whether the recorded kernels have completed. Note that this method
|
||||
// returns true if record() has not been called.
|
||||
bool completed() const;
|
||||
|
||||
bool recorded() const {
|
||||
return recorded_;
|
||||
}
|
||||
// Internal: make sure event pool is initialized.
|
||||
static void init_pool();
|
||||
|
||||
private:
|
||||
bool recorded_{false};
|
||||
std::shared_ptr<CudaEventHandle> event_;
|
||||
CudaEventHandle event_;
|
||||
};
|
||||
|
||||
// Event that can synchronize between CPU and GPU. It is much slower than
|
||||
// CudaEvent so the latter should always be preferred when possible.
|
||||
class SharedEvent {
|
||||
class AtomicEvent {
|
||||
public:
|
||||
using Atomic = cuda::atomic<uint64_t>;
|
||||
|
||||
SharedEvent();
|
||||
AtomicEvent();
|
||||
|
||||
void wait(uint64_t value);
|
||||
void wait(cudaStream_t stream, uint64_t value);
|
||||
@@ -57,7 +68,11 @@ class SharedEvent {
|
||||
uint64_t value() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<mlx::core::allocator::Buffer> buf_;
|
||||
Atomic* atomic() const {
|
||||
return static_cast<AtomicEvent::Atomic*>(buf_->raw_ptr());
|
||||
}
|
||||
|
||||
std::shared_ptr<allocator::Buffer> buf_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -7,7 +7,7 @@ namespace mlx::core {
|
||||
|
||||
struct FenceImpl {
|
||||
uint32_t count;
|
||||
cu::SharedEvent event;
|
||||
cu::AtomicEvent event;
|
||||
};
|
||||
|
||||
Fence::Fence(Stream s) {
|
||||
|
||||
@@ -50,8 +50,10 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
: CUBLAS_COMPUTE_32F;
|
||||
case float64:
|
||||
case complex64:
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
case complex64:
|
||||
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
: CUBLAS_COMPUTE_32F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||
@@ -85,10 +87,10 @@ cublasLtMatrixLayout_t create_matrix_layout(
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride) {
|
||||
cublasLtMatrixLayout_t desc;
|
||||
if (transposed) {
|
||||
std::swap(rows, cols);
|
||||
}
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
|
||||
cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t)));
|
||||
if (batch_count > 1) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
@@ -126,37 +128,47 @@ CublasGemm::CublasGemm(
|
||||
N_(b_cols) {
|
||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||
|
||||
auto scale_type = dtype_to_cublas_type(dtype);
|
||||
scale_type_ = dtype_to_cublas_type(dtype);
|
||||
if (dtype == bfloat16 || dtype == float16) {
|
||||
scale_type = CUDA_R_32F;
|
||||
scale_type_ = CUDA_R_32F;
|
||||
}
|
||||
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), scale_type_));
|
||||
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
||||
&pointer_mode,
|
||||
sizeof(int32_t)));
|
||||
cublasOperation_t op = CUBLAS_OP_N;
|
||||
|
||||
// In cublasLt matrices use column-major layout, while it is possible to use
|
||||
// the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
|
||||
// epilogue does not work with the option. So instead we swap A and B to make
|
||||
// cublasLt return the row-major result, which works because:
|
||||
// - the data of a matrix in row-major layout is identical to its transpose in
|
||||
// column-major layout
|
||||
// - C^T = (A @ B)^T = B^T @ A^T
|
||||
cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_TRANSA,
|
||||
&op,
|
||||
&a_op,
|
||||
sizeof(cublasOperation_t)));
|
||||
cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_TRANSB,
|
||||
&op,
|
||||
&b_op,
|
||||
sizeof(cublasOperation_t)));
|
||||
|
||||
auto type = dtype_to_cublas_type(dtype);
|
||||
a_desc_ = create_matrix_layout(
|
||||
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
|
||||
type, b_cols, b_rows, b_transposed, ldb, batch_count, b_batch_stride);
|
||||
b_desc_ = create_matrix_layout(
|
||||
type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride);
|
||||
type, a_cols, a_rows, a_transposed, lda, batch_count, a_batch_stride);
|
||||
out_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols);
|
||||
type, b_cols, a_rows, false, b_cols, batch_count, a_rows * b_cols);
|
||||
}
|
||||
|
||||
CublasGemm::CublasGemm(
|
||||
@@ -191,7 +203,7 @@ CublasGemm::CublasGemm(
|
||||
b_batch_stride) {
|
||||
auto type = dtype_to_cublas_type(dtype);
|
||||
c_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||
type, b_cols, a_rows, false, ldc, batch_count, c_batch_stride);
|
||||
}
|
||||
|
||||
CublasGemm::~CublasGemm() {
|
||||
@@ -202,6 +214,41 @@ CublasGemm::~CublasGemm() {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
|
||||
}
|
||||
|
||||
void CublasGemm::set_out(
|
||||
Dtype dtype,
|
||||
bool transposed,
|
||||
uint64_t rows,
|
||||
uint64_t cols,
|
||||
int64_t ld,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride) {
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
|
||||
out_desc_ = create_matrix_layout(
|
||||
dtype_to_cublas_type(dtype),
|
||||
cols,
|
||||
rows,
|
||||
transposed,
|
||||
ld,
|
||||
batch_count,
|
||||
batch_stride);
|
||||
}
|
||||
|
||||
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
|
||||
encoder.set_input_array(bias);
|
||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_EPILOGUE,
|
||||
&epilogue,
|
||||
sizeof(epilogue)));
|
||||
auto* bias_ptr = bias.data<void>();
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
|
||||
&bias_ptr,
|
||||
sizeof(bias_ptr)));
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
@@ -209,11 +256,19 @@ void CublasGemm::run(
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const Strides& b_batch_strides,
|
||||
float alpha) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
if (batch_count / batch_shape.back() > 1) {
|
||||
run_batched(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
encoder,
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
alpha);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -221,7 +276,13 @@ void CublasGemm::run(
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
execute(encoder, out.data<void>(), a.data<void>(), b.data<void>(), nullptr);
|
||||
execute(
|
||||
encoder,
|
||||
out.data<void>(),
|
||||
a.data<void>(),
|
||||
b.data<void>(),
|
||||
nullptr,
|
||||
alpha);
|
||||
}
|
||||
|
||||
void CublasGemm::run(
|
||||
@@ -294,6 +355,16 @@ void CublasGemm::execute(
|
||||
}
|
||||
}
|
||||
|
||||
const void* alpha_ptr = α
|
||||
const void* beta_ptr = β
|
||||
complex64_t alpha_c, beta_c;
|
||||
if (scale_type_ == CUDA_C_32F) {
|
||||
alpha_c = complex64_t{alpha, 0.0f};
|
||||
beta_c = complex64_t{beta, 0.0f};
|
||||
alpha_ptr = &alpha_c;
|
||||
beta_ptr = &beta_c;
|
||||
}
|
||||
|
||||
void* workspace_ptr = nullptr;
|
||||
if (heuristic_.workspaceSize > 0) {
|
||||
// Ensure workspace is 256-byte aligned
|
||||
@@ -310,12 +381,12 @@ void CublasGemm::execute(
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||
handle_,
|
||||
matmul_desc_,
|
||||
&alpha,
|
||||
a,
|
||||
alpha_ptr,
|
||||
b, // a and b are swapped
|
||||
a_desc_,
|
||||
b,
|
||||
a,
|
||||
b_desc_,
|
||||
&beta,
|
||||
beta_ptr,
|
||||
c ? c : out,
|
||||
c ? c_desc_ : out_desc_,
|
||||
out,
|
||||
|
||||
@@ -44,6 +44,19 @@ class CublasGemm {
|
||||
|
||||
~CublasGemm();
|
||||
|
||||
// The output's descriptor is inferred from inputs by default, use this method
|
||||
// for unusual output.
|
||||
void set_out(
|
||||
Dtype dtype,
|
||||
bool transposed,
|
||||
uint64_t rows,
|
||||
uint64_t cols,
|
||||
int64_t ld,
|
||||
int32_t batch_count,
|
||||
int64_t batch_stride);
|
||||
|
||||
void set_bias(cu::CommandEncoder& encoder, const array& bias);
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
array& out,
|
||||
@@ -51,7 +64,8 @@ class CublasGemm {
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides);
|
||||
const Strides& b_batch_strides,
|
||||
float alpha = 1.0f);
|
||||
|
||||
void run(
|
||||
cu::CommandEncoder& encoder,
|
||||
@@ -74,7 +88,8 @@ class CublasGemm {
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides);
|
||||
const Strides& b_batch_strides,
|
||||
float alpha);
|
||||
|
||||
void run_batched(
|
||||
cu::CommandEncoder& encoder,
|
||||
@@ -100,6 +115,7 @@ class CublasGemm {
|
||||
|
||||
uint64_t M_;
|
||||
uint64_t N_;
|
||||
cudaDataType_t scale_type_;
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
cublasLtHandle_t handle_{nullptr};
|
||||
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||
|
||||
@@ -13,7 +13,8 @@ void CublasGemm::run_batched(
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const Strides& b_batch_strides,
|
||||
float alpha) {
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
@@ -27,7 +28,8 @@ void CublasGemm::run_batched(
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M_ * N_,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
b.data<int8_t>() + b.itemsize() * b_it.loc,
|
||||
nullptr);
|
||||
nullptr,
|
||||
alpha);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
|
||||
@@ -154,7 +154,8 @@ void CublasGemm::run_batched(
|
||||
const array& b,
|
||||
const Shape& batch_shape,
|
||||
const Strides& a_batch_strides,
|
||||
const Strides& b_batch_strides) {
|
||||
const Strides& b_batch_strides,
|
||||
float alpha) {
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
@@ -226,7 +227,8 @@ void CublasGemm::run_batched(
|
||||
reinterpret_cast<void*>(out_pointers),
|
||||
reinterpret_cast<void*>(a_pointers),
|
||||
reinterpret_cast<void*>(b_pointers),
|
||||
nullptr);
|
||||
nullptr,
|
||||
alpha);
|
||||
}
|
||||
|
||||
void CublasGemm::run_batched(
|
||||
|
||||
@@ -13,6 +13,37 @@ namespace cg = cooperative_groups;
|
||||
|
||||
static constexpr int rows_per_block = 8;
|
||||
|
||||
// Accumulator type selection per input element type T.
|
||||
template <typename T>
|
||||
struct GemvAccType {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemvAccType<__half> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemvAccType<__nv_bfloat16> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemvAccType<float> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemvAccType<double> {
|
||||
using type = double;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemvAccType<cu::complex64_t> {
|
||||
using type = cu::complex64_t;
|
||||
};
|
||||
|
||||
template <typename T, int rows_per_block, int n_per_thread>
|
||||
__device__ void
|
||||
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||
@@ -24,7 +55,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||
int row = g_idx.x * rows_per_block + t_idx.y;
|
||||
|
||||
if (row < rows) {
|
||||
float sum = 0.0f;
|
||||
using Acc = typename GemvAccType<T>::type;
|
||||
Acc sum = Acc(0);
|
||||
for (int col = n_per_thread * warp.thread_rank(); col < cols;
|
||||
col += (WARP_SIZE * n_per_thread)) {
|
||||
auto local_mat =
|
||||
@@ -32,12 +64,11 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < n_per_thread; ++j) {
|
||||
sum +=
|
||||
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
|
||||
sum += static_cast<Acc>(local_mat[j]) * static_cast<Acc>(local_vec[j]);
|
||||
}
|
||||
}
|
||||
|
||||
sum = cg::reduce(warp, sum, cg::plus<float>{});
|
||||
sum = cg::reduce(warp, sum, cg::plus<Acc>{});
|
||||
if (warp.thread_rank() == 0) {
|
||||
out[row] = static_cast<T>(sum);
|
||||
}
|
||||
@@ -107,7 +138,7 @@ void gemv(
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
|
||||
dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
dim3 block_dims{WARP_SIZE, rows_per_block};
|
||||
const DataType* mat;
|
||||
|
||||
@@ -94,7 +94,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||
return std::make_tuple(false, jit_source_gather, std::move(kernel_names));
|
||||
});
|
||||
|
||||
cu::KernelArgs args;
|
||||
@@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
args.append<int32_t>(src.ndim());
|
||||
args.append_ndim(slice_sizes_);
|
||||
args.append(slice_size);
|
||||
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
||||
args.append(axes_);
|
||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
@@ -189,7 +189,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
large ? "int64_t" : "int32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||
return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
|
||||
});
|
||||
|
||||
cu::KernelArgs args;
|
||||
@@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
args.append_ndim(out.shape());
|
||||
args.append_ndim(out.strides());
|
||||
args.append<int32_t>(out.ndim());
|
||||
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
||||
args.append(axes_);
|
||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
@@ -268,7 +268,8 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
|
||||
return std::make_tuple(
|
||||
false, jit_source_gather_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
@@ -371,7 +372,8 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
|
||||
return std::make_tuple(
|
||||
false, jit_source_scatter_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
|
||||
@@ -67,9 +67,11 @@ const std::string& cccl_dir() {
|
||||
return path.string();
|
||||
}
|
||||
// Finally check the environment variable.
|
||||
path = std::getenv("MLX_CCCL_DIR");
|
||||
if (!path.empty() && std::filesystem::exists(path)) {
|
||||
return path.string();
|
||||
if (const char* env = std::getenv("MLX_CCCL_DIR"); env) {
|
||||
path = env;
|
||||
if (!path.empty() && std::filesystem::exists(path)) {
|
||||
return path.string();
|
||||
}
|
||||
}
|
||||
return std::string();
|
||||
}();
|
||||
@@ -97,17 +99,41 @@ const std::filesystem::path& ptx_cache_dir() {
|
||||
return cache;
|
||||
}
|
||||
|
||||
std::filesystem::path get_ptx_path(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name) {
|
||||
#ifdef _WIN32
|
||||
constexpr int max_file_name_length = 140;
|
||||
#else
|
||||
constexpr int max_file_name_length = 245;
|
||||
#endif
|
||||
|
||||
if (module_name.size() <= max_file_name_length) {
|
||||
return cache_dir / (module_name + ".ptx");
|
||||
}
|
||||
|
||||
auto ptx_path = cache_dir;
|
||||
int offset = 0;
|
||||
while (module_name.size() - offset > max_file_name_length) {
|
||||
ptx_path /= module_name.substr(offset, max_file_name_length);
|
||||
offset += max_file_name_length;
|
||||
}
|
||||
ptx_path /= module_name.substr(offset) + ".ptx";
|
||||
|
||||
return ptx_path;
|
||||
}
|
||||
|
||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||
bool read_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
std::vector<char>* ptx,
|
||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||
std::string& ptx,
|
||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
if (cache_dir.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||
auto ptx_path = get_ptx_path(cache_dir, module_name);
|
||||
std::error_code error;
|
||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||
if (error) {
|
||||
@@ -117,15 +143,15 @@ bool read_cached_ptx(
|
||||
if (!ptx_file.good()) {
|
||||
return false;
|
||||
}
|
||||
ptx->resize(ptx_size);
|
||||
ptx_file.read(ptx->data(), ptx_size);
|
||||
ptx.resize(ptx_size);
|
||||
ptx_file.read(ptx.data(), ptx_size);
|
||||
|
||||
std::ifstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
std::ifstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
|
||||
std::string line;
|
||||
while (std::getline(txt_file, line)) {
|
||||
auto tab = line.find('\t');
|
||||
if (tab != std::string::npos) {
|
||||
ptx_kernels->emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||
ptx_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
@@ -135,23 +161,33 @@ bool read_cached_ptx(
|
||||
void write_cached_ptx(
|
||||
const std::filesystem::path& cache_dir,
|
||||
const std::string& module_name,
|
||||
const std::vector<char>& ptx,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
const std::string& source_code) {
|
||||
if (cache_dir.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||
auto ptx_path = get_ptx_path(cache_dir, module_name);
|
||||
|
||||
// Ensure that the directory exists
|
||||
auto parent = ptx_path.parent_path();
|
||||
if (parent != cache_dir) {
|
||||
std::filesystem::create_directories(parent);
|
||||
}
|
||||
|
||||
// Write the compiled code and mangled names
|
||||
std::ofstream ptx_file(ptx_path, std::ios::binary);
|
||||
if (!ptx.empty()) {
|
||||
ptx_file.write(&ptx.front(), ptx.size());
|
||||
}
|
||||
std::ofstream txt_file(cache_dir / (module_name + ".txt"), std::ios::binary);
|
||||
std::ofstream txt_file(ptx_path.replace_extension(".txt"), std::ios::binary);
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
txt_file << name << "\t" << mangled << std::endl;
|
||||
}
|
||||
|
||||
std::ofstream source_file(cache_dir / (module_name + ".cu"));
|
||||
// Write the generated code
|
||||
std::ofstream source_file(ptx_path.replace_extension(".cu"));
|
||||
source_file << source_code;
|
||||
}
|
||||
|
||||
@@ -217,85 +253,86 @@ constexpr const char* g_headers[] = {
|
||||
jit_source_utils,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
JitModule::JitModule(
|
||||
void compile(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder) {
|
||||
// Check cache.
|
||||
std::vector<char> ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||
// Create program.
|
||||
auto [source_code, kernel_names] = builder();
|
||||
nvrtcProgram prog;
|
||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||
&prog,
|
||||
source_code.c_str(),
|
||||
(module_name + ".cu").c_str(),
|
||||
std::size(g_headers),
|
||||
g_headers,
|
||||
g_include_names));
|
||||
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||
&prog,
|
||||
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||
for (const auto& name : kernel_names) {
|
||||
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
std::vector<const char*> args;
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
args.push_back(compute.c_str());
|
||||
std::string cccl_include = cccl_dir();
|
||||
if (!cccl_include.empty()) {
|
||||
cccl_include = fmt::format("--include-path={}", cccl_include);
|
||||
args.push_back(cccl_include.c_str());
|
||||
}
|
||||
std::string cuda_include =
|
||||
fmt::format("--include-path={}/include", cuda_home());
|
||||
args.push_back(cuda_include.c_str());
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, args.size(), args.data());
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
std::vector<char> log(log_size + 1, 0);
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||
}
|
||||
|
||||
// Get mangled names of kernel names.
|
||||
for (const auto& name : kernel_names) {
|
||||
const char* mangled;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||
ptx_kernels.emplace_back(name, mangled);
|
||||
}
|
||||
|
||||
// Get ptx data.
|
||||
size_t ptx_size;
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size, 0);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
const std::string& source,
|
||||
const std::vector<std::string>& kernel_names,
|
||||
std::string& ptx,
|
||||
std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||
// Create the program
|
||||
nvrtcProgram prog;
|
||||
CHECK_NVRTC_ERROR(nvrtcCreateProgram(
|
||||
&prog,
|
||||
source.c_str(),
|
||||
(module_name + ".cu").c_str(),
|
||||
std::size(g_headers),
|
||||
g_headers,
|
||||
g_include_names));
|
||||
std::unique_ptr<nvrtcProgram, void (*)(nvrtcProgram*)> prog_freer(
|
||||
&prog,
|
||||
[](nvrtcProgram* p) { CHECK_NVRTC_ERROR(nvrtcDestroyProgram(p)); });
|
||||
for (const auto& name : kernel_names) {
|
||||
CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str()));
|
||||
}
|
||||
|
||||
// Compile program.
|
||||
std::vector<const char*> args;
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
args.push_back(compute.c_str());
|
||||
std::string cccl_include = cccl_dir();
|
||||
if (!cccl_include.empty()) {
|
||||
cccl_include = fmt::format("--include-path={}", cccl_include);
|
||||
args.push_back(cccl_include.c_str());
|
||||
}
|
||||
std::string cuda_include =
|
||||
fmt::format("--include-path={}/include", cuda_home());
|
||||
args.push_back(cuda_include.c_str());
|
||||
nvrtcResult compile_result =
|
||||
nvrtcCompileProgram(prog, args.size(), args.data());
|
||||
if (compile_result != NVRTC_SUCCESS) {
|
||||
size_t log_size;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLogSize(prog, &log_size));
|
||||
std::vector<char> log(log_size + 1, 0);
|
||||
CHECK_NVRTC_ERROR(nvrtcGetProgramLog(prog, log.data()));
|
||||
throw std::runtime_error(
|
||||
fmt::format("Failed to compile kernel: {}.", log.data()));
|
||||
}
|
||||
|
||||
// Get mangled names of kernel names.
|
||||
for (const auto& name : kernel_names) {
|
||||
const char* mangled;
|
||||
CHECK_NVRTC_ERROR(nvrtcGetLoweredName(prog, name.c_str(), &mangled));
|
||||
ptx_kernels.emplace_back(name, mangled);
|
||||
}
|
||||
|
||||
// Get ptx data.
|
||||
size_t ptx_size;
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBINSize(prog, &ptx_size));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTXSize(prog, &ptx_size));
|
||||
}
|
||||
ptx.resize(ptx_size);
|
||||
if (use_sass) {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetCUBIN(prog, ptx.data()));
|
||||
} else {
|
||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||
}
|
||||
}
|
||||
|
||||
void load_module(
|
||||
const std::string& module_name,
|
||||
const std::string& ptx,
|
||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels,
|
||||
CUmodule& module_,
|
||||
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>>&
|
||||
kernels) {
|
||||
// Load module.
|
||||
char jit_log[4089] = {};
|
||||
CUjit_option options[] = {
|
||||
@@ -312,21 +349,77 @@ JitModule::JitModule(
|
||||
for (const auto& [name, mangled] : ptx_kernels) {
|
||||
CUfunction kernel;
|
||||
CHECK_CUDA_ERROR(cuModuleGetFunction(&kernel, module_, mangled.c_str()));
|
||||
kernels_[name] = kernel;
|
||||
kernels[name] = std::make_tuple(kernel, false, 0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
JitModule::JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder,
|
||||
bool use_disk_cache) {
|
||||
// Will hold the actual device executable source code and kernel names
|
||||
std::string ptx;
|
||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||
|
||||
// Try to load them from the file cache
|
||||
if (!read_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels)) {
|
||||
auto [precompiled, source_code, kernel_names] = builder();
|
||||
|
||||
// Get the PTX or cubin
|
||||
if (precompiled) {
|
||||
ptx = std::move(source_code);
|
||||
for (auto& name : kernel_names) {
|
||||
ptx_kernels.emplace_back(name, name);
|
||||
}
|
||||
} else {
|
||||
compile(device, module_name, source_code, kernel_names, ptx, ptx_kernels);
|
||||
}
|
||||
|
||||
// If requested save them in the file cache for the next launch
|
||||
if (use_disk_cache) {
|
||||
write_cached_ptx(
|
||||
ptx_cache_dir(), module_name, ptx, ptx_kernels, source_code);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the module
|
||||
load_module(module_name, ptx, ptx_kernels, module_, kernels_);
|
||||
}
|
||||
|
||||
JitModule::~JitModule() {
|
||||
CHECK_CUDA_ERROR(cuModuleUnload(module_));
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(const std::string& kernel_name) {
|
||||
std::pair<CUfunction, uint> JitModule::get_kernel_and_dims(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel) {
|
||||
auto it = kernels_.find(kernel_name);
|
||||
if (it == kernels_.end()) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("There is no kernel named {}.", kernel_name));
|
||||
}
|
||||
return it->second;
|
||||
|
||||
// If it is the first time we run this kernel then configure it. Do it only
|
||||
// once!
|
||||
auto kernel = std::get<0>(it->second);
|
||||
if (!std::get<1>(it->second)) {
|
||||
if (configure_kernel) {
|
||||
configure_kernel(kernel);
|
||||
}
|
||||
std::get<1>(it->second) = true;
|
||||
std::get<2>(it->second) = max_occupancy_block_dim(kernel);
|
||||
}
|
||||
|
||||
return {kernel, std::get<2>(it->second)};
|
||||
}
|
||||
|
||||
CUfunction JitModule::get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel) {
|
||||
return get_kernel_and_dims(kernel_name, std::move(configure_kernel)).first;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||
@@ -337,11 +430,12 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache() {
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder) {
|
||||
const KernelBuilder& builder,
|
||||
bool cache) {
|
||||
auto& map = get_jit_module_cache();
|
||||
auto it = map.find(name);
|
||||
if (it == map.end()) {
|
||||
it = map.try_emplace(name, cu::device(device), name, builder).first;
|
||||
it = map.try_emplace(name, cu::device(device), name, builder, cache).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
@@ -19,7 +19,8 @@ namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
|
||||
using KernelBuilderResult = std::pair<
|
||||
using KernelBuilderResult = std::tuple<
|
||||
/* precompiled */ bool,
|
||||
/* source code */ std::string,
|
||||
/* kernel names */ std::vector<std::string>>;
|
||||
using KernelBuilder = std::function<KernelBuilderResult()>;
|
||||
@@ -45,6 +46,11 @@ struct KernelArgs {
|
||||
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append(const std::vector<T>& vec) {
|
||||
append(SmallVector<T>(vec.begin(), vec.end()));
|
||||
}
|
||||
|
||||
// Make sure the arg is copied to an array with size of NDIM.
|
||||
template <size_t NDIM = MAX_NDIM, typename T>
|
||||
void append_ndim(SmallVector<T> vec) {
|
||||
@@ -63,14 +69,16 @@ struct KernelArgs {
|
||||
private:
|
||||
std::vector<void*> args_;
|
||||
|
||||
// The cuLaunchKernel API requires passing pointers to arguments so store
|
||||
// temporary values untill kernel is launched.
|
||||
// The cuGraphAddKernelNode API requires passing pointers to arguments so
|
||||
// store temporary values until the node is created.
|
||||
using Arg = std::variant<
|
||||
std::monostate,
|
||||
CUdeviceptr,
|
||||
bool,
|
||||
int32_t,
|
||||
uint32_t,
|
||||
int64_t,
|
||||
float,
|
||||
SmallVector<const void*>,
|
||||
SmallVector<int32_t>,
|
||||
SmallVector<int64_t>>;
|
||||
@@ -82,16 +90,22 @@ class JitModule {
|
||||
JitModule(
|
||||
Device& device,
|
||||
const std::string& module_name,
|
||||
const KernelBuilder& builder);
|
||||
const KernelBuilder& builder,
|
||||
bool cache);
|
||||
~JitModule();
|
||||
|
||||
JitModule(const JitModule&) = delete;
|
||||
JitModule& operator=(const JitModule&) = delete;
|
||||
CUfunction get_kernel(const std::string& kernel_name);
|
||||
CUfunction get_kernel(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||
std::pair<CUfunction, uint> get_kernel_and_dims(
|
||||
const std::string& kernel_name,
|
||||
std::function<void(CUfunction)> configure_kernel = nullptr);
|
||||
|
||||
private:
|
||||
CUmodule module_{nullptr};
|
||||
std::unordered_map<std::string, CUfunction> kernels_;
|
||||
std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
|
||||
};
|
||||
|
||||
std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
@@ -99,6 +113,7 @@ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
|
||||
JitModule& get_jit_module(
|
||||
const mlx::core::Device& device,
|
||||
const std::string& name,
|
||||
const KernelBuilder& builder);
|
||||
const KernelBuilder& builder,
|
||||
bool use_disk_cache = true);
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -35,12 +35,10 @@ std::tuple<dim3, uint> get_launch_args(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
bool large,
|
||||
int work_per_thread) {
|
||||
int work_per_thread /* = 1 */,
|
||||
uint max_block_dim /* = 1024 */) {
|
||||
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||
uint block_dim = 1024;
|
||||
if (block_dim > nthreads) {
|
||||
block_dim = nthreads;
|
||||
}
|
||||
uint block_dim = max_block_dim < nthreads ? max_block_dim : nthreads;
|
||||
dim3 num_blocks;
|
||||
if (large) {
|
||||
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file includes host-only utilies for writing CUDA kernels, the difference
|
||||
// from backend/cuda/device/utils.cuh is that the latter file only include
|
||||
// device-only code.
|
||||
// This file includes host-only utilities for writing CUDA kernels, the
|
||||
// difference from backend/cuda/device/utils.cuh is that the latter file only
|
||||
// include device-only code.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -120,19 +120,28 @@ dim3 get_2d_grid_dims(
|
||||
size_t divisor);
|
||||
std::pair<dim3, dim3> get_grid_and_block(int dim0, int dim1, int dim2);
|
||||
|
||||
// Get the num_blocks and block_dims that maximize occupancy for |kernel|,
|
||||
// assuming each thread handles |work_per_thread| elements of |arr|.
|
||||
// Get the num_blocks and block_dims assuming each thread handles
|
||||
// |work_per_thread| elements of |arr|.
|
||||
std::tuple<dim3, uint> get_launch_args(
|
||||
size_t size,
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
bool large,
|
||||
int work_per_thread = 1);
|
||||
int work_per_thread = 1,
|
||||
uint max_block_dim = 1024);
|
||||
|
||||
inline std::tuple<dim3, uint>
|
||||
get_launch_args(const array& arr, bool large, int work_per_thread = 1) {
|
||||
inline std::tuple<dim3, uint> get_launch_args(
|
||||
const array& arr,
|
||||
bool large,
|
||||
int work_per_thread = 1,
|
||||
uint max_block_dim = 1024) {
|
||||
return get_launch_args(
|
||||
arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||
arr.size(),
|
||||
arr.shape(),
|
||||
arr.strides(),
|
||||
large,
|
||||
work_per_thread,
|
||||
max_block_dim);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -2,11 +2,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <
|
||||
@@ -27,6 +31,14 @@ class LRUCache {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize with capacity read from |env_name|.
|
||||
LRUCache(const char* env_name, int default_capacity)
|
||||
: LRUCache(env::get_var(env_name, default_capacity)) {
|
||||
if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) {
|
||||
env_name_ = env_name;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return map_.size();
|
||||
}
|
||||
@@ -76,6 +88,14 @@ class LRUCache {
|
||||
return {it->second, false};
|
||||
}
|
||||
|
||||
if (env_name_ && ++cache_misses_ > 2 * capacity_) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Cache thrashing is happening, please set the environment variable "
|
||||
"{} to a larger value than {} to fix degraded performance.",
|
||||
env_name_,
|
||||
capacity_));
|
||||
}
|
||||
|
||||
vlist_.emplace_front(key, std::forward<U>(value));
|
||||
map_[key] = vlist_.begin();
|
||||
|
||||
@@ -106,6 +126,9 @@ class LRUCache {
|
||||
}
|
||||
}
|
||||
|
||||
const char* env_name_{nullptr};
|
||||
size_t cache_misses_{0};
|
||||
|
||||
list_type vlist_;
|
||||
map_type map_;
|
||||
size_t capacity_;
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<bool, int64_t, array>
|
||||
@@ -28,6 +29,80 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_and_bias(
|
||||
cu::CommandEncoder& encoder,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
bool a_transposed,
|
||||
int64_t lda,
|
||||
bool b_transposed,
|
||||
int64_t ldb,
|
||||
array& out,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const std::optional<array>& bias = std::nullopt,
|
||||
float alpha = 1.0f) {
|
||||
// Check and collapse batch dimensions
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||
|
||||
auto batch_count = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||
b_batch_strides.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_count = 1;
|
||||
|
||||
a_batch_strides = {0};
|
||||
b_batch_strides = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
// Use gemmv when possible
|
||||
if (!bias && cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
||||
cu::gemv(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_count,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
encoder);
|
||||
return;
|
||||
}
|
||||
|
||||
// Invoke cublasLt
|
||||
CublasGemm gemm(
|
||||
encoder.device(),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
K,
|
||||
lda,
|
||||
b_transposed,
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
if (bias) {
|
||||
if (a.dtype() == complex64) {
|
||||
throw std::runtime_error(
|
||||
"[gemm_and_bias] complex64 bias epilogue isn’t supported in cublasLtMatmul.");
|
||||
}
|
||||
gemm.set_bias(encoder, *bias);
|
||||
}
|
||||
gemm.run(
|
||||
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -48,9 +123,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
int K = a_pre.shape(-1);
|
||||
@@ -60,58 +132,8 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
|
||||
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
|
||||
|
||||
auto batch_count = out.size() / (M * N);
|
||||
|
||||
// Collapse batches into M if needed
|
||||
if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 &&
|
||||
a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K &&
|
||||
b_batch_strides.back() == 0) {
|
||||
M *= batch_shape.back();
|
||||
batch_count = 1;
|
||||
|
||||
a_batch_strides = {0};
|
||||
b_batch_strides = {0};
|
||||
batch_shape = {1};
|
||||
}
|
||||
|
||||
if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) {
|
||||
cu::gemv(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
batch_count,
|
||||
batch_shape,
|
||||
a_batch_strides,
|
||||
b_batch_strides,
|
||||
encoder);
|
||||
return;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
CublasGemm gemm(
|
||||
cu::device(s.device),
|
||||
a.dtype(),
|
||||
a_transposed,
|
||||
M,
|
||||
K,
|
||||
lda,
|
||||
b_transposed,
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides);
|
||||
gemm_and_bias(
|
||||
encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b);
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -136,6 +158,29 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Dispatch to GEMM with epilogue or AddMM
|
||||
|
||||
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
|
||||
c.data_size() == out.shape(-1)) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
gemm_and_bias(
|
||||
encoder,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_transposed,
|
||||
lda,
|
||||
b_transposed,
|
||||
ldb,
|
||||
out,
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
alpha_);
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t ldc;
|
||||
{
|
||||
auto stx = c.strides()[c.ndim() - 2];
|
||||
@@ -177,7 +222,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Invoke cublasLt
|
||||
// Invoke cublasLt with AddMM settings
|
||||
|
||||
CublasGemm gemm(
|
||||
cu::device(s.device),
|
||||
|
||||
@@ -1,11 +1,47 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
#include "mlx/fast.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
bool is_available() {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
} // namespace cu
|
||||
|
||||
namespace fast {
|
||||
|
||||
CustomKernelFunction cuda_kernel(
|
||||
const std::string&,
|
||||
const std::vector<std::string>&,
|
||||
const std::vector<std::string>&,
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
bool,
|
||||
int) {
|
||||
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
|
||||
}
|
||||
|
||||
std::vector<array> precompiled_cuda_kernel(
|
||||
const std::string&,
|
||||
const std::string&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<Shape>&,
|
||||
const std::vector<Dtype>&,
|
||||
const std::vector<ScalarArg>&,
|
||||
std::tuple<int, int, int>,
|
||||
std::tuple<int, int, int>,
|
||||
int shared_memory,
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice) {
|
||||
throw std::runtime_error("[cuda_kernel] No CUDA back-end.");
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -24,8 +24,6 @@ namespace mlx::core {
|
||||
}
|
||||
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
@@ -41,12 +39,7 @@ NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
|
||||
@@ -46,10 +46,10 @@ inline array ensure_row_contiguous_matrix(
|
||||
|
||||
} // namespace
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
void fast::Quantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("AffineQuantize::eval_gpu");
|
||||
nvtx3::scoped_range r("Quantize::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& d = cu::device(s.device);
|
||||
auto& enc = d.get_command_encoder(s);
|
||||
|
||||
@@ -181,6 +181,47 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = 4>
|
||||
__global__ void col_reduce_small(
|
||||
const T* in,
|
||||
U* out,
|
||||
const __grid_constant__ ColReduceArgs args,
|
||||
size_t total) {
|
||||
Op op;
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
const auto idx = grid.thread_rank() * N_READS;
|
||||
const auto before_axis = idx / args.reduction_stride;
|
||||
const auto after_axis = idx % args.reduction_stride;
|
||||
const auto offset =
|
||||
before_axis * args.reduction_stride * args.reduction_size + after_axis;
|
||||
|
||||
if (idx >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
in += offset;
|
||||
out += idx;
|
||||
|
||||
AlignedVector<U, N_READS> accumulator;
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
accumulator[i] = ReduceInit<Op, T>::value();
|
||||
}
|
||||
|
||||
for (int i = 0; i < args.reduction_size; i++) {
|
||||
auto values = load_vector<N_READS>(in, 0);
|
||||
|
||||
for (int j = 0; j < N_READS; j++) {
|
||||
accumulator[j] = op(accumulator[j], cast_to<U>(values[j]));
|
||||
}
|
||||
|
||||
in += args.reduction_stride;
|
||||
}
|
||||
|
||||
store_vector(out, 0, accumulator);
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
inline auto output_grid_for_col_reduce(
|
||||
@@ -206,7 +247,7 @@ void col_reduce_looped(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::ColReduceArgs args) {
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -230,12 +271,55 @@ void col_reduce_looped(
|
||||
auto kernel =
|
||||
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, blocks, 0, indata, out.data<U>(), args);
|
||||
kernel,
|
||||
grid,
|
||||
blocks,
|
||||
0,
|
||||
indata,
|
||||
out.data<U>(),
|
||||
static_cast<cu::ColReduceArgs>(args));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void col_reduce_small(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
const cu::ColReduceArgs& args) {
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
constexpr int N_READS = 16 / sizeof(T);
|
||||
auto tmp_grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
auto [grid, block] = get_grid_and_block(tmp_grid.x, tmp_grid.y, 1);
|
||||
auto kernel = cu::col_reduce_small<T, U, OP, N_READS>;
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
static_cast<cu::ColReduceArgs>(args),
|
||||
out.size());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void col_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& in,
|
||||
@@ -258,6 +342,13 @@ void col_reduce(
|
||||
// Make the args struct to help route to the best kernel
|
||||
cu::ColReduceArgs args(in, plan, axes);
|
||||
|
||||
// Small col reduce with a single or contiguous reduction axis
|
||||
if (args.non_col_reductions == 1 && args.reduction_size <= 32 &&
|
||||
args.reduction_stride % (16 / in.itemsize()) == 0) {
|
||||
col_reduce_small(encoder, in, out, reduce_type, axes, plan, args);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback col reduce
|
||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -83,7 +81,8 @@ struct RowReduceArgs {
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
|
||||
__global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
__global__ void
|
||||
row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
@@ -91,8 +90,8 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
const U init = cu::ReduceInit<ReduceOp, T>::value();
|
||||
ReduceOp op;
|
||||
|
||||
T vals[M][N];
|
||||
U accs[M];
|
||||
AlignedVector<T, N> vals[M];
|
||||
AlignedVector<U, M> accs;
|
||||
for (int i = 0; i < M; i++) {
|
||||
accs[i] = init;
|
||||
}
|
||||
@@ -101,43 +100,31 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
|
||||
const size_t full_blocks = size / (block.size() * N);
|
||||
const size_t final_offset = full_blocks * (block.size() * N);
|
||||
in += start_row * size;
|
||||
in += start_row * size + block.thread_rank() * N;
|
||||
out += start_row;
|
||||
|
||||
if (size % N == 0) {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlockedVectorized<T, N>(
|
||||
block.thread_rank(),
|
||||
in + k * size + r * (block.size() * N),
|
||||
vals[k]);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + k * size + r * (block.size() * N),
|
||||
vals[k]);
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
vals[k] = load_vector<N>(in + k * size, 0);
|
||||
}
|
||||
for (int k = 0; k < M; k++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
}
|
||||
|
||||
in += block.size() * N;
|
||||
}
|
||||
|
||||
if (final_offset < size) {
|
||||
for (int k = 0; k < M; k++) {
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + k * size + final_offset,
|
||||
vals[k],
|
||||
size,
|
||||
cast_to<T>(init));
|
||||
for (int i = 0; i < N; i++) {
|
||||
vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size)
|
||||
? in[k * size + i]
|
||||
: cast_to<T>(init);
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < M; k++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
|
||||
}
|
||||
@@ -145,13 +132,11 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
}
|
||||
|
||||
__shared__ U shared_accumulators[32 * M];
|
||||
block_reduce(block, warp, accs, shared_accumulators, op, init);
|
||||
block_reduce(block, warp, accs.val, shared_accumulators, op, init);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
if (grid.block_rank() * M + M <= n_rows) {
|
||||
for (int i = 0; i < M; i++) {
|
||||
out[i] = accs[i];
|
||||
}
|
||||
store_vector(out, 0, accs);
|
||||
} else {
|
||||
short offset = grid.block_rank() * M + M - n_rows;
|
||||
for (int i = offset; i < M; i++) {
|
||||
@@ -161,17 +146,10 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
int BLOCK_DIM,
|
||||
int N_READS = 4>
|
||||
template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
|
||||
__global__ void row_reduce_looped(
|
||||
T* in,
|
||||
const T* in,
|
||||
U* out,
|
||||
size_t out_size,
|
||||
const __grid_constant__ RowReduceArgs args) {
|
||||
auto grid = cg::this_grid();
|
||||
auto block = cg::this_thread_block();
|
||||
@@ -185,36 +163,60 @@ __global__ void row_reduce_looped(
|
||||
U init = ReduceInit<Op, T>::value();
|
||||
total[0] = init;
|
||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||
size_t full_blocks = args.row_size / (BLOCK_DIM * N_READS);
|
||||
size_t final_offset = full_blocks * BLOCK_DIM * N_READS;
|
||||
const size_t full_blocks = args.row_size / (block.size() * N_READS);
|
||||
const size_t final_offset = full_blocks * (block.size() * N_READS);
|
||||
|
||||
in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
|
||||
in += block.thread_rank() * N_READS;
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlockedVectorized<T, N_READS>(
|
||||
block.thread_rank(),
|
||||
in + loop.location() + r * BLOCK_DIM * N_READS,
|
||||
vals);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
// Unaligned reduce
|
||||
if (final_offset < args.row_size) {
|
||||
bool mask[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
mask[i] =
|
||||
(final_offset + block.thread_rank() * N_READS + i) < args.row_size;
|
||||
}
|
||||
if (final_offset < args.row_size) {
|
||||
T vals[N_READS];
|
||||
cub::LoadDirectBlocked(
|
||||
block.thread_rank(),
|
||||
in + loop.location() + final_offset,
|
||||
vals,
|
||||
args.row_size - final_offset,
|
||||
cast_to<T>(init));
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
const T* inlocal = in + loop.location();
|
||||
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
auto vals = load_vector<N_READS>(inlocal, 0);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
inlocal += block.size() * N_READS;
|
||||
}
|
||||
|
||||
{
|
||||
T vals[N_READS];
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
|
||||
}
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
}
|
||||
|
||||
// Aligned case
|
||||
else {
|
||||
for (size_t n = 0; n < args.non_row_reductions; n++) {
|
||||
const T* inlocal = in + loop.location();
|
||||
|
||||
for (size_t r = 0; r < full_blocks; r++) {
|
||||
auto vals = load_vector<N_READS>(inlocal, 0);
|
||||
for (int i = 0; i < N_READS; i++) {
|
||||
total[0] = op(total[0], cast_to<U>(vals[i]));
|
||||
}
|
||||
inlocal += block.size() * N_READS;
|
||||
}
|
||||
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
// TODO: Maybe block.sync() here?
|
||||
loop.next(args.reduce_shape.data(), args.reduce_strides.data());
|
||||
}
|
||||
|
||||
__shared__ U shared_accumulators[32];
|
||||
@@ -234,8 +236,6 @@ void row_reduce_simple(
|
||||
Reduce::ReduceType reduce_type,
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
// Allocate data for the output using in's layout to avoid elem_to_loc in the
|
||||
// kernel.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -250,14 +250,15 @@ void row_reduce_simple(
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
constexpr int N_READS = 16 / sizeof(T);
|
||||
|
||||
// Calculate the grid and block dims
|
||||
size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||
warps /= 4;
|
||||
warps = std::max(std::min(warps, 32), 1);
|
||||
int threads = warps * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
@@ -267,6 +268,7 @@ void row_reduce_simple(
|
||||
kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
|
||||
}
|
||||
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
int size = plan.shape.back();
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
|
||||
@@ -282,8 +284,6 @@ void row_reduce_looped(
|
||||
const std::vector<int>& axes,
|
||||
const ReductionPlan& plan,
|
||||
cu::RowReduceArgs args) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
// Allocate data for the output using in's layout to access them as
|
||||
// contiguously as possible.
|
||||
allocate_same_layout(out, in, axes);
|
||||
@@ -295,34 +295,27 @@ void row_reduce_looped(
|
||||
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
using U = typename cu::ReduceResult<OP, T>::type;
|
||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||
T* indata = const_cast<T*>(in.data<T>());
|
||||
|
||||
constexpr int N_READS = 16 / sizeof(T);
|
||||
|
||||
// Calculate the grid and block dims
|
||||
args.sort_access_pattern(in, axes);
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||
int threads = std::min(1024UL, reductions);
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||
warps /= 4;
|
||||
warps = std::max(std::min(warps, 32), 1);
|
||||
int threads = warps * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
// Pick the kernel
|
||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>;
|
||||
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||
dispatch_block_dim(threads, [&](auto threads_constant) {
|
||||
kernel = cu::row_reduce_looped<
|
||||
T,
|
||||
U,
|
||||
OP,
|
||||
reduce_ndim.value,
|
||||
threads_constant.value,
|
||||
N_READS>;
|
||||
block.x = threads_constant.value;
|
||||
});
|
||||
kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>;
|
||||
});
|
||||
|
||||
encoder.add_kernel_node(
|
||||
kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
|
||||
kernel, grid, block, 0, in.data<T>(), out.data<U>(), args);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -103,15 +103,21 @@ template <typename T, bool traditional, bool forward, int N = 4>
|
||||
__device__ void rope_impl(
|
||||
const T* in,
|
||||
T* out,
|
||||
int offset,
|
||||
const int* offset,
|
||||
float inv_freq,
|
||||
float scale,
|
||||
const cuda::std::array<int64_t, 3> strides,
|
||||
const cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 pos,
|
||||
uint3 dims) {
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
auto n_head_up = N * ((n_head + N - 1) / N);
|
||||
auto head_idx = static_cast<int>((pos.z * N) % n_head_up);
|
||||
auto batch_idx = (pos.z * N) / n_head_up;
|
||||
auto batch_offset = offset[batch_idx * offset_stride];
|
||||
float L = scale * static_cast<float>(pos.y + batch_offset);
|
||||
auto mat_idx = batch_idx * n_head + head_idx;
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * inv_freq;
|
||||
@@ -123,20 +129,19 @@ __device__ void rope_impl(
|
||||
size_t out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 =
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
|
||||
N * pos.z * out_strides[0];
|
||||
mat_idx * out_strides[0];
|
||||
out_index_2 = out_index_1 + dims.x * out_strides[2];
|
||||
in_index_1 =
|
||||
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
|
||||
in_index_2 = in_index_1 + dims.x * strides[2];
|
||||
}
|
||||
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
|
||||
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
@@ -167,7 +172,8 @@ __global__ void rope(
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 dims) {
|
||||
uint3 pos = make_uint3(
|
||||
blockIdx.x * blockDim.x + threadIdx.x,
|
||||
@@ -182,12 +188,13 @@ __global__ void rope(
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
@@ -202,7 +209,8 @@ __global__ void rope_freqs(
|
||||
float base,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, 3> out_strides,
|
||||
int64_t n_batch,
|
||||
int64_t offset_stride,
|
||||
int n_head,
|
||||
uint3 dims,
|
||||
int64_t freq_stride) {
|
||||
uint3 pos = make_uint3(
|
||||
@@ -217,12 +225,13 @@ __global__ void rope_freqs(
|
||||
rope_impl<T, traditional, forward>(
|
||||
in,
|
||||
out,
|
||||
*offset,
|
||||
offset,
|
||||
inv_freq,
|
||||
scale,
|
||||
strides,
|
||||
out_strides,
|
||||
n_batch,
|
||||
offset_stride,
|
||||
n_head,
|
||||
pos,
|
||||
dims);
|
||||
}
|
||||
@@ -245,23 +254,28 @@ void RoPE::eval_gpu(
|
||||
auto& offset = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() < 3) {
|
||||
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
|
||||
}
|
||||
|
||||
cuda::std::array<int64_t, 3> strides;
|
||||
cuda::std::array<int64_t, 3> out_strides;
|
||||
bool donated = false;
|
||||
int ndim = in.ndim();
|
||||
int dispatch_ndim = in.ndim();
|
||||
|
||||
int B = in.shape(0);
|
||||
int T = in.shape(-2);
|
||||
int D = in.shape(-1);
|
||||
size_t mat_size = T * D;
|
||||
int dispatch_ndim = ndim;
|
||||
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
|
||||
dispatch_ndim--;
|
||||
}
|
||||
size_t mat_size = in.shape(-2) * in.shape(-1);
|
||||
|
||||
int N = 1;
|
||||
for (int i = 1; i < (ndim - 2); ++i) {
|
||||
N *= in.shape(i);
|
||||
}
|
||||
|
||||
// We apply rope to less that the whole vector so copy to output and then
|
||||
// apply in-place.
|
||||
if (dims_ < in.shape(-1)) {
|
||||
if (dims_ < D) {
|
||||
donated = true;
|
||||
auto ctype =
|
||||
(in.flags().row_contiguous) ? CopyType::Vector : CopyType::General;
|
||||
@@ -302,7 +316,7 @@ void RoPE::eval_gpu(
|
||||
out_strides[2] = out.strides()[ndim - 1];
|
||||
|
||||
// Some flags to help us dispatch below
|
||||
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
|
||||
bool single = in.flags().row_contiguous && B == 1 && T == 1;
|
||||
bool with_freqs = inputs.size() == 3;
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
@@ -319,7 +333,7 @@ void RoPE::eval_gpu(
|
||||
if (single && !with_freqs) {
|
||||
auto kernel =
|
||||
cu::rope_single<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
uint2 dims = make_uint2(dims_ / 2, N);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -336,7 +350,7 @@ void RoPE::eval_gpu(
|
||||
} else if (single) {
|
||||
auto kernel =
|
||||
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
|
||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||
uint2 dims = make_uint2(dims_ / 2, N);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
@@ -354,10 +368,14 @@ void RoPE::eval_gpu(
|
||||
} else if (with_freqs) {
|
||||
auto kernel =
|
||||
cu::rope_freqs<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
int n_per_thread = 4;
|
||||
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||
uint3 dims = make_uint3(dims_ / 2, T, dimz);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
@@ -371,15 +389,20 @@ void RoPE::eval_gpu(
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
offset_stride,
|
||||
N,
|
||||
dims,
|
||||
inputs[2].strides(0));
|
||||
} else {
|
||||
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
|
||||
uint3 dims =
|
||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||
dims.z = (dims.z + 3) / 4;
|
||||
int n_per_thread = 4;
|
||||
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
|
||||
uint3 dims = make_uint3(dims_ / 2, T, dimz);
|
||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
|
||||
int64_t offset_stride = 0;
|
||||
if (inputs[1].ndim() > 0) {
|
||||
offset_stride = inputs[1].strides()[0];
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
grid,
|
||||
@@ -392,7 +415,8 @@ void RoPE::eval_gpu(
|
||||
std::log2(base_),
|
||||
strides,
|
||||
out_strides,
|
||||
in.size() / mat_size,
|
||||
offset_stride,
|
||||
N,
|
||||
dims);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
@@ -46,6 +45,7 @@ __global__ void kernel_sdpav_1pass(
|
||||
const T* K,
|
||||
const T* V,
|
||||
T* O,
|
||||
const T* sinks,
|
||||
__grid_constant__ const AttnParams params) {
|
||||
constexpr int BN = 32;
|
||||
constexpr int BD = 32;
|
||||
@@ -65,7 +65,7 @@ __global__ void kernel_sdpav_1pass(
|
||||
__shared__ U max_scores[BN];
|
||||
__shared__ U sum_exp_scores[BN];
|
||||
|
||||
const U scale_log2 = params.scale * 1.44269504089f;
|
||||
const U scale_log2 = params.scale * M_LOG2E;
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<32>(block);
|
||||
@@ -108,8 +108,12 @@ __global__ void kernel_sdpav_1pass(
|
||||
o[i] = 0.f;
|
||||
}
|
||||
|
||||
U max_score = -INFINITY;
|
||||
U max_score = Limits<U>::finite_min();
|
||||
U sum_exp_score = 0.f;
|
||||
if (sinks && warp_idx == 0) {
|
||||
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
|
||||
sum_exp_score = 1.f;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = kv_seq_idx; i < params.kL; i += BN) {
|
||||
@@ -167,7 +171,7 @@ __global__ void kernel_sdpav_1pass(
|
||||
U factor = exp2f(max_score - new_max);
|
||||
sum_exp_score =
|
||||
cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
|
||||
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||
sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
|
||||
|
||||
// Now we need to aggregate all the outputs
|
||||
PRAGMA_LOOP_UNROLL
|
||||
@@ -193,6 +197,7 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
const T* Q,
|
||||
const T* K,
|
||||
const T* V,
|
||||
const T* sinks,
|
||||
float* partials,
|
||||
float* sums,
|
||||
float* maxs,
|
||||
@@ -268,8 +273,12 @@ __global__ void kernel_sdpav_2pass_1(
|
||||
o[i] = 0.f;
|
||||
}
|
||||
|
||||
U max_score = -1e9;
|
||||
U max_score = Limits<U>::finite_min();
|
||||
U sum_exp_score = 0.f;
|
||||
if (sinks && warp_idx == 0 && block_idx == 0) {
|
||||
max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
|
||||
sum_exp_score = 1.f;
|
||||
}
|
||||
|
||||
// For each key
|
||||
for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
|
||||
@@ -410,7 +419,7 @@ __global__ void kernel_sdpav_2pass_2(
|
||||
U new_max = cg::reduce(warp, max_score, cg::greater<U>());
|
||||
U factor = exp2f(max_score - new_max);
|
||||
U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
|
||||
sum_exp_score = __frcp_rn(sum_exp_score);
|
||||
sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
|
||||
|
||||
PRAGMA_LOOP_UNROLL
|
||||
for (int i = 0; i < v_per_thread; i++) {
|
||||
@@ -463,10 +472,14 @@ void sdpa_vector_1pass_fallback(
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks) {
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
if (sinks) {
|
||||
encoder.set_input_array(*sinks);
|
||||
}
|
||||
encoder.set_output_array(o);
|
||||
|
||||
cu::AttnParams params{
|
||||
@@ -489,7 +502,7 @@ void sdpa_vector_1pass_fallback(
|
||||
dim3 block_dim(1024, 1, 1);
|
||||
|
||||
dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
|
||||
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||
dispatch_bool(do_causal, [&](auto do_causal) {
|
||||
dispatch_headdim(params.D, [&](auto headdim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
@@ -504,6 +517,7 @@ void sdpa_vector_1pass_fallback(
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
o.data<DataType>(),
|
||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
||||
params);
|
||||
});
|
||||
});
|
||||
@@ -518,7 +532,8 @@ void sdpa_vector_2pass_fallback(
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks) {
|
||||
cu::AttnParams params{
|
||||
/* int B = */ q.shape(0),
|
||||
/* int H = */ q.shape(1),
|
||||
@@ -559,7 +574,7 @@ void sdpa_vector_2pass_fallback(
|
||||
encoder.add_temporary(maxs);
|
||||
|
||||
dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
|
||||
dispatch_bool(do_causal_, [&](auto do_causal) {
|
||||
dispatch_bool(do_causal, [&](auto do_causal) {
|
||||
dispatch_headdim(params.D, [&](auto headdim) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
@@ -570,6 +585,10 @@ void sdpa_vector_2pass_fallback(
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
if (sinks) {
|
||||
encoder.set_input_array(*sinks);
|
||||
}
|
||||
|
||||
encoder.set_output_array(intermediate);
|
||||
encoder.set_output_array(sums);
|
||||
encoder.set_output_array(maxs);
|
||||
@@ -585,6 +604,7 @@ void sdpa_vector_2pass_fallback(
|
||||
q.data<DataType>(),
|
||||
k.data<DataType>(),
|
||||
v.data<DataType>(),
|
||||
sinks ? (*sinks).data<DataType>() : nullptr,
|
||||
intermediate.data<float>(),
|
||||
sums.data<float>(),
|
||||
maxs.data<float>(),
|
||||
@@ -627,15 +647,16 @@ void sdpa_vector_fallback(
|
||||
const array& v,
|
||||
const float scale,
|
||||
array& o,
|
||||
bool do_causal_ = false) {
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks) {
|
||||
int kL = k.shape(2);
|
||||
|
||||
if (kL > 1024) {
|
||||
return sdpa_vector_2pass_fallback(
|
||||
s, encoder, q, k, v, scale, o, do_causal_);
|
||||
s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||
} else {
|
||||
return sdpa_vector_1pass_fallback(
|
||||
s, encoder, q, k, v, scale, o, do_causal_);
|
||||
s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -691,7 +712,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
copies.reserve(3);
|
||||
copies.reserve(inputs.size());
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
@@ -703,6 +724,16 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
}
|
||||
};
|
||||
|
||||
// Checks that the headdim dimension has stride 1.
|
||||
auto is_matrix_contiguous = [](const array& arr) {
|
||||
return arr.strides(-1) == 1;
|
||||
};
|
||||
|
||||
std::optional<array> sinks = std::nullopt;
|
||||
if (has_sinks_) {
|
||||
sinks = copy_unless(is_matrix_contiguous, inputs.back());
|
||||
}
|
||||
|
||||
// We are in vector mode ie single query
|
||||
if (q_pre.shape(2) < 4) {
|
||||
auto q_copy_unless = [](const array& arr) {
|
||||
@@ -740,10 +771,6 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const auto& k = copy_unless(kv_copy_unless, k_pre);
|
||||
const auto& v = copy_unless(kv_copy_unless, v_pre);
|
||||
|
||||
for (const auto& cp : copies) {
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
|
||||
o.copy_shared_buffer(q);
|
||||
@@ -752,22 +779,26 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
int64_t str_oH = o.shape(3);
|
||||
int64_t str_oL = o.shape(1) * str_oH;
|
||||
int64_t str_oB = o.shape(2) * str_oL;
|
||||
size_t data_size = o.shape(0) * str_oB;
|
||||
|
||||
array::Flags flags{
|
||||
/* bool contiguous = */ 1,
|
||||
/* bool row_contiguous = */ o.shape(2) == 1,
|
||||
/* bool col_contiguous = */ 0,
|
||||
/* bool col_contiguous = */ o.size() == o.shape(3),
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
allocator::malloc(o.nbytes()),
|
||||
data_size,
|
||||
o.size(),
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(s, encoder, q, k, v, scale_, o, do_causal_);
|
||||
for (const auto& cp : copies) {
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(
|
||||
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
|
||||
}
|
||||
|
||||
// Full attention mode should never reach here
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
@@ -27,8 +30,7 @@ void concatenate_gpu(
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
// TODO: Handle concurrent outputs:
|
||||
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
|
||||
auto concurrent = cu::get_command_encoder(s).concurrent_context();
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis] * sizes[i];
|
||||
@@ -38,4 +40,71 @@ void concatenate_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
array compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes,
|
||||
const Stream& s) {
|
||||
Dtype dtype = indices.dtype();
|
||||
int nidx = axes.size();
|
||||
|
||||
std::string module_name =
|
||||
fmt::format("compute_dynamic_offset_{}_{}", dtype_to_string(dtype), nidx);
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::compute_dynamic_offset<{}, {}>",
|
||||
dtype_to_cuda_type(dtype),
|
||||
nidx);
|
||||
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::string source = R"(
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T, int NIDX>
|
||||
__global__ void compute_dynamic_offset(
|
||||
const T* indices,
|
||||
int64_t* offset,
|
||||
const __grid_constant__ Strides strides,
|
||||
const __grid_constant__ cuda::std::array<int, NIDX> axes) {
|
||||
int64_t acc = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
acc += indices[i] * strides[axes[i]];
|
||||
}
|
||||
*offset = acc;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
)";
|
||||
return std::make_tuple(false, std::move(source), std::vector{kernel_name});
|
||||
});
|
||||
|
||||
// Prepare output.
|
||||
array offset({1}, int64, nullptr, {});
|
||||
bool donate = indices.is_donatable() &&
|
||||
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.add_temporary(offset);
|
||||
encoder.set_input_array(indices);
|
||||
encoder.set_output_array(offset);
|
||||
|
||||
cu::KernelArgs args;
|
||||
args.append(indices);
|
||||
args.append(offset);
|
||||
args.append_ndim(strides);
|
||||
args.append(axes);
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
encoder.add_kernel_node(kernel, 1, 1, 0, args.args());
|
||||
|
||||
return offset;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
@@ -10,7 +9,7 @@
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
#include <cub/device/device_segmented_sort.cuh>
|
||||
#include <cub/device/device_segmented_radix_sort.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
@@ -80,7 +79,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
encoder.add_temporary(discard);
|
||||
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
nullptr,
|
||||
size,
|
||||
in.data<Type>(),
|
||||
@@ -91,6 +90,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
0,
|
||||
sizeof(Type) * 8,
|
||||
stream));
|
||||
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
@@ -105,7 +106,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
thrust::device_pointer_cast(indices.data<uint32_t>()),
|
||||
ModOp<uint32_t>{static_cast<uint32_t>(nsort)});
|
||||
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortPairs(
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortPairs(
|
||||
temp.data<void>(),
|
||||
size,
|
||||
in.data<Type>(),
|
||||
@@ -116,10 +117,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
0,
|
||||
sizeof(Type) * 8,
|
||||
stream));
|
||||
} else {
|
||||
size_t size;
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
||||
nullptr,
|
||||
size,
|
||||
in.data<Type>(),
|
||||
@@ -128,6 +131,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
0,
|
||||
sizeof(Type) * 8,
|
||||
stream));
|
||||
|
||||
array temp(allocator::malloc(size), {static_cast<int>(size)}, uint8);
|
||||
@@ -135,7 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
|
||||
// Start capturing after allocations
|
||||
auto capture = encoder.capture_context();
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedSort::StableSortKeys(
|
||||
CHECK_CUDA_ERROR(cub::DeviceSegmentedRadixSort::SortKeys(
|
||||
temp.data<void>(),
|
||||
size,
|
||||
in.data<Type>(),
|
||||
@@ -144,6 +149,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
in.data_size() / nsort,
|
||||
offsets,
|
||||
offsets + 1,
|
||||
0,
|
||||
sizeof(Type) * 8,
|
||||
stream));
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -156,7 +156,25 @@ void ternary_op_gpu_inplace(
|
||||
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
if (topt == TernaryOpType::General) {
|
||||
if (topt == TernaryOpType::VectorVectorVector ||
|
||||
topt == TernaryOpType::ScalarScalarScalar) {
|
||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
constexpr int N_READS = 16 / sizeof(DType);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||
encoder.add_kernel_node(
|
||||
cu::ternary_v<Op, DType, IdxT, N_READS>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size());
|
||||
});
|
||||
} else {
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
@@ -225,23 +243,6 @@ void ternary_op_gpu_inplace(
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
constexpr int N_READS = 16 / sizeof(DType);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||
encoder.add_kernel_node(
|
||||
cu::ternary_v<Op, DType, IdxT, N_READS>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<bool>(),
|
||||
b.data<DType>(),
|
||||
c.data<DType>(),
|
||||
out.data<DType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void unary_v(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_READS > size) {
|
||||
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||
out[i] = Op{}(in[i]);
|
||||
}
|
||||
} else {
|
||||
auto in_vec = load_vector<N_READS>(in, index);
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(in_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void unary_g(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size_rest,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides,
|
||||
int ndim) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto grid = cg::this_grid();
|
||||
IdxT index_rest =
|
||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
||||
if (index_rest >= size_rest) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto shape_x = shape[ndim - 1];
|
||||
auto stride_x = strides[ndim - 1];
|
||||
IdxT index_x =
|
||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
||||
auto idx =
|
||||
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
|
||||
auto in_vec =
|
||||
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(in_vec[i]);
|
||||
}
|
||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_unary_op() {
|
||||
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
|
||||
std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
|
||||
std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
|
||||
std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
|
||||
std::is_same_v<Op, Sigmoid>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
|
||||
return std::is_same_v<In, Out> && !mlx::core::is_complex_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Conjugate>) {
|
||||
return std::is_same_v<In, Out> && mlx::core::is_complex_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
|
||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
|
||||
std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
|
||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
|
||||
std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
|
||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
|
||||
std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
|
||||
std::is_same_v<Op, Tanh>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
|
||||
return mlx::core::is_complex_v<In> && std::is_same_v<Out, float>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogicalNot>) {
|
||||
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void unary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& in = inputs[0];
|
||||
if (in.size() == 0) {
|
||||
return;
|
||||
}
|
||||
bool contig = in.flags().contiguous;
|
||||
bool large;
|
||||
if (!contig) {
|
||||
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
|
||||
} else {
|
||||
large = in.data_size() > UINT32_MAX;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
dispatch_bool(large, [&](auto large) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
if (contig) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
constexpr int N_READS = 16 / sizeof(OutType);
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
out.data_size(), out.shape(), out.strides(), large, N_READS);
|
||||
encoder.add_kernel_node(
|
||||
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size());
|
||||
} else {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||
auto ndim = shape.size();
|
||||
int work_per_thread = 1;
|
||||
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
|
||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||
auto rest = out.size() / dim0;
|
||||
if (dim0 >= 4) {
|
||||
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
|
||||
work_per_thread = 4;
|
||||
}
|
||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
||||
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
{num_blocks_x, num_blocks_y},
|
||||
block_dims,
|
||||
0,
|
||||
in.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
rest,
|
||||
const_param(shape),
|
||||
const_param(strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do unary op {} on input of {} with output of {}.",
|
||||
op,
|
||||
dtype_to_string(in.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void unary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
set_unary_output_data(inputs[0], out);
|
||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define UNARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = out.primitive().stream(); \
|
||||
unary_op_gpu<cu::func>(inputs, out, name(), s); \
|
||||
}
|
||||
|
||||
UNARY_GPU(Abs)
|
||||
UNARY_GPU(ArcCos)
|
||||
UNARY_GPU(ArcCosh)
|
||||
UNARY_GPU(ArcSin)
|
||||
UNARY_GPU(ArcSinh)
|
||||
UNARY_GPU(ArcTan)
|
||||
UNARY_GPU(ArcTanh)
|
||||
UNARY_GPU(BitwiseInvert)
|
||||
UNARY_GPU(Ceil)
|
||||
UNARY_GPU(Conjugate)
|
||||
UNARY_GPU(Cos)
|
||||
UNARY_GPU(Cosh)
|
||||
UNARY_GPU(Erf)
|
||||
UNARY_GPU(ErfInv)
|
||||
UNARY_GPU(Exp)
|
||||
UNARY_GPU(Expm1)
|
||||
UNARY_GPU(Floor)
|
||||
UNARY_GPU(Imag)
|
||||
UNARY_GPU(Log1p)
|
||||
UNARY_GPU(LogicalNot)
|
||||
UNARY_GPU(Negative)
|
||||
UNARY_GPU(Real)
|
||||
UNARY_GPU(Sigmoid)
|
||||
UNARY_GPU(Sign)
|
||||
UNARY_GPU(Sin)
|
||||
UNARY_GPU(Sinh)
|
||||
UNARY_GPU(Square)
|
||||
UNARY_GPU(Tan)
|
||||
UNARY_GPU(Tanh)
|
||||
|
||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Log::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
unary_op_gpu<cu::Log>(inputs, out, name(), s);
|
||||
break;
|
||||
case Base::two:
|
||||
unary_op_gpu<cu::Log2>(inputs, out, name(), s);
|
||||
break;
|
||||
case Base::ten:
|
||||
unary_op_gpu<cu::Log10>(inputs, out, name(), s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Round::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto& s = out.primitive().stream();
|
||||
if (issubdtype(in.dtype(), inexact)) {
|
||||
unary_op_gpu<cu::Round>(inputs, out, name(), s);
|
||||
} else {
|
||||
// No-op integer types
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Sort::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
if (recip_) {
|
||||
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
|
||||
} else {
|
||||
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -8,36 +8,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
CudaStream::CudaStream(cu::Device& device) {
|
||||
device.make_current();
|
||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking));
|
||||
}
|
||||
|
||||
CudaStream::~CudaStream() {
|
||||
CHECK_CUDA_ERROR(cudaStreamDestroy(stream_));
|
||||
}
|
||||
|
||||
CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {}
|
||||
|
||||
CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) {
|
||||
other.handle_ = nullptr;
|
||||
};
|
||||
|
||||
CudaGraphExec::~CudaGraphExec() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void CudaGraphExec::instantiate(cudaGraph_t graph) {
|
||||
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
|
||||
}
|
||||
|
||||
void CudaGraphExec::reset() {
|
||||
if (handle_ != nullptr) {
|
||||
CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_));
|
||||
handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void check_cublas_error(const char* name, cublasStatus_t err) {
|
||||
if (err != CUBLAS_STATUS_SUCCESS) {
|
||||
// TODO: Use cublasGetStatusString when it is widely available.
|
||||
@@ -96,4 +66,24 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
CudaGraph::CudaGraph(cu::Device& device) {
|
||||
device.make_current();
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&handle_, 0));
|
||||
}
|
||||
|
||||
void CudaGraph::end_capture(cudaStream_t stream) {
|
||||
assert(handle_ == nullptr);
|
||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
||||
}
|
||||
|
||||
void CudaGraphExec::instantiate(cudaGraph_t graph) {
|
||||
assert(handle_ == nullptr);
|
||||
CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0));
|
||||
}
|
||||
|
||||
CudaStream::CudaStream(cu::Device& device) {
|
||||
device.make_current();
|
||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&handle_, cudaStreamNonBlocking));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file include utilies that are used by C++ code (i.e. .cpp files).
|
||||
// This file include utilities that are used by C++ code (i.e. .cpp files).
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -12,48 +12,11 @@ namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
class Device;
|
||||
|
||||
}
|
||||
|
||||
struct Dtype;
|
||||
|
||||
// Cuda stream managed with RAII.
|
||||
class CudaStream {
|
||||
public:
|
||||
explicit CudaStream(cu::Device& device);
|
||||
~CudaStream();
|
||||
|
||||
CudaStream(const CudaStream&) = delete;
|
||||
CudaStream& operator=(const CudaStream&) = delete;
|
||||
|
||||
operator cudaStream_t() const {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaStream_t stream_;
|
||||
};
|
||||
|
||||
// Move-able RAII handle of cudaGraphExec_t.
|
||||
class CudaGraphExec {
|
||||
public:
|
||||
CudaGraphExec(cudaGraphExec_t handle = nullptr);
|
||||
CudaGraphExec(CudaGraphExec&& other);
|
||||
~CudaGraphExec();
|
||||
|
||||
CudaGraphExec(const CudaGraphExec&) = delete;
|
||||
CudaGraphExec& operator=(const CudaGraphExec&) = delete;
|
||||
|
||||
void instantiate(cudaGraph_t graph);
|
||||
void reset();
|
||||
|
||||
operator cudaGraphExec_t() const {
|
||||
return handle_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaGraphExec_t handle_;
|
||||
};
|
||||
|
||||
// Throw exception if the cuda API does not succeed.
|
||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||
void check_cuda_error(const char* name, cudaError_t err);
|
||||
@@ -66,4 +29,75 @@ void check_cuda_error(const char* name, CUresult err);
|
||||
// Convert Dtype to CUDA C++ types.
|
||||
const char* dtype_to_cuda_type(const Dtype& dtype);
|
||||
|
||||
// Base class for RAII managed CUDA resources.
|
||||
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
||||
class CudaHandle {
|
||||
public:
|
||||
CudaHandle(Handle handle = nullptr) : handle_(handle) {}
|
||||
|
||||
CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
|
||||
assert(this != &other);
|
||||
other.handle_ = nullptr;
|
||||
}
|
||||
|
||||
~CudaHandle() {
|
||||
reset();
|
||||
}
|
||||
|
||||
CudaHandle(const CudaHandle&) = delete;
|
||||
CudaHandle& operator=(const CudaHandle&) = delete;
|
||||
|
||||
CudaHandle& operator=(CudaHandle&& other) {
|
||||
assert(this != &other);
|
||||
reset();
|
||||
std::swap(handle_, other.handle_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
if (handle_ != nullptr) {
|
||||
CHECK_CUDA_ERROR(Destroy(handle_));
|
||||
handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
operator Handle() const {
|
||||
return handle_;
|
||||
}
|
||||
|
||||
protected:
|
||||
Handle handle_;
|
||||
};
|
||||
|
||||
// Wrappers of CUDA resources.
|
||||
class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
|
||||
public:
|
||||
using CudaHandle::CudaHandle;
|
||||
explicit CudaGraph(cu::Device& device);
|
||||
void end_capture(cudaStream_t stream);
|
||||
};
|
||||
|
||||
class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
|
||||
public:
|
||||
void instantiate(cudaGraph_t graph);
|
||||
};
|
||||
|
||||
class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
|
||||
public:
|
||||
explicit CudaStream(cu::Device& device);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline uint max_occupancy_block_dim(T kernel) {
|
||||
int _, block_dim;
|
||||
if constexpr (std::is_same_v<T, CUfunction>) {
|
||||
CHECK_CUDA_ERROR(
|
||||
cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0));
|
||||
} else {
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel));
|
||||
}
|
||||
return block_dim;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -5,8 +5,9 @@
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
Worker::Worker()
|
||||
: signal_stream_(device(mlx::core::Device::gpu)),
|
||||
Worker::Worker(Device& d)
|
||||
: signal_stream_(d),
|
||||
signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),
|
||||
worker_(&Worker::thread_fn, this) {}
|
||||
|
||||
Worker::~Worker() {
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
@@ -16,7 +15,7 @@ namespace mlx::core::cu {
|
||||
// Run tasks in worker thread, synchronized with cuda stream.
|
||||
class Worker {
|
||||
public:
|
||||
Worker();
|
||||
explicit Worker(Device& d);
|
||||
~Worker();
|
||||
|
||||
Worker(const Worker&) = delete;
|
||||
|
||||
@@ -52,4 +52,70 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
|
||||
return arr_copy;
|
||||
}
|
||||
|
||||
void reshape_gpu(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
|
||||
int ndim = x.ndim();
|
||||
if (start_axis < 0) {
|
||||
start_axis += ndim;
|
||||
}
|
||||
if (end_axis < 0) {
|
||||
end_axis += ndim;
|
||||
}
|
||||
start_axis = std::max(0, start_axis);
|
||||
end_axis = std::min(ndim - 1, end_axis);
|
||||
|
||||
return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s);
|
||||
}
|
||||
|
||||
array reshape_in_eval(const array& x, Shape shape, Stream s) {
|
||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||
reshape_gpu(x, out, s);
|
||||
return out;
|
||||
}
|
||||
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||
int ndim = x.ndim();
|
||||
if (axis1 < 0) {
|
||||
axis1 += ndim;
|
||||
}
|
||||
if (axis2 < 0) {
|
||||
axis2 += ndim;
|
||||
}
|
||||
|
||||
auto shape = x.shape();
|
||||
std::swap(shape[axis1], shape[axis2]);
|
||||
auto strides = x.strides();
|
||||
std::swap(strides[axis1], strides[axis2]);
|
||||
|
||||
auto [data_size, row_contiguous, col_contiguous] =
|
||||
check_contiguity(shape, strides);
|
||||
bool contiguous = data_size == x.data_size();
|
||||
|
||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(
|
||||
x,
|
||||
std::move(strides),
|
||||
{contiguous, row_contiguous, col_contiguous},
|
||||
x.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -20,8 +20,8 @@ void copy_gpu_inplace(
|
||||
int64_t o_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||
std::optional<array> dynamic_i_offset = std::nullopt,
|
||||
std::optional<array> dynamic_o_offset = std::nullopt);
|
||||
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||
@@ -46,4 +46,12 @@ void fill_gpu(const array& val, array& out, const Stream& s);
|
||||
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||
array contiguous_copy_gpu(const array& arr, const Stream& s);
|
||||
|
||||
// Copy data from |in| and transpose to |out|'s shape.
|
||||
void reshape_gpu(const array& in, array& out, Stream s);
|
||||
|
||||
// Like the normal ops but safe to call in eval_gpu.
|
||||
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
|
||||
array reshape_in_eval(const array& x, Shape shape, Stream s);
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -20,29 +20,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void reshape(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
|
||||
eval(inputs, out);
|
||||
@@ -103,6 +80,74 @@ void Depends::eval_gpu(
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void DynamicSlice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("DynamicSlice::eval_gpu");
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& start = inputs[1];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto s = stream();
|
||||
auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ out.shape(),
|
||||
/* const Strides& i_strides = */ in.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ s,
|
||||
/* std::optional<array> dynamic_i_offset = */ std::move(in_offset),
|
||||
/* std::optional<array> dynamic_o_offset = */ std::nullopt);
|
||||
}
|
||||
|
||||
void DynamicSliceUpdate::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
MLX_PROFILER_RANGE("DynamicSliceUpdate::eval_gpu");
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& in = inputs[0];
|
||||
auto& upd = inputs[1];
|
||||
auto& start_indices = inputs[2];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy or donate input to output
|
||||
auto s = stream();
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
|
||||
|
||||
auto out_offset =
|
||||
compute_dynamic_offset(start_indices, out.strides(), axes_, s);
|
||||
copy_gpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ upd.shape(),
|
||||
/* const Strides& i_strides = */ upd.strides(),
|
||||
/* const Strides& o_strides = */ out.strides(),
|
||||
/* int64_t i_offset = */ 0,
|
||||
/* int64_t o_offset = */ 0,
|
||||
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||
/* const Stream& s = */ s,
|
||||
/* std::optional<array> dynamic_i_offset = */ std::nullopt,
|
||||
/* std::optional<array> dynamic_o_offset = */ std::move(out_offset));
|
||||
}
|
||||
|
||||
void ExpandDims::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("ExpandDims::eval_gpu");
|
||||
eval(inputs, out);
|
||||
@@ -124,7 +169,7 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Flatten::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
reshape_gpu(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -150,7 +195,7 @@ void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Reshape::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
reshape_gpu(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void Split::eval_gpu(
|
||||
@@ -224,7 +269,7 @@ void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
|
||||
reshape(inputs[0], out, stream());
|
||||
reshape_gpu(inputs[0], out, stream());
|
||||
}
|
||||
|
||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
@@ -27,4 +27,10 @@ void pad_gpu(
|
||||
const Shape& low_pad_size,
|
||||
const Stream& s);
|
||||
|
||||
array compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -33,10 +33,11 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||
make_jit_source(scatter kernels/indexing.h)
|
||||
make_jit_source(gather kernels/indexing.h)
|
||||
make_jit_source(gather_axis)
|
||||
make_jit_source(scatter_axis)
|
||||
make_jit_source(indexing/scatter kernels/indexing/indexing.h)
|
||||
make_jit_source(indexing/gather kernels/indexing/indexing.h)
|
||||
make_jit_source(indexing/gather_front kernels/indexing/indexing.h)
|
||||
make_jit_source(indexing/gather_axis)
|
||||
make_jit_source(indexing/scatter_axis)
|
||||
make_jit_source(hadamard)
|
||||
|
||||
if(MLX_METAL_JIT)
|
||||
@@ -77,7 +78,10 @@ if(MLX_METAL_JIT)
|
||||
make_jit_source(steel/conv/kernels/steel_conv)
|
||||
make_jit_source(steel/conv/kernels/steel_conv_general kernels/steel/defines.h
|
||||
kernels/steel/conv/loaders/loader_general.h)
|
||||
make_jit_source(quantized)
|
||||
|
||||
make_jit_source(quantized_utils)
|
||||
make_jit_source(quantized kernels/quantized_utils.h)
|
||||
make_jit_source(fp4_quantized kernels/quantized_utils.h)
|
||||
make_jit_source(gemv_masked)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
|
||||
<< N;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
|
||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
|
||||
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
|
||||
: "l")
|
||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"implicit_gemm_conv_2d_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
"_channel_",
|
||||
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
|
||||
"_filter_",
|
||||
small_filter ? 's' : 'l');
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_conv_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
kname,
|
||||
out,
|
||||
bm,
|
||||
bn,
|
||||
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
|
||||
{
|
||||
int bc = 32;
|
||||
int bo = 4;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_weight_transform_",
|
||||
type_to_name(out),
|
||||
"_bc",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_input_transform_",
|
||||
type_to_name(out),
|
||||
"_bc",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_output_transform_",
|
||||
type_to_name(out),
|
||||
"_bo",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
std::ostringstream kname;
|
||||
kname << "depthwise_conv_2d_" << type_to_name(out);
|
||||
std::string base_name = kname.str();
|
||||
std::string base_name;
|
||||
base_name.reserve(32);
|
||||
concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
|
||||
|
||||
const int N = conv_params.N;
|
||||
const int ker_h = conv_params.wS[0];
|
||||
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_ker_h_" << ker_h
|
||||
<< "_ker_w_" << ker_w
|
||||
<< "_str_h_" << str_h
|
||||
<< "_str_w_" << str_w
|
||||
<< "_tgp_h_" << th
|
||||
<< "_tgp_w_" << tw
|
||||
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
std::string hash_name;
|
||||
hash_name.reserve(64);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_ker_h_", ker_h,
|
||||
"_ker_w_", ker_w,
|
||||
"_str_h_", str_h,
|
||||
"_str_w_", str_w,
|
||||
"_tgp_h_", th,
|
||||
"_tgp_w_", tw,
|
||||
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void depthwise_conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array wt,
|
||||
array out) {
|
||||
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
|
||||
std::string base_name;
|
||||
base_name.reserve(32);
|
||||
concatenate(
|
||||
base_name,
|
||||
"depthwise_conv_1d_",
|
||||
large ? "_large" : "",
|
||||
type_to_name(out));
|
||||
|
||||
if (!wt.flags().row_contiguous) {
|
||||
wt = contiguous_copy_gpu(wt, s);
|
||||
d.add_temporary(wt, s.index);
|
||||
}
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto B = in.shape(0);
|
||||
auto Tout = out.shape(1);
|
||||
auto D = in.shape(2);
|
||||
auto K = wt.shape(1);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
if (large) {
|
||||
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
|
||||
compute_encoder.set_bytes(strides, 3, 3);
|
||||
|
||||
} else {
|
||||
int strides[3] = {
|
||||
static_cast<int>(in.strides(0)),
|
||||
static_cast<int>(in.strides(1)),
|
||||
static_cast<int>(in.strides(2))};
|
||||
compute_encoder.set_bytes(strides, 3, 3);
|
||||
}
|
||||
|
||||
compute_encoder.set_bytes(K, 4);
|
||||
auto group_dims = get_block_dims(D, Tout, B);
|
||||
MTL::Size grid_dims = MTL::Size(D, Tout, B);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -790,8 +873,15 @@ void conv_1D_gpu(
|
||||
bool is_idil_one = in_dilation[0] == 1;
|
||||
int C = in.shape(2);
|
||||
int O = wt.shape(0);
|
||||
const int C_per_group = in.shape(2) / groups;
|
||||
const int O_per_group = wt.shape(0) / groups;
|
||||
// Fast path for fully separable 1D convolution
|
||||
if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
|
||||
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
|
||||
depthwise_conv_1D_gpu(s, d, in, wt, out);
|
||||
return;
|
||||
}
|
||||
|
||||
const int C_per_group = C / groups;
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
|
||||
@@ -20,8 +20,8 @@ void copy_gpu_inplace(
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_i_offset /* = std::nullopt */,
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
std::optional<array> dynamic_i_offset /* = std::nullopt */,
|
||||
std::optional<array> dynamic_o_offset /* = std::nullopt */) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -172,7 +172,7 @@ std::string write_template(
|
||||
return template_def.str();
|
||||
}
|
||||
|
||||
MetalKernelFunction metal_kernel(
|
||||
CustomKernelFunction metal_kernel(
|
||||
const std::string& name,
|
||||
const std::vector<std::string>& input_names,
|
||||
const std::vector<std::string>& output_names,
|
||||
@@ -316,7 +316,10 @@ MetalKernelFunction metal_kernel(
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous,
|
||||
init_value),
|
||||
init_value,
|
||||
std::vector<ScalarArg>{},
|
||||
false,
|
||||
0),
|
||||
std::move(inputs));
|
||||
};
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
|
||||
}
|
||||
|
||||
MTL::Library* load_default_library(MTL::Device* device) {
|
||||
NS::Error* error[4];
|
||||
NS::Error* error[5];
|
||||
MTL::Library* lib;
|
||||
// First try the colocated mlx.metallib
|
||||
std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
|
||||
@@ -127,12 +127,19 @@ MTL::Library* load_default_library(MTL::Device* device) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Try lo load resources from Framework resources if SwiftPM wrapped as a
|
||||
// dynamic framework.
|
||||
std::tie(lib, error[3]) = load_colocated_library(device, "Resources/default");
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
|
||||
// Finally try default_mtllib_path
|
||||
std::tie(lib, error[3]) = load_library_from_path(device, default_mtllib_path);
|
||||
std::tie(lib, error[4]) = load_library_from_path(device, default_mtllib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "Failed to load the default metallib. ";
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
if (error[i] != nullptr) {
|
||||
msg << error[i]->localizedDescription()->utf8String() << " ";
|
||||
}
|
||||
@@ -464,6 +471,10 @@ void Device::end_encoding(int index) {
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto& stream = get_stream_(index);
|
||||
if (stream.encoder == nullptr) {
|
||||
// Ensure there is an active command buffer
|
||||
if (stream.buffer == nullptr) {
|
||||
get_command_buffer(index);
|
||||
}
|
||||
stream.encoder = std::make_unique<CommandEncoder>(stream);
|
||||
stream.fence = std::make_shared<Fence>(device_->newFence());
|
||||
}
|
||||
|
||||
@@ -60,22 +60,12 @@ struct CommandEncoder {
|
||||
enc_->updateFence(fence);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
|
||||
void set_vector_bytes(const Vec& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx);
|
||||
}
|
||||
template <typename T>
|
||||
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
|
||||
return set_vector_bytes(vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
// TODO: Code is duplicated but they should be deleted soon.
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, int idx) {
|
||||
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
|
||||
void set_vector_bytes(const Vec& vec, int idx) {
|
||||
return set_vector_bytes(vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
|
||||
@@ -52,8 +52,10 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
slice_size *= s;
|
||||
}
|
||||
|
||||
bool large_index = nidx && inputs[1].size() > INT32_MAX;
|
||||
bool large_src = src.size() > INT32_MAX;
|
||||
@@ -61,6 +63,55 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
bool large = large_index || large_src || large_out;
|
||||
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
|
||||
if (src.flags().row_contiguous && nidx == 1 && axes_[0] == 0 &&
|
||||
inputs[1].flags().row_contiguous && slice_size == src.strides()[0]) {
|
||||
int work_per_thread = (slice_size > 8 && src.dtype().size() < 4) ? 2 : 1;
|
||||
auto& indices = inputs[1];
|
||||
std::string kernel_name = fmt::format(
|
||||
"gather_front{0}_{1}_{2}_{3}",
|
||||
type_to_name(out),
|
||||
idx_type_name,
|
||||
large ? "int64_t" : "int",
|
||||
work_per_thread);
|
||||
std::string lib_name = kernel_name;
|
||||
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source = metal::utils();
|
||||
kernel_source += metal::gather_front();
|
||||
kernel_source += get_template_definition(
|
||||
kernel_name,
|
||||
"gather_front",
|
||||
get_type_string(out.dtype()),
|
||||
get_type_string(indices.dtype()),
|
||||
large ? "int64_t" : "int",
|
||||
work_per_thread);
|
||||
|
||||
return kernel_source;
|
||||
});
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;
|
||||
size_t dim_y = indices.size();
|
||||
auto group_dims = get_block_dims(dim_x, dim_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1);
|
||||
|
||||
compute_encoder.set_input_array(src, 0);
|
||||
compute_encoder.set_input_array(indices, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder.set_bytes(slice_size, 3);
|
||||
compute_encoder.set_bytes(src.shape(0), 4);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"gather{0}{1}_{2}_{3}_{4}",
|
||||
type_to_name(out),
|
||||
@@ -96,11 +147,6 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
slice_size *= s;
|
||||
}
|
||||
|
||||
// Launch 3D grid of threads
|
||||
// First two dimensions for the indices, the last one for the slice
|
||||
size_t dim0 = 1;
|
||||
@@ -332,7 +378,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
// Need placeholders so Metal doesn't complain
|
||||
int shape_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 3);
|
||||
@@ -347,7 +393,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
// Need placeholders so Metal doesn't complain
|
||||
int shape_ = 0;
|
||||
int64_t stride_ = 0;
|
||||
compute_encoder.set_bytes(shape_, 7);
|
||||
|
||||
@@ -19,9 +19,12 @@ const char* binary_two();
|
||||
const char* copy();
|
||||
const char* fft();
|
||||
const char* gather_axis();
|
||||
const char* gather_front();
|
||||
const char* hadamard();
|
||||
const char* logsumexp();
|
||||
const char* quantized_utils();
|
||||
const char* quantized();
|
||||
const char* fp4_quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* scatter_axis();
|
||||
|
||||
@@ -144,8 +144,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
auto t_str = get_type_string(type);
|
||||
std::string kernel_source = metal::utils();
|
||||
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
|
||||
const std::array<std::pair<std::string, std::string>, 4> kernel_types = {{
|
||||
{"v2", "ternary_v2"},
|
||||
const std::array<std::pair<std::string, std::string>, 3> kernel_types = {{
|
||||
{"g1large", "ternary_g_nd1"},
|
||||
{"g2large", "ternary_g_nd2"},
|
||||
{"g3large", "ternary_g_nd3"},
|
||||
@@ -154,13 +153,29 @@ MTL::ComputePipelineState* get_ternary_kernel(
|
||||
kernel_source +=
|
||||
get_template_definition(name + "_" + lib_name, func, t_str, op);
|
||||
}
|
||||
|
||||
kernel_source += get_template_definition(
|
||||
"v2_" + lib_name, "ternary_v2", t_str, op, false, false);
|
||||
kernel_source += get_template_definition(
|
||||
"sv2_" + lib_name, "ternary_v2", t_str, op, true, false);
|
||||
kernel_source += get_template_definition(
|
||||
"vs2_" + lib_name, "ternary_v2", t_str, op, false, true);
|
||||
|
||||
if (get_work_per_thread(type) > 1) {
|
||||
kernel_source +=
|
||||
get_template_definition("vn_" + lib_name, "ternary_v", t_str, op);
|
||||
kernel_source += get_template_definition(
|
||||
"vn_" + lib_name, "ternary_v", t_str, op, false, false);
|
||||
kernel_source += get_template_definition(
|
||||
"svn_" + lib_name, "ternary_v", t_str, op, true, false);
|
||||
kernel_source += get_template_definition(
|
||||
"vsn_" + lib_name, "ternary_v", t_str, op, false, true);
|
||||
}
|
||||
|
||||
kernel_source +=
|
||||
get_template_definition("v_" + lib_name, "ternary_v", t_str, op, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"v_" + lib_name, "ternary_v", t_str, op, false, false, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"sv_" + lib_name, "ternary_v", t_str, op, true, false, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"vs_" + lib_name, "ternary_v", t_str, op, false, true, 1);
|
||||
kernel_source += get_template_definition(
|
||||
"g1_" + lib_name, "ternary_g_nd1", t_str, op, "int");
|
||||
kernel_source += get_template_definition(
|
||||
@@ -804,13 +819,19 @@ MTL::ComputePipelineState* get_fft_kernel(
|
||||
MTL::ComputePipelineState* get_quantized_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& template_def) {
|
||||
const std::string& template_def,
|
||||
const std::string& mode) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
|
||||
<< template_def;
|
||||
return kernel_source.str();
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::quantized_utils(),
|
||||
(mode == "affine") ? metal::quantized() : metal::fp4_quantized(),
|
||||
template_def);
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
@@ -823,6 +844,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
const array& x,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& mode,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
@@ -833,22 +855,40 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
auto lib = d.get_library(lib_name, [&]() {
|
||||
std::string kernel_source;
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::utils(),
|
||||
metal::gemm(),
|
||||
metal::quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
"gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
bits,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm());
|
||||
if (mode == "affine") {
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
mode + "_gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
bits,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
} else {
|
||||
concatenate(
|
||||
kernel_source,
|
||||
metal::fp4_quantized(),
|
||||
get_template_definition(
|
||||
lib_name,
|
||||
mode + "_gather_qmm_rhs",
|
||||
get_type_string(x.dtype()),
|
||||
group_size,
|
||||
"uint8_t",
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
transpose));
|
||||
}
|
||||
return kernel_source;
|
||||
});
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
|
||||
@@ -238,7 +238,8 @@ MTL::ComputePipelineState* get_fft_kernel(
|
||||
MTL::ComputePipelineState* get_quantized_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& template_def);
|
||||
const std::string& template_def,
|
||||
const std::string& mode);
|
||||
|
||||
MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
metal::Device& d,
|
||||
@@ -248,6 +249,7 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
const array& x,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& mode,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
|
||||
@@ -108,7 +108,8 @@ if(NOT MLX_METAL_JIT)
|
||||
reduction/reduce_all.h
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h)
|
||||
build_kernel(quantized quantized.h ${STEEL_HEADERS})
|
||||
build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||
build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
|
||||
@@ -223,6 +223,11 @@ struct Power {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
// Undefined to raise integer to negative power
|
||||
if (exp < 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
|
||||
@@ -104,6 +104,27 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||
return {a.real + b.real, a.imag + b.imag};
|
||||
}
|
||||
|
||||
constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
|
||||
a.real += b.real;
|
||||
a.imag += b.imag;
|
||||
return a;
|
||||
}
|
||||
|
||||
constexpr threadgroup complex64_t& operator+=(
|
||||
threadgroup complex64_t& a,
|
||||
complex64_t b) {
|
||||
a.real += b.real;
|
||||
a.imag += b.imag;
|
||||
return a;
|
||||
}
|
||||
|
||||
constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
|
||||
a.real += b.real;
|
||||
a.imag += b.imag;
|
||||
return a;
|
||||
}
|
||||
|
||||
constexpr complex64_t operator+(float a, complex64_t b) {
|
||||
return {a + b.real, b.imag};
|
||||
}
|
||||
|
||||
@@ -166,115 +166,6 @@ instantiate_naive_unfold_nd_dims(float32, float);
|
||||
instantiate_naive_unfold_nd_dims(float16, half);
|
||||
instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Slow and naive conv2d kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in threads) */
|
||||
const int BN, /* Threadgroup cols (in threads) */
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const int BC = 16>
|
||||
[[kernel]] void naive_conv_2d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* wt [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)simd_gid;
|
||||
(void)simd_lid;
|
||||
|
||||
out += tid.z * params.out_strides[0];
|
||||
in += tid.z * params.in_strides[0];
|
||||
|
||||
int out_o = tid.y * BN * TN + lid.y * TN;
|
||||
int out_hw = tid.x * BM * TM + lid.x * TM;
|
||||
|
||||
int out_h[TM];
|
||||
int out_w[TN];
|
||||
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
int mm = (out_hw + m);
|
||||
out_h[m] = mm / params.oS[1];
|
||||
out_w[m] = mm % params.oS[1];
|
||||
}
|
||||
|
||||
T in_local[TM];
|
||||
T wt_local[TN];
|
||||
T out_local[TM * TN] = {T(0)};
|
||||
|
||||
for (int h = 0; h < params.wS[0]; ++h) {
|
||||
for (int w = 0; w < params.wS[1]; ++w) {
|
||||
for (int c = 0; c < params.C; ++c) {
|
||||
// Local in
|
||||
for (int m = 0; m < TM; m++) {
|
||||
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
|
||||
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
|
||||
|
||||
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
|
||||
in_local[m] = valid
|
||||
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
|
||||
: T(0);
|
||||
}
|
||||
|
||||
// Load weight
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
int o = out_o + n;
|
||||
wt_local[n] = o < params.O
|
||||
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
|
||||
w * params.wt_strides[2] + c]
|
||||
: T(0);
|
||||
}
|
||||
|
||||
// Accumulate
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
out_local[m * TN + n] += in_local[m] * wt_local[n];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int m = 0; m < TM; ++m) {
|
||||
for (int n = 0; n < TN; ++n) {
|
||||
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
|
||||
(out_o + n) < params.O)
|
||||
out[out_h[m] * params.out_strides[1] +
|
||||
out_w[m] * params.out_strides[2] + out_o + n] =
|
||||
out_local[m * TN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiations
|
||||
|
||||
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
|
||||
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
|
||||
"_tn" #tn)]] [[kernel]] void \
|
||||
naive_conv_2d<itype, bm, bn, tm, tn>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
const device itype* wt [[buffer(1)]], \
|
||||
device itype* out [[buffer(2)]], \
|
||||
const constant MLXConvParams<2>& params [[buffer(3)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_naive_conv_2d_blocks(name, itype) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
|
||||
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
|
||||
|
||||
instantiate_naive_conv_2d_blocks(float32, float);
|
||||
instantiate_naive_conv_2d_blocks(float16, half);
|
||||
instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Depthwise convolution kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -397,6 +288,40 @@ instantiate_depthconv2d(float32, float);
|
||||
instantiate_depthconv2d(float16, half);
|
||||
instantiate_depthconv2d(bfloat16, bfloat16_t);
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
[[kernel]] void depthwise_conv_1d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* w [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
constant const IdxT strides[3],
|
||||
constant const int& kernel_size,
|
||||
uint3 tid [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
|
||||
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
|
||||
w += tid.x * kernel_size;
|
||||
|
||||
float acc = 0.0;
|
||||
for (int i = 0; i < kernel_size; ++i) {
|
||||
acc += static_cast<float>(in[0]) * w[i];
|
||||
in += strides[1];
|
||||
}
|
||||
*out = static_cast<T>(acc);
|
||||
}
|
||||
|
||||
#define instantiate_depthconv1d(iname, itype) \
|
||||
instantiate_kernel( \
|
||||
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
|
||||
instantiate_kernel( \
|
||||
"depthwise_conv_1d_" #iname "_large", \
|
||||
depthwise_conv_1d, \
|
||||
itype, \
|
||||
int64_t)
|
||||
|
||||
instantiate_depthconv1d(float32, float);
|
||||
instantiate_depthconv1d(float16, half);
|
||||
instantiate_depthconv1d(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Winograd kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
1791
mlx/backend/metal/kernels/fp4_quantized.h
Normal file
1791
mlx/backend/metal/kernels/fp4_quantized.h
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user