mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
47 Commits
v0.29.4
...
3e05cea9f8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e05cea9f8 | ||
|
|
5b0f047226 | ||
|
|
618c87af8c | ||
|
|
d5f61a93fa | ||
|
|
4a09264236 | ||
|
|
0dbc7e5bee | ||
|
|
0d68efd461 | ||
|
|
f9e1a14135 | ||
|
|
d8e9ded928 | ||
|
|
60939d010c | ||
|
|
fdcd2923fd | ||
|
|
54f1cc6e3e | ||
|
|
b3825ac149 | ||
|
|
7f4b7e553c | ||
|
|
ad16f41a7f | ||
|
|
f46877bc08 | ||
|
|
6f35017d1b | ||
|
|
b167f0df1c | ||
|
|
a9f0d6b160 | ||
|
|
940f4c7818 | ||
|
|
35f81728f1 | ||
|
|
4442ed86c1 | ||
|
|
698559c231 | ||
|
|
ecc4879b07 | ||
|
|
32b18d8b66 | ||
|
|
472c43a0c8 | ||
|
|
b7214ff01e | ||
|
|
76414c8971 | ||
|
|
49e4566df3 | ||
|
|
aad49f932f | ||
|
|
86765cce34 | ||
|
|
1bedcbd556 | ||
|
|
9ac7dbe877 | ||
|
|
1bf605d56d | ||
|
|
3c622ddd1d | ||
|
|
27ff069175 | ||
|
|
3b2ffcefc3 | ||
|
|
b65f882df3 | ||
|
|
b704e9e77a | ||
|
|
66519fb348 | ||
|
|
8973550ff3 | ||
|
|
3f866be665 | ||
|
|
23f81ed1c1 | ||
|
|
3fe2250c00 | ||
|
|
047114b988 | ||
|
|
9320eb89a8 | ||
|
|
75819d70ea |
@@ -1,579 +0,0 @@
|
||||
version: 2.1
|
||||
|
||||
orbs:
|
||||
apple: ml-explore/pr-approval@0.1.0
|
||||
|
||||
parameters:
|
||||
nightly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
parameters:
|
||||
upload-docs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "26.0.0"
|
||||
resource_class: m4pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install
|
||||
command: |
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
brew install python@3.10
|
||||
brew install doxygen
|
||||
python3.10 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install -r docs/requirements.txt
|
||||
pip install . -v
|
||||
- when:
|
||||
condition:
|
||||
not: << parameters.upload-docs >>
|
||||
steps:
|
||||
- run:
|
||||
name: Build documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd docs && doxygen && make html O=-W
|
||||
- when:
|
||||
condition: << parameters.upload-docs >>
|
||||
steps:
|
||||
- add_ssh_keys:
|
||||
fingerprints:
|
||||
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
||||
- run:
|
||||
name: Upload documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
git config user.email "mlx@group.apple.com"
|
||||
git config user.name "CircleCI Docs"
|
||||
git checkout gh-pages
|
||||
git rebase main
|
||||
cd docs
|
||||
git rm -rf build/html
|
||||
doxygen && make html O=-W
|
||||
git add -f build/html
|
||||
git commit -m "rebase"
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Run style checks
|
||||
command: |
|
||||
pip install pre-commit
|
||||
pre-commit run --all
|
||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
python -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests
|
||||
|
||||
mac_build_and_test:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "26.0.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m4pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
||||
brew install openmpi uv
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv --python 3.10
|
||||
uv pip install \
|
||||
nanobind==2.4.0 \
|
||||
cmake \
|
||||
numpy \
|
||||
torch \
|
||||
tensorflow \
|
||||
unittest-xml-reporting
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
uv pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
cd examples/extensions
|
||||
uv pip install -r requirements.txt
|
||||
uv run --no-project setup.py build_ext --inplace
|
||||
uv run --no-project python test.py
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: |
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
uv pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
uv run --no-project python -m xmlrunner discover \
|
||||
-v python/tests \
|
||||
-o test-results/gpu_jit
|
||||
|
||||
cuda_build_and_test:
|
||||
parameters:
|
||||
image_date:
|
||||
type: string
|
||||
default: "2023.11.1"
|
||||
machine:
|
||||
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- cuda-<< parameters.image_date >>-{{ arch }}-
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install libnccl2 libnccl-dev
|
||||
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||
rm -rf ccache-4.11.3-linux-x86_64
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Set CCache size
|
||||
command: ccache --max-size 1G
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
cmake . -B build \
|
||||
-DMLX_BUILD_CUDA=ON \
|
||||
-DCMAKE_CUDA_COMPILER=`which nvcc` \
|
||||
-DCMAKE_BUILD_TYPE=DEBUG
|
||||
cmake --build build -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||
- run:
|
||||
name: CCache report
|
||||
command: |
|
||||
ccache --show-stats
|
||||
ccache --zero-stats
|
||||
ccache --cleanup
|
||||
- save_cache:
|
||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||
paths:
|
||||
- /home/circleci/.cache/ccache
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.10"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "26.0.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: m4pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
mkdir -p ~/miniconda3
|
||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
source ~/miniconda3/bin/activate
|
||||
conda init --all
|
||||
conda create -n env python=<< parameters.python_version >> -y
|
||||
conda activate env
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
pip install build
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
conda activate env
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
conda activate env
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.10", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
conda activate env
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
conda activate env
|
||||
twine upload dist/*
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_linux_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.10"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
PYTHON=python<< parameters.python_version >>
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
TZ=Etc/UTC sudo apt-get -y install tzdata
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
$PYTHON -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.build_env >> pip install ".[dev]" -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||
bash python/scripts/repair_linux.sh
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.10", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload packages
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
build_cuda_release:
|
||||
parameters:
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: xlarge
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install zip
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
python -m build -w
|
||||
bash python/scripts/repair_cuda.sh
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
and:
|
||||
- matches:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "15.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
and:
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
upload-docs: true
|
||||
- build_linux_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
|
||||
prb:
|
||||
when:
|
||||
matches:
|
||||
pattern: "^pull/\\d+(/head)?$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- hold:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- mac_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "15.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.nightly_build >>
|
||||
jobs:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
- build_cuda_release
|
||||
|
||||
build_dev_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["26.0.0"]
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
@@ -2,8 +2,8 @@ name: 'Build CUDA wheel'
|
||||
description: 'Build CUDA wheel'
|
||||
|
||||
inputs:
|
||||
nvcc-location:
|
||||
description: 'Location of nvcc compiler'
|
||||
toolkit:
|
||||
description: 'The CUDA toolkit'
|
||||
required: true
|
||||
|
||||
runs:
|
||||
@@ -12,7 +12,7 @@ runs:
|
||||
- name: Build package
|
||||
shell: bash
|
||||
env:
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
||||
run: |
|
||||
pip install auditwheel build patchelf setuptools
|
||||
python setup.py clean --all
|
||||
|
||||
29
.github/actions/build-cuda/action.yml
vendored
29
.github/actions/build-cuda/action.yml
vendored
@@ -2,10 +2,9 @@ name: 'Build and Test with CUDA'
|
||||
description: 'Build and test MLX with CUDA'
|
||||
|
||||
inputs:
|
||||
nvcc-location:
|
||||
description: 'Location of nvcc compiler'
|
||||
toolkit:
|
||||
description: 'The CUDA toolkit'
|
||||
required: true
|
||||
default: '/usr/local/cuda-12.9/bin/nvcc'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
@@ -14,32 +13,14 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
DEBUG: 1
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
|
||||
run: pip install -e ".[dev]" -v
|
||||
|
||||
- name: Run Python tests - CPU
|
||||
shell: bash
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
DEVICE: cpu
|
||||
run: python -m unittest discover python/tests -v
|
||||
|
||||
- name: Run Python tests - GPU
|
||||
shell: bash
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
DEVICE: gpu
|
||||
run: python -m tests discover python/tests -v
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
||||
run: pip install --no-build-isolation -e ".[dev]" -v
|
||||
|
||||
- name: Build CPP only
|
||||
shell: bash
|
||||
run: |
|
||||
cmake . -B build \
|
||||
-DMLX_BUILD_CUDA=ON \
|
||||
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
|
||||
-DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \
|
||||
-DCMAKE_BUILD_TYPE=DEBUG
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
- name: Run CPP tests
|
||||
shell: bash
|
||||
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||
|
||||
20
.github/actions/build-docs/action.yml
vendored
20
.github/actions/build-docs/action.yml
vendored
@@ -1,19 +1,19 @@
|
||||
name: 'Build Documentation'
|
||||
description: 'Build documentation on a mac'
|
||||
description: 'Build documentation'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup machine
|
||||
uses: ./.github/actions/setup-macos
|
||||
uses: ./.github/actions/setup-linux
|
||||
|
||||
- name: Install dependencies
|
||||
shell: sh
|
||||
shell: bash
|
||||
run: |
|
||||
brew install doxygen
|
||||
uv pip install --upgrade pip cmake
|
||||
uv pip install -r docs/requirements.txt
|
||||
uv pip install . -v
|
||||
sudo apt-get install -y doxygen
|
||||
source .venv/bin/activate
|
||||
pip install -r docs/requirements.txt
|
||||
pip install . -v
|
||||
|
||||
- name: Build documentation
|
||||
shell: bash
|
||||
@@ -24,8 +24,8 @@ runs:
|
||||
make html O=-W
|
||||
|
||||
- name: Create artifact tar
|
||||
shell: sh
|
||||
run: tar -cf artifact.tar --cd docs --dereference build/html index.html
|
||||
shell: bash
|
||||
run: tar -cf artifact.tar -C docs --dereference build/html index.html
|
||||
|
||||
# Do it manually because upload-pages-artifact requires gtar
|
||||
- name: Upload artifact
|
||||
@@ -35,4 +35,4 @@ runs:
|
||||
name: github-pages
|
||||
path: artifact.tar
|
||||
retention-days: 1
|
||||
if-no-files-found: error
|
||||
if-no-files-found: error
|
||||
|
||||
11
.github/actions/build-linux-release/action.yml
vendored
11
.github/actions/build-linux-release/action.yml
vendored
@@ -7,6 +7,13 @@ inputs:
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
arch:
|
||||
description: 'Platform architecture tag'
|
||||
required: true
|
||||
type: choice
|
||||
options:
|
||||
- x86_64
|
||||
- aarch64
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
@@ -23,11 +30,11 @@ runs:
|
||||
pip install auditwheel patchelf build
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 python -m build -w
|
||||
bash python/scripts/repair_linux.sh
|
||||
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
|
||||
- name: Build backend package
|
||||
if: ${{ inputs.build-backend }}
|
||||
shell: bash
|
||||
run: |
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}
|
||||
|
||||
20
.github/actions/build-linux/action.yml
vendored
20
.github/actions/build-linux/action.yml
vendored
@@ -9,33 +9,17 @@ runs:
|
||||
env:
|
||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
DEBUG: 1
|
||||
run: pip install -e ".[dev]" -v
|
||||
|
||||
run: pip install --no-build-isolation -e ".[dev]" -v
|
||||
|
||||
- name: Generate package stubs
|
||||
shell: sh
|
||||
run: |
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
|
||||
- name: Run Python tests
|
||||
shell: bash
|
||||
run: |
|
||||
python -m unittest discover python/tests -v
|
||||
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if grep -Fq '[WARN]' stderr.log ; then
|
||||
grep -F '[WARN]' stderr.log
|
||||
echo "Distributed ring test failed";
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
- name: Build CPP only
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j $(nproc)
|
||||
|
||||
- name: Run CPP tests
|
||||
shell: sh
|
||||
run: ./build/tests/tests
|
||||
|
||||
15
.github/actions/build-macos-release/action.yml
vendored
15
.github/actions/build-macos-release/action.yml
vendored
@@ -16,18 +16,19 @@ runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build Python package
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
uv pip install build
|
||||
uv run --no-project setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 uv run -m build -w
|
||||
pip install build
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 python -m build -w
|
||||
|
||||
- name: Build backend package
|
||||
if: ${{ inputs.build-backend }}
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||
run: |
|
||||
uv run --no-project setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 uv run -m build -w
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
|
||||
44
.github/actions/build-macos/action.yml
vendored
44
.github/actions/build-macos/action.yml
vendored
@@ -5,47 +5,47 @@ runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
shell: sh
|
||||
env:
|
||||
DEBUG: 1
|
||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
uv pip install --upgrade pip
|
||||
uv pip install cmake setuptools nanobind==2.4.0
|
||||
uv pip install -e . -v
|
||||
pip install --upgrade pip
|
||||
pip install cmake setuptools nanobind==2.4.0
|
||||
pip install -e . -v
|
||||
|
||||
- name: Generate package stubs
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
|
||||
- name: Install tests dependencies
|
||||
shell: sh
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
uv pip install numpy torch tensorflow unittest-xml-reporting
|
||||
pip install numpy torch tensorflow unittest-xml-reporting
|
||||
|
||||
- name: Run Python tests
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
run: |
|
||||
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
|
||||
- name: Build example extension
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
cd examples/extensions
|
||||
uv pip install -r requirements.txt
|
||||
uv run --no-project setup.py build_ext --inplace
|
||||
uv run --no-project test.py
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext --inplace
|
||||
python test.py
|
||||
|
||||
- name: Build CPP only
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
mkdir -p build
|
||||
cd build
|
||||
@@ -53,7 +53,7 @@ runs:
|
||||
make -j $(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Run CPP tests
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
DEVICE: gpu
|
||||
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||
@@ -61,7 +61,7 @@ runs:
|
||||
run: ./build/tests/tests
|
||||
|
||||
- name: Build small binary with JIT
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
mkdir -p build
|
||||
cd build
|
||||
@@ -74,7 +74,7 @@ runs:
|
||||
make -j $(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Run Python tests with JIT
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
DEVICE: gpu
|
||||
@@ -82,7 +82,7 @@ runs:
|
||||
METAL_DEBUG_ERROR_MODE: 0
|
||||
run: |
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
uv pip install -e . -v
|
||||
uv run -m xmlrunner discover \
|
||||
pip install -e . -v
|
||||
python -m xmlrunner discover \
|
||||
-v python/tests \
|
||||
-o test-results/gpu_jit
|
||||
|
||||
66
.github/actions/setup-linux/action.yml
vendored
66
.github/actions/setup-linux/action.yml
vendored
@@ -2,14 +2,10 @@ name: 'Setup Linux Environment'
|
||||
description: 'Install dependencies for Linux builds'
|
||||
|
||||
inputs:
|
||||
runner-type:
|
||||
description: 'Whether to set this up as a linux or CUDA runner'
|
||||
toolkit:
|
||||
description: 'Which toolkit to install'
|
||||
required: false
|
||||
default: 'linux'
|
||||
type: choice
|
||||
options:
|
||||
- linux
|
||||
- cuda
|
||||
default: 'cpu'
|
||||
python-version:
|
||||
description: 'Version of python to set up'
|
||||
required: false
|
||||
@@ -18,56 +14,62 @@ inputs:
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Free disk space
|
||||
shell: sh
|
||||
if: inputs.runner-type == 'linux'
|
||||
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
- name: Use ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
||||
max-size: 1GB
|
||||
|
||||
- name: Install common dependencies
|
||||
env:
|
||||
TZ: Etc/UTC
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip
|
||||
sudo apt autoremove -y
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
|
||||
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
cache: 'pip'
|
||||
|
||||
- name: setup python venv
|
||||
- name: Setup Python venv
|
||||
shell: bash
|
||||
run: |
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install setuptools cmake nanobind==2.4.0
|
||||
echo PATH=$PATH >> $GITHUB_ENV
|
||||
pip install --upgrade pip cmake
|
||||
# Make cmake search .venv for nanobind
|
||||
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
|
||||
|
||||
- name: Install MPI
|
||||
if: inputs.runner-type == 'linux'
|
||||
shell: bash
|
||||
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
|
||||
|
||||
- name: Network CUDA installation from packages
|
||||
id: install-cuda
|
||||
if: inputs.runner-type == 'cuda'
|
||||
- name: Install CUDA toolkit
|
||||
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||
shell: bash
|
||||
env:
|
||||
TZ: Etc/UTC
|
||||
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
|
||||
# Note: the CI machine does not meet CUDA 13's driver requirement.
|
||||
# Compatibility matrix:
|
||||
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
||||
# The `nvcc` is installed into `/usr/local/cuda-VERSION/bin/nvcc` - but
|
||||
# it's *not* on the default toolkit path.
|
||||
PACKAGES: |
|
||||
{
|
||||
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
|
||||
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
|
||||
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
|
||||
}
|
||||
run: |
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
export ARCH=${{ runner.arch == 'arm64' && 'arm64' || 'x86_64' }}
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
|
||||
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
|
||||
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
|
||||
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
||||
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
|
||||
sudo apt-get install -y \
|
||||
libnccl2 libnccl-dev \
|
||||
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
|
||||
|
||||
- name: Package and Driver Report
|
||||
if: inputs.runner-type == 'cuda'
|
||||
- name: CUDA packages and driver report
|
||||
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get install -y ubuntu-drivers-common dkms
|
||||
|
||||
9
.github/actions/setup-macos/action.yml
vendored
9
.github/actions/setup-macos/action.yml
vendored
@@ -17,9 +17,8 @@ runs:
|
||||
- name: Verify MetalToolchain installed
|
||||
shell: bash
|
||||
run: xcodebuild -showComponent MetalToolchain
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- uses: conda-incubator/setup-miniconda@v3
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
activate-environment: true
|
||||
miniconda-version: "latest"
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
69
.github/actions/test-linux/action.yml
vendored
Normal file
69
.github/actions/test-linux/action.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
name: 'Run Linux tests'
|
||||
|
||||
inputs:
|
||||
cpu-only:
|
||||
description: 'Skip GPU tests'
|
||||
required: false
|
||||
default: false
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Run MPI tests
|
||||
shell: bash
|
||||
run: |
|
||||
echo "::group::MPI tests"
|
||||
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run distributed tests
|
||||
if: ${{ inputs.cpu-only == 'true' }}
|
||||
shell: bash
|
||||
run: |
|
||||
echo "::group::Distributed tests"
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if grep -Fq '[WARN]' stderr.log ; then
|
||||
grep -F '[WARN]' stderr.log
|
||||
echo "Distributed ring test failed";
|
||||
exit 1;
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run Python tests - CPU
|
||||
if: ${{ inputs.cpu-only == 'true' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: cpu
|
||||
run: |
|
||||
echo "::group::Python tests - CPU"
|
||||
python -m unittest discover python/tests -v
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run Python tests - GPU
|
||||
if: ${{ inputs.cpu-only == 'false' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: gpu
|
||||
run: |
|
||||
echo "::group::Python tests - GPU"
|
||||
python -m tests discover python/tests -v
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run CPP tests - CPU
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: cpu
|
||||
run: |
|
||||
echo "::group::CPP tests - CPU"
|
||||
./build/tests/tests
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run CPP tests - GPU
|
||||
if: ${{ inputs.cpu-only == 'false' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: gpu
|
||||
run: |
|
||||
echo "::group::CPP tests - GPU"
|
||||
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||
echo "::endgroup::"
|
||||
4
.github/workflows/documentation.yml
vendored
4
.github/workflows/documentation.yml
vendored
@@ -8,7 +8,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: [self-hosted, macos]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/build-docs
|
||||
@@ -25,4 +25,4 @@ jobs:
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v4
|
||||
uses: actions/deploy-pages@v4
|
||||
|
||||
50
.github/workflows/nightly.yml
vendored
50
.github/workflows/nightly.yml
vendored
@@ -21,6 +21,7 @@ jobs:
|
||||
- uses: ./.github/actions/build-linux-release
|
||||
with:
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
arch: "x86_64"
|
||||
- name: Upload mlx artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
@@ -34,19 +35,25 @@ jobs:
|
||||
name: mlx-cpu
|
||||
path: wheelhouse/mlx_cpu-*.whl
|
||||
retention-days: 7
|
||||
|
||||
|
||||
build_linux_with_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
runs-on: ubuntu-22.04
|
||||
python_version: ["3.11", "3.12", "3.13", "3.14"]
|
||||
runner:
|
||||
- ubuntu-22.04
|
||||
- ubuntu-22.04-arm
|
||||
runs-on: ${{ matrix.runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
- uses: ./.github/actions/build-linux
|
||||
- uses: ./.github/actions/test-linux
|
||||
with:
|
||||
cpu-only: true
|
||||
|
||||
build_mac_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
@@ -60,7 +67,6 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- uses: ./.github/actions/build-macos
|
||||
|
||||
- name: Build macOS 15 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
@@ -72,16 +78,6 @@ jobs:
|
||||
macos-target: 14.0
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
|
||||
build_cuda_with_tests:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: gpu-t4-4-core
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
- uses: ./.github/actions/build-cuda
|
||||
|
||||
build_cuda_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: ubuntu-22-large
|
||||
@@ -89,36 +85,14 @@ jobs:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: mlx-cuda
|
||||
path: wheelhouse/mlx_cuda-*.whl
|
||||
retention-days: 7
|
||||
|
||||
linux_fedora_build_cpp:
|
||||
name: Linux Fedora CPP Build (${{ matrix.arch }})
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- host: ubuntu-22.04
|
||||
arch: x86_64
|
||||
- host: ubuntu-22.04-arm
|
||||
arch: aarch64
|
||||
|
||||
runs-on: ${{ matrix.host }}
|
||||
container:
|
||||
image: fedora:42
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: CPP Build Test - No Release
|
||||
run: |
|
||||
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
|
||||
|
||||
42
.github/workflows/pull_request.yml
vendored
42
.github/workflows/pull_request.yml
vendored
@@ -1,28 +1,52 @@
|
||||
name: Build and Test
|
||||
|
||||
on: pull_request
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
# For testing CI without starting a pull request:
|
||||
- test/*
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
jobs:
|
||||
check_lint:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
|
||||
linux_build_and_test:
|
||||
runs-on: ubuntu-22.04
|
||||
needs: check_lint
|
||||
strategy:
|
||||
matrix:
|
||||
runner:
|
||||
- ubuntu-22.04
|
||||
- ubuntu-22.04-arm
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
- uses: ./.github/actions/build-linux
|
||||
- uses: ./.github/actions/test-linux
|
||||
with:
|
||||
cpu-only: true
|
||||
|
||||
mac_build_and_test:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
matrix:
|
||||
macos-target: ["14.0", "15.0"]
|
||||
runs-on: [self-hosted, macos]
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
@@ -31,18 +55,25 @@ jobs:
|
||||
|
||||
cuda_build_and_test:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolkit: ['cuda-12.6', 'cuda-12.9']
|
||||
runs-on: gpu-t4-4-core
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/build-cuda
|
||||
with:
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/test-linux
|
||||
|
||||
build_documentation:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: [self-hosted, macos]
|
||||
runs-on: ubuntu-22.04
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
@@ -50,6 +81,7 @@ jobs:
|
||||
|
||||
linux_fedora_build_cpp:
|
||||
name: Linux Fedora CPP Build (${{ matrix.arch }})
|
||||
needs: check_lint
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
||||
66
.github/workflows/release.yml
vendored
66
.github/workflows/release.yml
vendored
@@ -5,6 +5,11 @@ on:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dev_release:
|
||||
description: "Do a dev release or regular release"
|
||||
required: true
|
||||
default: "false"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -12,16 +17,13 @@ permissions:
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
pypi_env: ${{ github.event_name == 'push' && 'pypi' || 'test-pypi' }}
|
||||
pypi_url: ${{ github.event_name == 'push' && 'https://upload.pypi.org/legacy/' || 'https://test.pypi.org/legacy/' }}
|
||||
steps:
|
||||
- name: Set publishing variables
|
||||
run: echo "Publishing setup complete"
|
||||
|
||||
build_documentation:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: [self-hosted, macos]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/build-docs
|
||||
@@ -45,9 +47,11 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
runs-on: ubuntu-22.04
|
||||
arch: ['x86_64', 'aarch64']
|
||||
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
@@ -56,16 +60,19 @@ jobs:
|
||||
- uses: ./.github/actions/build-linux-release
|
||||
with:
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
arch: ${{ matrix.arch }}
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: linux-wheels-${{ matrix.python_version }}
|
||||
overwrite: true
|
||||
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
|
||||
path: wheelhouse/mlx-*.whl
|
||||
- name: Upload CPU artifacts
|
||||
if: matrix.python_version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: mlx-cpu
|
||||
overwrite: true
|
||||
name: mlx-cpu-${{ matrix.arch }}
|
||||
path: wheelhouse/mlx_cpu-*.whl
|
||||
|
||||
build_mac_release:
|
||||
@@ -76,22 +83,25 @@ jobs:
|
||||
runs-on: [self-hosted, macos]
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-macos
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
shell: sh
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
uv pip install --upgrade pip
|
||||
uv pip install cmake setuptools nanobind==2.4.0
|
||||
uv pip install -e . -v
|
||||
pip install --upgrade pip
|
||||
pip install cmake setuptools nanobind==2.4.0
|
||||
pip install -e . -v
|
||||
- name: Generate package stubs
|
||||
shell: bash
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- name: Build macOS 14 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
@@ -105,12 +115,14 @@ jobs:
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mac-wheels-${{ matrix.python-version }}
|
||||
path: dist/mlx-*.whl
|
||||
- name: Upload Metal artifacts
|
||||
if: matrix.python-version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-metal
|
||||
path: dist/mlx_metal-*.whl
|
||||
|
||||
@@ -119,18 +131,20 @@ jobs:
|
||||
runs-on: ubuntu-22-large
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
runner-type: 'cuda'
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-cuda
|
||||
path: wheelhouse/mlx_cuda-*.whl
|
||||
|
||||
@@ -141,7 +155,7 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
@@ -159,7 +173,7 @@ jobs:
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
|
||||
pypi-publish-cuda:
|
||||
name: Upload CUDA release to PyPI
|
||||
@@ -168,7 +182,7 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-cuda
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
@@ -180,7 +194,7 @@ jobs:
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
|
||||
pypi-publish-cpu:
|
||||
name: Upload CPU release to PyPI
|
||||
@@ -189,19 +203,20 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-cpu
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: mlx-cpu
|
||||
pattern: mlx-cpu-*
|
||||
merge-multiple: true
|
||||
path: dist
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R dist
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
|
||||
pypi-publish-metal:
|
||||
name: Upload Metal release to PyPI
|
||||
@@ -210,7 +225,7 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: ${{ needs.setup.outputs.pypi_env }}
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-metal
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
@@ -222,5 +237,4 @@ jobs:
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
||||
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
|
||||
@@ -74,6 +74,7 @@ endif()
|
||||
if(MLX_USE_CCACHE)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
|
||||
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
|
||||
@@ -75,7 +75,7 @@ void time_irregular_binary_ops_3D() {
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
mx::Shape shape = {8, 8, 512, 512};
|
||||
auto a = mx::random::uniform(shape);
|
||||
auto b = mx::random::uniform(shape);
|
||||
|
||||
@@ -115,7 +115,7 @@ void time_irregular_binary_ops_4D() {
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape;
|
||||
mx::Shape shape;
|
||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||
return mx::reshape(a, shape, device);
|
||||
};
|
||||
@@ -170,7 +170,7 @@ void time_irregular_astype_1D() {
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
mx::Shape shape = {size, size};
|
||||
|
||||
auto a = mx::random::uniform(shape);
|
||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
212
benchmarks/python/masked_scatter.py
Normal file
212
benchmarks/python/masked_scatter.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.ticker import FuncFormatter
|
||||
|
||||
RESULTS_DIR = "./results"
|
||||
|
||||
|
||||
if not os.path.isdir(RESULTS_DIR):
|
||||
os.mkdir(RESULTS_DIR)
|
||||
|
||||
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
|
||||
|
||||
TORCH_DEVICE = torch.device(
|
||||
"mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
)
|
||||
|
||||
|
||||
N_WARMUP = 5
|
||||
N_ITER_BENCH = 50
|
||||
N_ITER_FUNC = 20
|
||||
|
||||
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
|
||||
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
|
||||
D_TYPES = ("float32", "float16")
|
||||
|
||||
|
||||
def _power_of_two_formatter(value, _position):
|
||||
if value <= 0:
|
||||
return ""
|
||||
exponent = int(round(math.log2(value)))
|
||||
if abs(value - (1 << exponent)) / value > 1e-6:
|
||||
return f"{value:g}"
|
||||
return f"$2^{{{exponent}}}$"
|
||||
|
||||
|
||||
def torch_sync():
|
||||
if TORCH_DEVICE.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif TORCH_DEVICE.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
|
||||
outs = []
|
||||
for _ in range(N_ITER_FUNC):
|
||||
out = copy(self_arr)
|
||||
out[mask_arr] = src_arr
|
||||
outs.append(out)
|
||||
mx.eval(outs)
|
||||
return outs
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
|
||||
outs = []
|
||||
for _ in range(N_ITER_FUNC):
|
||||
out = self_tensor.clone()
|
||||
out.masked_scatter_(mask_tensor, src_tensor)
|
||||
outs.append(out)
|
||||
torch_sync()
|
||||
return outs
|
||||
|
||||
|
||||
def measure(fn):
|
||||
for _ in range(N_WARMUP):
|
||||
fn()
|
||||
start = time.perf_counter_ns()
|
||||
for _ in range(N_ITER_BENCH):
|
||||
fn()
|
||||
end = time.perf_counter_ns()
|
||||
return (end - start) * 1e-9
|
||||
|
||||
|
||||
def bytes_touched(length, true_count, item_size):
|
||||
mask_bytes = length
|
||||
self_bytes = length * item_size * 2 # read + write
|
||||
src_bytes = true_count * item_size
|
||||
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
|
||||
|
||||
|
||||
def build_case(length, density, np_dtype, torch_dtype):
|
||||
true_count = max(1, int(round(length * density)))
|
||||
|
||||
rng = np.random.default_rng()
|
||||
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
|
||||
mask_np = np.zeros(length, dtype=bool)
|
||||
mask_np[:true_count] = True
|
||||
rng.shuffle(mask_np)
|
||||
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
|
||||
|
||||
self_mlx = mx.array(self_np)
|
||||
mask_mlx = mx.array(mask_np)
|
||||
src_mlx = mx.array(src_np)
|
||||
|
||||
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
|
||||
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||
|
||||
# Correctness check once per configuration
|
||||
mx_out = mx.array(self_np)
|
||||
mx_out[mask_mlx] = src_mlx
|
||||
mx.eval(mx_out)
|
||||
torch_out = self_torch.clone()
|
||||
torch_out.masked_scatter_(mask_torch, src_torch)
|
||||
|
||||
atol = 5e-3 if np_dtype == np.float16 else 1e-5
|
||||
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
|
||||
raise AssertionError("masked_scatter results diverged between MLX and Torch")
|
||||
|
||||
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
|
||||
|
||||
|
||||
def bench_case(length, density, dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
torch_dtype = getattr(torch, dtype)
|
||||
(
|
||||
self_mlx,
|
||||
mask_mlx,
|
||||
src_mlx,
|
||||
self_torch,
|
||||
mask_torch,
|
||||
src_torch,
|
||||
true_count,
|
||||
) = build_case(length, density, np_dtype, torch_dtype)
|
||||
|
||||
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
|
||||
time_torch = measure(
|
||||
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
|
||||
)
|
||||
|
||||
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
|
||||
bytes_per_gb = float(1024**3)
|
||||
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
|
||||
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
|
||||
|
||||
return time_mlx, time_torch, mlx_gbps, torch_gbps
|
||||
|
||||
|
||||
def plot_density(ax_perf, ax_speedup, density, dtype):
|
||||
mlx_gbps = []
|
||||
torch_gbps = []
|
||||
mlx_times = []
|
||||
torch_times = []
|
||||
|
||||
for length in VECTOR_LENGTHS:
|
||||
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
|
||||
mlx_gbps.append(gbps_mlx)
|
||||
torch_gbps.append(gbps_torch)
|
||||
mlx_times.append(t_mlx)
|
||||
torch_times.append(t_torch)
|
||||
|
||||
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
|
||||
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
|
||||
ax_perf.set_xscale("log", base=2)
|
||||
ax_perf.set_xticks(VECTOR_LENGTHS)
|
||||
formatter = FuncFormatter(_power_of_two_formatter)
|
||||
ax_perf.xaxis.set_major_formatter(formatter)
|
||||
ax_perf.set_title(f"density={density:.2f}")
|
||||
ax_perf.set_ylabel("GB/s")
|
||||
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||
ax_perf.legend()
|
||||
|
||||
speedup = np.array(torch_times) / np.array(mlx_times)
|
||||
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
|
||||
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
|
||||
ax_speedup.set_xscale("log", base=2)
|
||||
ax_speedup.set_xticks(VECTOR_LENGTHS)
|
||||
ax_speedup.xaxis.set_major_formatter(formatter)
|
||||
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
|
||||
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||
|
||||
|
||||
def main():
|
||||
for dtype in D_TYPES:
|
||||
fig, axs = plt.subplots(
|
||||
len(MASK_DENSITIES),
|
||||
2,
|
||||
figsize=(10, 12),
|
||||
layout="constrained",
|
||||
sharex=True,
|
||||
)
|
||||
|
||||
for i, density in enumerate(MASK_DENSITIES):
|
||||
plot_density(axs[i][0], axs[i][1], density, dtype)
|
||||
axs[i][0].set_xlabel("vector length")
|
||||
axs[i][1].set_xlabel("vector length")
|
||||
|
||||
fig.suptitle(
|
||||
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
|
||||
)
|
||||
output_path = os.path.join(
|
||||
RESULTS_DIR,
|
||||
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
|
||||
)
|
||||
fig.savefig(output_path)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
cmake/Findnvpl.cmake
Normal file
3
cmake/Findnvpl.cmake
Normal file
@@ -0,0 +1,3 @@
|
||||
# This file does nothing but to suppress the cmake warning: "By not providing
|
||||
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
|
||||
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.
|
||||
@@ -70,7 +70,8 @@ Differences from NumPy
|
||||
|
||||
* Indexing does not perform bounds checking. Indexing out of bounds is
|
||||
undefined behavior.
|
||||
* Boolean mask based indexing is not yet supported.
|
||||
* Boolean mask based indexing is supported for assignment only (see
|
||||
:ref:`boolean-mask-assignment`).
|
||||
|
||||
The reason for the lack of bounds checking is that exceptions cannot propagate
|
||||
from the GPU. Performing bounds checking for array indices before launching the
|
||||
@@ -143,3 +144,51 @@ expected. For example:
|
||||
|
||||
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
||||
and ones elsewhere.
|
||||
|
||||
.. _boolean-mask-assignment:
|
||||
|
||||
Boolean Mask Assignment
|
||||
-----------------------
|
||||
|
||||
MLX supports boolean indices using NumPy syntax. A mask must already be
|
||||
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
|
||||
Other index types are routed through the standard scatter code.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1.0, 2.0, 3.0])
|
||||
>>> mask = mx.array([True, False, True])
|
||||
>>> updates = mx.array([5.0, 6.0])
|
||||
>>> a[mask] = updates
|
||||
>>> a
|
||||
array([5.0, 2.0, 6.0], dtype=float32)
|
||||
|
||||
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
|
||||
assignments, ``updates`` must provide at least as many elements as there are
|
||||
``True`` entries in ``mask``.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.zeros((2, 3))
|
||||
>>> mask = mx.array([[True, False, True],
|
||||
[False, False, True]])
|
||||
>>> a[mask] = 1.0
|
||||
>>> a
|
||||
array([[1.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0]], dtype=float32)
|
||||
|
||||
Boolean masks follow NumPy semantics:
|
||||
|
||||
- The mask shape must match the shape of the axes it indexes exactly. No mask
|
||||
broadcasting occurs.
|
||||
- Any axes not covered by the mask are taken in full.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
||||
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||
|
||||
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
||||
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
||||
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
|
||||
axes and therefore raise errors.
|
||||
|
||||
@@ -167,7 +167,7 @@ void array::copy_shared_buffer(
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
int64_t offset /* = 0 */) {
|
||||
array_desc_->data = other.array_desc_->data;
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
|
||||
@@ -439,7 +439,7 @@ class array {
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
int64_t offset = 0);
|
||||
|
||||
void copy_shared_buffer(const array& other);
|
||||
|
||||
|
||||
@@ -14,17 +14,13 @@ std::tuple<int64_t, Strides> prepare_slice(
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
}
|
||||
// Normalize the offset
|
||||
if (data_offset < 0) {
|
||||
data_offset += in.data_size();
|
||||
}
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
int64_t data_offset,
|
||||
size_t data_size,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
@@ -51,17 +47,24 @@ void slice(
|
||||
|
||||
// Calculate out strides, initial offset
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
int64_t data_end = 1;
|
||||
for (int i = 0; i < start_indices.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
|
||||
// Get the location of the end based on the inp strides and out.shape()
|
||||
int64_t low_idx = 0;
|
||||
int64_t high_idx = 0;
|
||||
for (int i = 0; i < inp_strides.size(); ++i) {
|
||||
auto delta = inp_strides[i] * (out.shape()[i] - 1);
|
||||
if (inp_strides[i] > 0) {
|
||||
high_idx += delta;
|
||||
} else {
|
||||
low_idx += delta;
|
||||
}
|
||||
}
|
||||
if (data_end < 0) {
|
||||
data_end += in.data_size();
|
||||
int64_t data_size = (high_idx - low_idx) + 1;
|
||||
if (data_size < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[slice] Computed invalid data size: " << data_size << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
size_t data_size = (data_end - data_offset);
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,167 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
complex64_t to_complex(T r, T i) {
|
||||
return {static_cast<float>(r), static_cast<float>(i)};
|
||||
}
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct EigWork {};
|
||||
|
||||
template <typename T>
|
||||
struct EigWork<
|
||||
T,
|
||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||
using O = complex64_t;
|
||||
|
||||
char jobl;
|
||||
char jobr;
|
||||
int N;
|
||||
int lwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
||||
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
|
||||
T work;
|
||||
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
||||
int n_vecs_r = 1;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
&work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
|
||||
if (compute_eigenvectors) {
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
|
||||
}
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
}
|
||||
|
||||
void run(T* a, O* values, O* vectors) {
|
||||
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
|
||||
T* vec_tmp = nullptr;
|
||||
if (vectors) {
|
||||
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
|
||||
}
|
||||
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
|
||||
|
||||
int n_vecs_l = vectors ? N : 1;
|
||||
int n_vecs_r = 1;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
a,
|
||||
&N,
|
||||
eig_tmp,
|
||||
eig_tmp + N,
|
||||
vectors ? vec_tmp : nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
work,
|
||||
&lwork,
|
||||
&info);
|
||||
|
||||
for (int i = 0; i < N; ++i) {
|
||||
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
|
||||
}
|
||||
|
||||
if (vectors) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
if (values[i].imag() != 0) {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vectors[i * N + j] =
|
||||
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
|
||||
vectors[(i + 1) * N + j] =
|
||||
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
|
||||
}
|
||||
i += 1;
|
||||
} else {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct EigWork<std::complex<float>> {
|
||||
using T = std::complex<float>;
|
||||
using R = float;
|
||||
using O = T;
|
||||
|
||||
char jobl;
|
||||
char jobr;
|
||||
int N;
|
||||
int lwork;
|
||||
int lrwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
||||
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
|
||||
T work;
|
||||
R rwork;
|
||||
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
||||
int n_vecs_r = 1;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
&work,
|
||||
&lwork,
|
||||
&rwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work.real());
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||
}
|
||||
|
||||
void run(T* a, T* values, T* vectors) {
|
||||
int n_vecs_l = vectors ? N : 1;
|
||||
int n_vecs_r = 1;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
a,
|
||||
&N,
|
||||
values,
|
||||
vectors,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||
&info);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void eig_impl(
|
||||
array& a,
|
||||
@@ -19,101 +180,39 @@ void eig_impl(
|
||||
array& values,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using OT = std::complex<T>;
|
||||
auto a_ptr = a.data<T>();
|
||||
auto eig_ptr = values.data<OT>();
|
||||
auto val_ptr = values.data<complex64_t>();
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(values);
|
||||
OT* vec_ptr = nullptr;
|
||||
complex64_t* vec_ptr = nullptr;
|
||||
if (compute_eigenvectors) {
|
||||
encoder.set_output_array(vectors);
|
||||
vec_ptr = vectors.data<OT>();
|
||||
vec_ptr = vectors.data<complex64_t>();
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
val_ptr,
|
||||
vec_ptr,
|
||||
eig_ptr,
|
||||
compute_eigenvectors,
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
char jobr = 'N';
|
||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||
int n_vecs_r = 1;
|
||||
int n_vecs_l = compute_eigenvectors ? N : 1;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
&work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
}
|
||||
|
||||
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
||||
auto vec_tmp_data =
|
||||
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
|
||||
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
a_ptr,
|
||||
&N,
|
||||
eig_tmp,
|
||||
eig_tmp + N,
|
||||
vec_tmp,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
||||
}
|
||||
work.run(a_ptr, val_ptr, vec_ptr);
|
||||
a_ptr += N * N;
|
||||
val_ptr += N;
|
||||
if (vec_ptr) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
if (eig_ptr[i].imag() != 0) {
|
||||
// This vector and the next are a pair
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {
|
||||
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
||||
vec_ptr[(i + 1) * N + j] = {
|
||||
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
||||
}
|
||||
i += 1;
|
||||
} else {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
||||
}
|
||||
}
|
||||
}
|
||||
vec_ptr += N * N;
|
||||
}
|
||||
a_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
if (work.info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
<< work.info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -165,8 +264,17 @@ void Eig::eval_cpu(
|
||||
case float32:
|
||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||
break;
|
||||
case float64:
|
||||
eig_impl<double>(
|
||||
a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
eig_impl<std::complex<float>>(
|
||||
a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
||||
throw std::runtime_error(
|
||||
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -747,4 +747,108 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void masked_scatter_impl(const array& mask, const array& src, array& out) {
|
||||
ContiguousIterator mask_it(mask);
|
||||
ContiguousIterator src_it(src);
|
||||
ContiguousIterator out_it(out);
|
||||
|
||||
const bool* mask_ptr = mask.data<bool>();
|
||||
const T* src_ptr = src.data<T>();
|
||||
T* dst_ptr = out.data<T>();
|
||||
|
||||
const size_t batch_count = mask.shape(0);
|
||||
const size_t mask_batch_size = mask.size() / batch_count;
|
||||
const size_t src_batch_size = src.size() / batch_count;
|
||||
|
||||
for (uint b = 0; b < batch_count; ++b) {
|
||||
size_t src_consumed = 0;
|
||||
src_it.seek(b * src_batch_size);
|
||||
|
||||
for (size_t i = 0; i < mask_batch_size; ++i) {
|
||||
if (mask_ptr[mask_it.loc]) {
|
||||
if (src_consumed >= src_batch_size) {
|
||||
throw std::runtime_error(
|
||||
"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.");
|
||||
}
|
||||
dst_ptr[out_it.loc] = src_ptr[src_it.loc];
|
||||
src_it.step();
|
||||
++src_consumed;
|
||||
}
|
||||
mask_it.step();
|
||||
out_it.step();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
|
||||
auto& dst = inputs[0];
|
||||
auto& mask = inputs[1];
|
||||
auto& src = inputs[2];
|
||||
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy_cpu(dst, out, ctype, stream());
|
||||
|
||||
if (mask.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(mask);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([mask = array::unsafe_weak_copy(mask),
|
||||
src = array::unsafe_weak_copy(src),
|
||||
out = array::unsafe_weak_copy(out)]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
masked_scatter_impl<bool>(mask, src, out);
|
||||
break;
|
||||
case uint8:
|
||||
masked_scatter_impl<uint8_t>(mask, src, out);
|
||||
break;
|
||||
case uint16:
|
||||
masked_scatter_impl<uint16_t>(mask, src, out);
|
||||
break;
|
||||
case uint32:
|
||||
masked_scatter_impl<uint32_t>(mask, src, out);
|
||||
break;
|
||||
case uint64:
|
||||
masked_scatter_impl<uint64_t>(mask, src, out);
|
||||
break;
|
||||
case int8:
|
||||
masked_scatter_impl<int8_t>(mask, src, out);
|
||||
break;
|
||||
case int16:
|
||||
masked_scatter_impl<int16_t>(mask, src, out);
|
||||
break;
|
||||
case int32:
|
||||
masked_scatter_impl<int32_t>(mask, src, out);
|
||||
break;
|
||||
case int64:
|
||||
masked_scatter_impl<int64_t>(mask, src, out);
|
||||
break;
|
||||
case float16:
|
||||
masked_scatter_impl<float16_t>(mask, src, out);
|
||||
break;
|
||||
case float32:
|
||||
masked_scatter_impl<float>(mask, src, out);
|
||||
break;
|
||||
case float64:
|
||||
masked_scatter_impl<double>(mask, src, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
masked_scatter_impl<bfloat16_t>(mask, src, out);
|
||||
break;
|
||||
case complex64:
|
||||
masked_scatter_impl<complex64_t>(mask, src, out);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -45,9 +45,7 @@
|
||||
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||
INSTANTIATE_LAPACK_REAL(syevd)
|
||||
INSTANTIATE_LAPACK_REAL(geev)
|
||||
INSTANTIATE_LAPACK_REAL(potrf)
|
||||
INSTANTIATE_LAPACK_REAL(gesdd)
|
||||
INSTANTIATE_LAPACK_REAL(getrf)
|
||||
INSTANTIATE_LAPACK_REAL(getri)
|
||||
INSTANTIATE_LAPACK_REAL(trtri)
|
||||
@@ -63,3 +61,20 @@ INSTANTIATE_LAPACK_REAL(trtri)
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||
|
||||
#define INSTANTIATE_LAPACK_ALL(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, float>) { \
|
||||
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, double>) { \
|
||||
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_ALL(geev)
|
||||
INSTANTIATE_LAPACK_ALL(gesdd)
|
||||
|
||||
@@ -8,6 +8,183 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct SVDWork {};
|
||||
|
||||
template <typename T>
|
||||
struct SVDWork<
|
||||
T,
|
||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||
using R = T;
|
||||
|
||||
int N;
|
||||
int M;
|
||||
int K;
|
||||
int lda;
|
||||
int ldu;
|
||||
int ldvt;
|
||||
char jobz;
|
||||
std::vector<array::Data> buffers;
|
||||
int lwork;
|
||||
|
||||
SVDWork(int N, int M, int K, char jobz)
|
||||
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
|
||||
|
||||
int lwork_query = -1;
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesdd<T>(
|
||||
/* jobz = */ &jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
lwork = workspace_dimension;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
}
|
||||
|
||||
void run(T* a, R* s, T* u, T* vt) {
|
||||
int info;
|
||||
gesdd<T>(
|
||||
/* jobz = */ &jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ a,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ s,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ u,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ vt,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SVDWork<std::complex<float>> {
|
||||
using T = std::complex<float>;
|
||||
using R = float;
|
||||
|
||||
int N;
|
||||
int M;
|
||||
int K;
|
||||
int lda;
|
||||
int ldu;
|
||||
int ldvt;
|
||||
char jobz;
|
||||
std::vector<array::Data> buffers;
|
||||
int lwork;
|
||||
|
||||
SVDWork(int N, int M, int K, char jobz)
|
||||
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
|
||||
|
||||
const int lrwork =
|
||||
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
|
||||
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
|
||||
|
||||
int lwork_query = -1;
|
||||
int work_query = -1;
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesdd<T>(
|
||||
/* jobz = */ &jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
|
||||
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
lwork = workspace_dimension.real();
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
}
|
||||
|
||||
void run(T* a, R* s, T* u, T* vt) {
|
||||
int info;
|
||||
gesdd<T>(
|
||||
/* jobz = */ &jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ a,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ s,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ u,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ vt,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
|
||||
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void svd_impl(
|
||||
const array& a,
|
||||
@@ -27,6 +204,8 @@ void svd_impl(
|
||||
const int N = a.shape(-1);
|
||||
const int K = std::min(M, N);
|
||||
|
||||
using R = typename SVDWork<T>::R;
|
||||
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
@@ -42,7 +221,7 @@ void svd_impl(
|
||||
encoder.set_input_array(a);
|
||||
auto in_ptr = in.data<T>();
|
||||
T* u_ptr;
|
||||
T* s_ptr;
|
||||
R* s_ptr;
|
||||
T* vt_ptr;
|
||||
|
||||
if (compute_uv) {
|
||||
@@ -58,7 +237,7 @@ void svd_impl(
|
||||
encoder.set_output_array(s);
|
||||
encoder.set_output_array(vt);
|
||||
|
||||
s_ptr = s.data<T>();
|
||||
s_ptr = s.data<R>();
|
||||
u_ptr = u.data<T>();
|
||||
vt_ptr = vt.data<T>();
|
||||
} else {
|
||||
@@ -68,96 +247,26 @@ void svd_impl(
|
||||
|
||||
encoder.set_output_array(s);
|
||||
|
||||
s_ptr = s.data<T>();
|
||||
s_ptr = s.data<R>();
|
||||
u_ptr = nullptr;
|
||||
vt_ptr = nullptr;
|
||||
}
|
||||
|
||||
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
|
||||
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
|
||||
const int lda = N;
|
||||
// U of shape M x M. (N x N in lapack).
|
||||
const int ldu = N;
|
||||
// Vᵀ of shape N x N. (M x M in lapack).
|
||||
const int ldvt = M;
|
||||
|
||||
auto jobz = (u_ptr) ? "A" : "N";
|
||||
|
||||
T workspace_dimension = 0;
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
int info;
|
||||
|
||||
// Compute workspace size.
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ nullptr,
|
||||
/* u = */ nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
/* vt = */ nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ &workspace_dimension,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
|
||||
auto jobz = (u_ptr) ? 'A' : 'N';
|
||||
SVDWork<T> svd_work(N, M, K, jobz);
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
/* m = */ &N,
|
||||
/* n = */ &M,
|
||||
/* a = */ in_ptr + M * N * i,
|
||||
/* lda = */ &lda,
|
||||
/* s = */ s_ptr + K * i,
|
||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||
/* ldu = */ &ldu,
|
||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
|
||||
/* ldvt = */ &ldvt,
|
||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
svd_work.run(
|
||||
in_ptr + M * N * i,
|
||||
s_ptr + K * i,
|
||||
vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||
u_ptr ? u_ptr + M * M * i : nullptr);
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void compute_svd(
|
||||
const array& a,
|
||||
bool compute_uv,
|
||||
std::vector<array>& outputs,
|
||||
Stream stream) {}
|
||||
|
||||
void SVD::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
@@ -168,9 +277,12 @@ void SVD::eval_cpu(
|
||||
case float64:
|
||||
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[SVD::eval_cpu] only supports float32 or float64.");
|
||||
"[SVD::eval_cpu] only supports float32, float64, or complex64.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
@@ -125,7 +126,11 @@ endif()
|
||||
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
||||
# managed memory.
|
||||
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||
set(MLX_CUDA_ARCHITECTURES "native")
|
||||
execute_process(
|
||||
COMMAND bash detect_cuda_arch.sh
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
endif()
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
@@ -137,6 +142,7 @@ FetchContent_Declare(
|
||||
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
||||
FetchContent_MakeAvailable(cccl)
|
||||
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
|
||||
set_target_properties(mlx PROPERTIES CCCL_DIR "${cccl_SOURCE_DIR}/include")
|
||||
|
||||
# Use fixed version of NVTX.
|
||||
FetchContent_Declare(
|
||||
@@ -162,7 +168,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
FetchContent_Declare(
|
||||
cudnn
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||
GIT_TAG v1.14.0
|
||||
GIT_TAG v1.16.0
|
||||
GIT_SHALLOW TRUE
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||
|
||||
@@ -92,7 +92,7 @@ CudaAllocator::CudaAllocator()
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.95;
|
||||
memory_limit_ = total * 0.9;
|
||||
max_pool_size_ = memory_limit_;
|
||||
|
||||
int device_count = 0;
|
||||
@@ -119,7 +119,8 @@ void copy_to_managed(CudaBuffer& buf) {
|
||||
buf.data = new_data;
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
||||
Buffer
|
||||
CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
||||
if (size == 0) {
|
||||
return Buffer{new CudaBuffer{nullptr, 0, -1}};
|
||||
}
|
||||
@@ -134,9 +135,8 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
||||
size = page_size * ((size + page_size - 1) / page_size);
|
||||
}
|
||||
|
||||
int device = -1;
|
||||
if (size > small_block_size && stream != nullptr) {
|
||||
CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device));
|
||||
if (size <= small_block_size || stream == nullptr) {
|
||||
device = -1;
|
||||
}
|
||||
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
@@ -176,18 +176,14 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
// Copy to managed here if the buffer is not on the right device
|
||||
if (buf->device != device) {
|
||||
if (buf->device >= 0 && buf->device != device) {
|
||||
copy_to_managed(*buf);
|
||||
}
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
|
||||
return malloc_impl(size, stream);
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
return malloc_impl(size, nullptr);
|
||||
return malloc_async(size, -1, nullptr);
|
||||
}
|
||||
|
||||
void CudaAllocator::free(Buffer buffer) {
|
||||
@@ -223,9 +219,9 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
scalar_pool_.free(buf);
|
||||
} else {
|
||||
if (buf->device >= 0) {
|
||||
cudaFreeAsync(buf->data, free_streams_[buf->device]);
|
||||
CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
|
||||
} else {
|
||||
cudaFree(buf->data);
|
||||
CHECK_CUDA_ERROR(cudaFree(buf->data));
|
||||
}
|
||||
delete buf;
|
||||
}
|
||||
@@ -277,8 +273,9 @@ CudaAllocator& allocator() {
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
Buffer malloc_async(size_t size, cudaStream_t stream) {
|
||||
auto buffer = allocator().malloc_async(size, stream);
|
||||
Buffer malloc_async(size_t size, CommandEncoder& encoder) {
|
||||
auto buffer = allocator().malloc_async(
|
||||
size, encoder.device().cuda_device(), encoder.stream());
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class CommandEncoder;
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
// Stores cuda-managed unified memory.
|
||||
@@ -48,7 +50,7 @@ class SmallSizePool {
|
||||
class CudaAllocator : public allocator::Allocator {
|
||||
public:
|
||||
Buffer malloc(size_t size) override;
|
||||
Buffer malloc_async(size_t size, cudaStream_t stream);
|
||||
Buffer malloc_async(size_t size, int device, cudaStream_t stream);
|
||||
void free(Buffer buffer) override;
|
||||
size_t size(Buffer buffer) const override;
|
||||
|
||||
@@ -62,7 +64,6 @@ class CudaAllocator : public allocator::Allocator {
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
Buffer malloc_impl(size_t size, cudaStream_t stream);
|
||||
void cuda_free(CudaBuffer* buf);
|
||||
|
||||
CudaAllocator();
|
||||
@@ -80,6 +81,6 @@ class CudaAllocator : public allocator::Allocator {
|
||||
|
||||
CudaAllocator& allocator();
|
||||
|
||||
Buffer malloc_async(size_t size, cudaStream_t stream);
|
||||
Buffer malloc_async(size_t size, CommandEncoder& encoder);
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
||||
@@ -42,7 +42,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
encoder.set_output_array(out);
|
||||
|
||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||
|
||||
@@ -143,7 +143,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
Shape shape = remove_index(in.shape(), axis_);
|
||||
|
||||
@@ -367,9 +367,8 @@ void binary_op_gpu(
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
set_binary_op_output_data(
|
||||
a, b, out, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
|
||||
@@ -246,12 +246,10 @@ void binary_two_op_gpu_inplace(
|
||||
auto& out_b = outputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
set_binary_op_output_data(
|
||||
a, b, out_a, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||
set_binary_op_output_data(
|
||||
a, b, out_b, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||
|
||||
if (out_a.size() == 0) {
|
||||
return;
|
||||
|
||||
@@ -298,7 +298,7 @@ void Compiled::eval_gpu(
|
||||
// Put outputs.
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, is_constant_, contiguous, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
return cu::malloc_async(n, encoder);
|
||||
});
|
||||
for (auto& x : outputs) {
|
||||
args.append(x);
|
||||
|
||||
@@ -277,11 +277,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
array in = inputs[0];
|
||||
array wt = inputs[1];
|
||||
array out = out_;
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
Dtype dtype = out.dtype();
|
||||
|
||||
// Search cache.
|
||||
ConvCacheKey cache_key{
|
||||
BytesKey<ConvCacheKey> cache_key;
|
||||
cache_key.pod = {
|
||||
encoder.device().cuda_device(),
|
||||
dtype_to_cudnn_type(dtype),
|
||||
vector_key(in.shape()),
|
||||
|
||||
@@ -86,7 +86,7 @@ array unfold_inputs_nd(
|
||||
int mat_N,
|
||||
ConvParams<NDIM>& params) {
|
||||
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
|
||||
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
|
||||
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
|
||||
encoder.add_temporary(unfolded);
|
||||
|
||||
int filter_size = params.C;
|
||||
|
||||
@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
|
||||
int mat_N,
|
||||
ConvParams<NDIM>& params) {
|
||||
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
|
||||
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
|
||||
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
|
||||
encoder.add_temporary(unfolded);
|
||||
|
||||
int filter_size = params.C;
|
||||
|
||||
@@ -7,9 +7,8 @@ namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
bool donated = set_copy_output_data(
|
||||
in, out, ctype, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
@@ -104,7 +103,7 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
|
||||
return;
|
||||
}
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
||||
@@ -114,7 +113,7 @@ void reshape_gpu(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
|
||||
@@ -135,9 +135,7 @@ bool prepare_cudnn_plan(
|
||||
void* workspace_ptr = nullptr;
|
||||
if (workspace_size > 0) {
|
||||
array workspace(
|
||||
cu::malloc_async(workspace_size, encoder.stream()),
|
||||
{workspace_size},
|
||||
uint8);
|
||||
cu::malloc_async(workspace_size, encoder), {workspace_size}, uint8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = gpu_ptr<void>(workspace);
|
||||
}
|
||||
|
||||
@@ -44,13 +44,13 @@ inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||
// There are 2 differences from the const_param util from kernel_utils.cuh:
|
||||
// 1. The rest of array is filled with 0.
|
||||
// 2. This util can be used in .cpp files.
|
||||
template <typename T, template <typename U> class Vec>
|
||||
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
template <int NDIM = MAX_NDIM, typename T, template <typename U> class Vec>
|
||||
inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
}
|
||||
std::array<T, MAX_NDIM> result = {};
|
||||
std::array<T, NDIM> result = {};
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ std::string build_kernel(
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos) {
|
||||
const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 8192);
|
||||
kernel_source += default_header;
|
||||
@@ -81,17 +81,17 @@ std::string build_kernel(
|
||||
kernel_source += ",\n";
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
if (std::get<0>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ Shape ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_shape,\n";
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
if (std::get<1>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ Strides ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_strides,\n";
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
if (std::get<2>(shape_infos[i])) {
|
||||
kernel_source += " const __grid_constant__ int ";
|
||||
kernel_source += name;
|
||||
kernel_source += "_ndim,\n";
|
||||
@@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel(
|
||||
"[custom_kernel] Must specify at least one output.");
|
||||
}
|
||||
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
std::tuple<bool, bool, bool> shape_info;
|
||||
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
|
||||
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
|
||||
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
|
||||
@@ -254,8 +254,8 @@ std::vector<array> precompiled_cuda_kernel(
|
||||
std::optional<float> init_value,
|
||||
bool ensure_row_contiguous,
|
||||
StreamOrDevice s) {
|
||||
std::vector<CustomKernelShapeInfo> shape_infos(
|
||||
inputs.size(), CustomKernelShapeInfo{false, false, false});
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos(
|
||||
inputs.size(), {false, false, false});
|
||||
return array::make_arrays(
|
||||
output_shapes,
|
||||
output_dtypes,
|
||||
@@ -289,7 +289,7 @@ void CustomKernel::eval_gpu(
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
fill_gpu(copies.back(), out, s);
|
||||
} else {
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,13 +327,13 @@ void CustomKernel::eval_gpu(
|
||||
const array& in = checked_inputs[i];
|
||||
auto& shape_info = shape_infos_[i];
|
||||
args.append(in);
|
||||
if (shape_info.shape) {
|
||||
if (std::get<0>(shape_info)) {
|
||||
args.append_ndim(in.shape());
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
if (std::get<1>(shape_info)) {
|
||||
args.append_ndim(in.strides());
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
if (std::get<2>(shape_info)) {
|
||||
args.append<int32_t>(in.ndim());
|
||||
}
|
||||
}
|
||||
|
||||
13
mlx/backend/cuda/detect_cuda_arch.sh
Normal file
13
mlx/backend/cuda/detect_cuda_arch.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
arch=`__nvcc_device_query`
|
||||
case "$arch" in
|
||||
"90")
|
||||
echo "90a" ;;
|
||||
"100")
|
||||
echo "100a" ;;
|
||||
"121")
|
||||
echo "121a" ;;
|
||||
*)
|
||||
echo "native" ;;
|
||||
esac
|
||||
@@ -46,6 +46,7 @@ Device::Device(int device) : device_(device) {
|
||||
"Device {} does not support synchronization in managed memory.",
|
||||
device_));
|
||||
}
|
||||
|
||||
// The cublasLt handle is used by matmul.
|
||||
make_current();
|
||||
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
||||
@@ -114,18 +115,17 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
||||
}
|
||||
|
||||
// Use an empty graph node for synchronization
|
||||
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
||||
enc.empty_node_count_++;
|
||||
CommandEncoder::GraphNode empty{NULL, "E", std::to_string(enc.node_count_++)};
|
||||
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
|
||||
|
||||
// Insert the concurrent -> empty node dependencies
|
||||
for (auto& from : enc.concurrent_nodes_) {
|
||||
enc.from_nodes_.push_back(from.node);
|
||||
enc.to_nodes_.push_back(empty.node);
|
||||
enc.graph_key_ += from.id;
|
||||
enc.graph_key_ += from.node_type;
|
||||
enc.graph_key_ += empty.id;
|
||||
enc.graph_key_ += empty.node_type;
|
||||
enc.graph_deps_key_ += from.id;
|
||||
enc.graph_deps_key_ += "-";
|
||||
enc.graph_deps_key_ += empty.id;
|
||||
enc.graph_deps_key_ += "-";
|
||||
}
|
||||
|
||||
// Insert the input -> concurrent node dependencies without updating output
|
||||
@@ -140,9 +140,6 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
||||
}
|
||||
|
||||
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
||||
if (node.node_type == 'G') {
|
||||
graph_node_count_++;
|
||||
}
|
||||
node.id = std::to_string(node_count_++);
|
||||
if (in_concurrent_) {
|
||||
concurrent_nodes_.push_back(std::move(node));
|
||||
@@ -154,6 +151,10 @@ void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
||||
}
|
||||
|
||||
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
for (auto& node : nodes) {
|
||||
graph_nodes_key_ += node.node_type;
|
||||
graph_nodes_key_ += "-";
|
||||
}
|
||||
std::vector<GraphNode> deps;
|
||||
{
|
||||
// Dependencies must be added in the same order to produce a consistent
|
||||
@@ -181,20 +182,49 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
for (auto& to : nodes) {
|
||||
from_nodes_.push_back(from.node);
|
||||
to_nodes_.push_back(to.node);
|
||||
graph_key_ += from.id;
|
||||
graph_key_ += from.node_type;
|
||||
graph_key_ += to.id;
|
||||
graph_key_ += to.node_type;
|
||||
graph_deps_key_ += from.id;
|
||||
graph_deps_key_ += "-";
|
||||
graph_deps_key_ += to.id;
|
||||
graph_deps_key_ += "-";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER, MLX_MAX_MB_PER_BUFFER
|
||||
std::pair<int, int> get_graph_limits(Device& d) {
|
||||
auto cc =
|
||||
d.compute_capability_major() * 100 + d.compute_capability_minor() * 10;
|
||||
int ops = 20;
|
||||
int mb = 100;
|
||||
switch (cc) {
|
||||
case 800: // A100
|
||||
ops = 20;
|
||||
mb = 400;
|
||||
break;
|
||||
case 900: // H100
|
||||
ops = 30;
|
||||
mb = 400;
|
||||
break;
|
||||
case 1000: // B200
|
||||
ops = 50;
|
||||
mb = 500;
|
||||
break;
|
||||
case 1210: // DGX Spark
|
||||
ops = 20;
|
||||
mb = 25;
|
||||
break;
|
||||
}
|
||||
return {env::max_ops_per_buffer(ops), env::max_mb_per_buffer(mb)};
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(Device& d)
|
||||
: device_(d),
|
||||
stream_(d),
|
||||
graph_(d),
|
||||
worker_(d),
|
||||
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
|
||||
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {
|
||||
std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d);
|
||||
}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
@@ -204,6 +234,7 @@ void CommandEncoder::set_input_array(const array& arr) {
|
||||
if (!use_cuda_graphs()) {
|
||||
return;
|
||||
}
|
||||
bytes_in_graph_ += arr.data_size();
|
||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||
active_deps_.push_back(id);
|
||||
}
|
||||
@@ -278,13 +309,46 @@ void CommandEncoder::add_kernel_node(
|
||||
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
insert_graph_dependencies(GraphNode{node, "K"});
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
CUgraphNode node;
|
||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
insert_graph_dependencies(GraphNode{node, "K"});
|
||||
}
|
||||
|
||||
bool is_graph_updatable(cudaGraph_t graph, int& cluster_dim_x) {
|
||||
// CUDA graphs do not get updated correctly if a kernel node getting updated
|
||||
// has a different cluster shape than the node it's being updated with.
|
||||
size_t num_nodes = 0;
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
||||
if (num_nodes == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
||||
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
||||
for (const auto& node : nodes) {
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||
if (type != cudaGraphNodeTypeKernel) {
|
||||
return false;
|
||||
}
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
// Only dim.x can be greater than 1
|
||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||
return false;
|
||||
}
|
||||
// Only one child node allowed when subgraph uses clusters
|
||||
if (cluster_dim.clusterDim.x > 0 && num_nodes > 1) {
|
||||
return false;
|
||||
}
|
||||
cluster_dim_x = cluster_dim.clusterDim.x;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
@@ -297,12 +361,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
return;
|
||||
}
|
||||
cudaGraphNode_t node;
|
||||
int cluster_dim_x = 0;
|
||||
is_graph_updatable_ = is_graph_updatable(child, cluster_dim_x);
|
||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
insert_graph_dependencies(
|
||||
GraphNode{node, "G" + std::to_string(cluster_dim_x)});
|
||||
}
|
||||
|
||||
int CommandEncoder::get_num_ops() {
|
||||
return node_count_;
|
||||
bool CommandEncoder::needs_commit() {
|
||||
return (node_count_ > max_ops_per_graph_) ||
|
||||
((bytes_in_graph_ >> 20) > max_mb_per_graph_);
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
@@ -322,53 +390,55 @@ void CommandEncoder::commit() {
|
||||
from_nodes_.size()));
|
||||
}
|
||||
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(node_count_);
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(graph_node_count_);
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(empty_node_count_);
|
||||
|
||||
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
|
||||
|
||||
if (graph_exec != nullptr) {
|
||||
cudaGraphExecUpdateResult update_result;
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo info;
|
||||
cudaGraphExecUpdate(graph_exec, graph_, &info);
|
||||
update_result = info.result;
|
||||
#else
|
||||
cudaGraphNode_t error_node;
|
||||
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
if (update_result != cudaGraphExecUpdateSuccess) {
|
||||
cudaGetLastError(); // reset error
|
||||
graph_exec.reset();
|
||||
}
|
||||
}
|
||||
if (graph_exec == nullptr) {
|
||||
graph_exec.instantiate(graph_);
|
||||
}
|
||||
device_.make_current();
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
|
||||
if (!is_graph_updatable_) {
|
||||
CudaGraphExec graph_exec;
|
||||
graph_exec.instantiate(graph_);
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
} else {
|
||||
auto graph_key = graph_nodes_key_ + ":" + graph_deps_key_;
|
||||
auto& graph_exec = graph_cache_[graph_key];
|
||||
|
||||
if (graph_exec != nullptr) {
|
||||
cudaGraphExecUpdateResult update_result;
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo info;
|
||||
cudaGraphExecUpdate(graph_exec, graph_, &info);
|
||||
update_result = info.result;
|
||||
#else
|
||||
cudaGraphNode_t error_node;
|
||||
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
if (update_result != cudaGraphExecUpdateSuccess) {
|
||||
cudaGetLastError(); // reset error
|
||||
graph_exec.reset();
|
||||
}
|
||||
}
|
||||
if (graph_exec == nullptr) {
|
||||
graph_exec.instantiate(graph_);
|
||||
}
|
||||
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
}
|
||||
// Reset state
|
||||
graph_node_count_ = 0;
|
||||
empty_node_count_ = 0;
|
||||
from_nodes_.clear();
|
||||
to_nodes_.clear();
|
||||
graph_key_.clear();
|
||||
graph_deps_key_.clear();
|
||||
graph_nodes_key_.clear();
|
||||
node_map_.clear();
|
||||
graph_ = CudaGraph(device_);
|
||||
is_graph_updatable_ = true;
|
||||
}
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.commit(stream_);
|
||||
node_count_ = 0;
|
||||
bytes_in_graph_ = 0;
|
||||
}
|
||||
|
||||
void CommandEncoder::synchronize() {
|
||||
cudaStreamSynchronize(stream_);
|
||||
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream_));
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
std::future<void> f = p->get_future();
|
||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||
|
||||
@@ -84,7 +84,7 @@ class CommandEncoder {
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
int get_num_ops();
|
||||
bool needs_commit();
|
||||
void commit();
|
||||
|
||||
Device& device() {
|
||||
@@ -106,8 +106,9 @@ class CommandEncoder {
|
||||
cudaGraphNode_t node;
|
||||
// K = kernel
|
||||
// E = empty
|
||||
// G = subgraph
|
||||
char node_type;
|
||||
// G* = subgraph (with metadata)
|
||||
// Symbols ':', '-' are reserved as separators
|
||||
std::string node_type;
|
||||
std::string id;
|
||||
};
|
||||
|
||||
@@ -119,18 +120,21 @@ class CommandEncoder {
|
||||
CudaGraph graph_;
|
||||
Worker worker_;
|
||||
char node_count_{0};
|
||||
char graph_node_count_{0};
|
||||
char empty_node_count_{0};
|
||||
bool in_concurrent_{false};
|
||||
std::vector<cudaGraphNode_t> from_nodes_;
|
||||
std::vector<cudaGraphNode_t> to_nodes_;
|
||||
std::string graph_key_;
|
||||
std::string graph_nodes_key_;
|
||||
std::string graph_deps_key_;
|
||||
std::vector<GraphNode> concurrent_nodes_;
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
||||
std::vector<std::uintptr_t> active_deps_;
|
||||
std::vector<std::uintptr_t> active_outputs_;
|
||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||
size_t bytes_in_graph_{0};
|
||||
bool is_graph_updatable_{true};
|
||||
int max_ops_per_graph_;
|
||||
int max_mb_per_graph_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
@@ -166,6 +170,7 @@ class Device {
|
||||
int device_;
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
std::string device_name_;
|
||||
cublasLtHandle_t lt_;
|
||||
cudnnHandle_t cudnn_;
|
||||
std::unordered_map<int, CommandEncoder> encoders_;
|
||||
|
||||
@@ -26,7 +26,7 @@ void AllReduce::eval_gpu(
|
||||
out.copy_shared_buffer(in);
|
||||
return {in, out};
|
||||
} else {
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
return {in, out};
|
||||
}
|
||||
};
|
||||
@@ -74,7 +74,7 @@ void AllGather::eval_gpu(
|
||||
};
|
||||
|
||||
auto input = ensure_contiguous(inputs[0]);
|
||||
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
|
||||
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
|
||||
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(outputs[0]);
|
||||
@@ -103,7 +103,7 @@ void ReduceScatter::eval_gpu(
|
||||
};
|
||||
|
||||
auto input = ensure_contiguous(inputs[0]);
|
||||
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
|
||||
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
|
||||
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(outputs[0]);
|
||||
|
||||
@@ -11,9 +11,6 @@
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
@@ -53,8 +50,7 @@ void eval(array& arr) {
|
||||
encoder.add_temporary(s);
|
||||
}
|
||||
|
||||
if (encoder.get_num_ops() >=
|
||||
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||
if (encoder.needs_commit()) {
|
||||
scheduler::notify_new_task(stream);
|
||||
encoder.add_completed_handler(
|
||||
[stream]() { scheduler::notify_task_completion(stream); });
|
||||
|
||||
@@ -370,7 +370,7 @@ void CublasGemm::execute(
|
||||
// Ensure workspace is 256-byte aligned
|
||||
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
|
||||
array workspace(
|
||||
cu::malloc_async(nbytes, encoder.stream()),
|
||||
cu::malloc_async(nbytes, encoder),
|
||||
{static_cast<int>(heuristic_.workspaceSize)},
|
||||
int8);
|
||||
encoder.add_temporary(workspace);
|
||||
|
||||
@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
|
||||
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder),
|
||||
{batch_count * 3},
|
||||
uint64);
|
||||
|
||||
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
|
||||
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder),
|
||||
{batch_count * 4},
|
||||
uint64);
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -241,7 +241,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -279,11 +279,14 @@ void compile(
|
||||
// Compile program.
|
||||
std::vector<const char*> args;
|
||||
bool use_sass = compiler_supports_device_sass(device);
|
||||
auto cc = device.compute_capability_major();
|
||||
std::string arch_tag = (cc == 90 || cc == 100 || cc == 121) ? "a" : "";
|
||||
std::string compute = fmt::format(
|
||||
"--gpu-architecture={}_{}{}",
|
||||
"--gpu-architecture={}_{}{}{}",
|
||||
use_sass ? "sm" : "compute",
|
||||
device.compute_capability_major(),
|
||||
device.compute_capability_minor());
|
||||
cc,
|
||||
device.compute_capability_minor(),
|
||||
arch_tag);
|
||||
args.push_back(compute.c_str());
|
||||
std::string cccl_include = cccl_dir();
|
||||
if (!cccl_include.empty()) {
|
||||
|
||||
@@ -244,7 +244,7 @@ void LayerNorm::eval_gpu(
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
|
||||
gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
|
||||
g_in_gw = true;
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
|
||||
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
auto size = out.size();
|
||||
auto nbytes = size * out.itemsize();
|
||||
out.set_data(cu::malloc_async(nbytes, encoder.stream()));
|
||||
out.set_data(cu::malloc_async(nbytes, encoder));
|
||||
auto out_ptr = malloc(nbytes);
|
||||
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
|
||||
if (swap_endianness_) {
|
||||
|
||||
@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
cu::malloc_async(in.nbytes() / n, encoder.stream()),
|
||||
cu::malloc_async(in.nbytes() / n, encoder),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
|
||||
@@ -135,12 +135,19 @@ class LRUCache {
|
||||
};
|
||||
|
||||
// Turn a POD struct into a container key by doing bytes compare.
|
||||
//
|
||||
// Usage:
|
||||
// BytesKey<MyKey> key;
|
||||
// key.pod = { ... };
|
||||
template <typename T>
|
||||
struct BytesKey {
|
||||
T pod;
|
||||
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||
|
||||
BytesKey(T pod) : pod(std::move(pod)) {}
|
||||
BytesKey() {
|
||||
// Make sure the paddings between members are filled with 0.
|
||||
memset(&pod, 0, sizeof(T));
|
||||
}
|
||||
|
||||
BytesKey(const BytesKey& other) {
|
||||
memcpy(&pod, &other.pod, sizeof(T));
|
||||
|
||||
@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
|
||||
int M = a_pre.shape(-2);
|
||||
int N = b_pre.shape(-1);
|
||||
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
|
||||
c.data_size() == out.shape(-1)) {
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
gemm_and_bias(
|
||||
encoder,
|
||||
M,
|
||||
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto sty = c.strides()[c.ndim() - 1];
|
||||
if (sty == 1 && stx == c.shape(-1)) {
|
||||
ldc = stx;
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
} else if (sty == 1 && stx == 0) {
|
||||
ldc = 0;
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
} else {
|
||||
// Copy C into out and set C to out
|
||||
ldc = c.shape(-1);
|
||||
|
||||
@@ -37,6 +37,7 @@ NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
NO_GPU(MaskedScatter)
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(Send)
|
||||
|
||||
@@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu(
|
||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||
auto& w = outputs[0];
|
||||
|
||||
w.set_data(cu::malloc_async(w.nbytes(), enc.stream()));
|
||||
w.set_data(cu::malloc_async(w.nbytes(), enc));
|
||||
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||
@@ -72,11 +72,11 @@ void fast::Quantize::eval_gpu(
|
||||
auto& wq = outputs[0];
|
||||
auto& scales = outputs[1];
|
||||
|
||||
wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream()));
|
||||
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream()));
|
||||
wq.set_data(cu::malloc_async(wq.nbytes(), enc));
|
||||
scales.set_data(cu::malloc_async(scales.nbytes(), enc));
|
||||
if (mode_ == QuantizationMode::Affine) {
|
||||
auto& biases = outputs[2];
|
||||
biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream()));
|
||||
biases.set_data(cu::malloc_async(biases.nbytes(), enc));
|
||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||
} else {
|
||||
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
|
||||
|
||||
@@ -139,30 +139,36 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
auto& keys = inputs[0];
|
||||
uint32_t num_keys = keys.size() / 2;
|
||||
size_t num_keys = keys.size() / 2;
|
||||
|
||||
uint32_t elems_per_key = out.size() / num_keys;
|
||||
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
uint32_t half_size = out_per_key / 2;
|
||||
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||
size_t half_size = out_per_key / 2;
|
||||
|
||||
bool odd = out_per_key % 2;
|
||||
if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {
|
||||
throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported");
|
||||
}
|
||||
|
||||
encoder.set_input_array(keys);
|
||||
encoder.set_output_array(out);
|
||||
dim3 grid_dims{num_keys, half_size + odd};
|
||||
int64_t total = grid_dims.x * grid_dims.y;
|
||||
int32_t threads_y = 1;
|
||||
while ((total / threads_y) >= (1U << 31)) {
|
||||
int64_t total = num_keys * (half_size + odd);
|
||||
uint32_t threads_y = 1;
|
||||
while ((total / threads_y) >= UINT_MAX) {
|
||||
threads_y *= 2;
|
||||
}
|
||||
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||
uint32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||
|
||||
dim3 grid_dims{
|
||||
static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};
|
||||
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||
auto& stream = encoder.stream();
|
||||
if (keys.flags().row_contiguous) {
|
||||
|
||||
@@ -66,7 +66,7 @@ void all_reduce(
|
||||
Reduce::ReduceType reduce_type) {
|
||||
constexpr int N_READS = 8;
|
||||
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
|
||||
auto get_args = [](size_t size, int N) {
|
||||
int threads = std::min(512UL, (size + N - 1) / N);
|
||||
@@ -107,8 +107,7 @@ void all_reduce(
|
||||
encoder.set_input_array(in);
|
||||
if (blocks > 1) {
|
||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||
intermediate.set_data(
|
||||
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
|
||||
intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.set_output_array(intermediate);
|
||||
dispatch_all_types(dt, [&](auto type_tag) {
|
||||
|
||||
@@ -28,7 +28,7 @@ void init_reduce(
|
||||
Reduce::ReduceType reduce_type) {
|
||||
// Allocate if needed
|
||||
if (out.data_shared_ptr() == nullptr) {
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
}
|
||||
|
||||
encoder.set_output_array(out);
|
||||
|
||||
@@ -96,7 +96,7 @@ inline void allocate_same_layout(
|
||||
const std::vector<int>& axes,
|
||||
cu::CommandEncoder& encoder) {
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ inline void allocate_same_layout(
|
||||
fl.col_contiguous = cc;
|
||||
fl.contiguous = true;
|
||||
out.set_data(
|
||||
cu::malloc_async(out.nbytes(), encoder.stream()),
|
||||
cu::malloc_async(out.nbytes(), encoder),
|
||||
data_size,
|
||||
final_strides,
|
||||
fl,
|
||||
|
||||
@@ -190,7 +190,7 @@ void RMSNorm::eval_gpu(
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@@ -274,7 +274,7 @@ void RMSNormVJP::eval_gpu(
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
|
||||
gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
encoder.add_temporary(g);
|
||||
@@ -292,7 +292,7 @@ void RMSNormVJP::eval_gpu(
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
|
||||
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
|
||||
encoder.add_temporary(gw_temp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,14 +292,14 @@ void RoPE::eval_gpu(
|
||||
donated = true;
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
}
|
||||
strides[0] = mat_size;
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
} else if (dispatch_ndim == 3) {
|
||||
// Handle non-contiguous 3D inputs
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
strides[0] = in.strides()[ndim - 3];
|
||||
strides[1] = in.strides()[ndim - 2];
|
||||
strides[2] = in.strides()[ndim - 1];
|
||||
|
||||
537
mlx/backend/cuda/scaled_dot_product_attention.cpp
Normal file
537
mlx/backend/cuda/scaled_dot_product_attention.cpp
Normal file
@@ -0,0 +1,537 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace fe = cudnn_frontend;
|
||||
|
||||
namespace {
|
||||
|
||||
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
||||
do { \
|
||||
auto error = cmd; \
|
||||
if (!error.is_good()) { \
|
||||
throw std::runtime_error( \
|
||||
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
std::vector<int64_t> normalized_strides(const array& x) {
|
||||
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
|
||||
if (std::all_of(
|
||||
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
|
||||
strides.back() = 1;
|
||||
return strides;
|
||||
}
|
||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||
return strides;
|
||||
}
|
||||
for (int i = x.ndim() - 2; i >= 0; --i) {
|
||||
if (x.shape(i) == 1) {
|
||||
strides[i] = x.shape(i + 1) * strides[i + 1];
|
||||
}
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
void set_tensor_attrs(
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||
int64_t uid,
|
||||
const array& x) {
|
||||
tensor->set_uid(uid)
|
||||
.set_dim({x.shape().begin(), x.shape().end()})
|
||||
.set_stride(normalized_strides(x));
|
||||
}
|
||||
|
||||
array prepare_sdpa_input(const array& x, Stream s) {
|
||||
// SDPA kernel's requirements on inputs:
|
||||
// 1. last dim's stride be 1;
|
||||
// 2. pointer be aligned.
|
||||
if (x.strides(-1) != 1 || get_alignment(x) < 16) {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
constexpr int QKV_NDIM = 4;
|
||||
|
||||
struct SDPACacheKey {
|
||||
int device_id;
|
||||
cudnnDataType_t cudnn_dtype;
|
||||
std::array<int, QKV_NDIM> q_shape;
|
||||
std::array<int, QKV_NDIM> k_shape;
|
||||
std::array<int, QKV_NDIM> v_shape;
|
||||
std::array<int64_t, QKV_NDIM> q_strides;
|
||||
std::array<int64_t, QKV_NDIM> k_strides;
|
||||
std::array<int64_t, QKV_NDIM> v_strides;
|
||||
bool do_causal;
|
||||
bool output_logsumexp;
|
||||
};
|
||||
|
||||
inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
||||
cu::CommandEncoder& encoder,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
bool output_logsumexp = true) {
|
||||
BytesKey<SDPACacheKey> cache_key;
|
||||
cache_key.pod = {
|
||||
encoder.device().cuda_device(),
|
||||
dtype_to_cudnn_type(q.dtype()),
|
||||
vector_key<QKV_NDIM>(q.shape()),
|
||||
vector_key<QKV_NDIM>(k.shape()),
|
||||
vector_key<QKV_NDIM>(v.shape()),
|
||||
vector_key<QKV_NDIM>(q.strides()),
|
||||
vector_key<QKV_NDIM>(k.strides()),
|
||||
vector_key<QKV_NDIM>(v.strides()),
|
||||
do_causal,
|
||||
output_logsumexp,
|
||||
};
|
||||
return cache_key;
|
||||
}
|
||||
|
||||
auto& sdpa_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 16);
|
||||
return cache;
|
||||
}
|
||||
|
||||
auto& sdpa_backward_cache() {
|
||||
static LRUBytesKeyCache<SDPACacheKey, fe::graph::Graph> cache(
|
||||
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 16);
|
||||
return cache;
|
||||
}
|
||||
|
||||
enum UIDS {
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
SCALE,
|
||||
O,
|
||||
STATS,
|
||||
// Backward graph:
|
||||
D_Q,
|
||||
D_K,
|
||||
D_V,
|
||||
D_O,
|
||||
};
|
||||
|
||||
fe::graph::Graph build_sdpa_graph(
|
||||
cudnnHandle_t handle,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
bool output_logsumexp,
|
||||
const array& o,
|
||||
const array& stats) {
|
||||
auto dtype = fe::DataType_t::HALF;
|
||||
if (q.dtype() == bfloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
|
||||
fe::graph::Graph graph;
|
||||
graph.set_io_data_type(dtype)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
|
||||
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
|
||||
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
|
||||
set_tensor_attrs(q_, Q, q);
|
||||
set_tensor_attrs(k_, K, k);
|
||||
set_tensor_attrs(v_, V, v);
|
||||
|
||||
auto scale = graph.tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Scale")
|
||||
.set_uid(SCALE)
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
|
||||
auto options = fe::graph::SDPA_attributes()
|
||||
.set_name("sdpa_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal)
|
||||
.set_generate_stats(output_logsumexp);
|
||||
|
||||
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
||||
o_->set_output(true);
|
||||
set_tensor_attrs(o_, O, o);
|
||||
if (output_logsumexp) {
|
||||
stats_->set_output(true).set_data_type(fe::DataType_t::FLOAT);
|
||||
set_tensor_attrs(stats_, STATS, stats);
|
||||
}
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph.validate());
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
|
||||
graph.select_behavior_notes(
|
||||
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
fe::graph::Graph build_sdpa_backward_graph(
|
||||
cudnnHandle_t handle,
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool do_causal,
|
||||
const array& o,
|
||||
const array& d_o,
|
||||
const array& stats,
|
||||
array& d_q,
|
||||
array& d_k,
|
||||
array& d_v) {
|
||||
auto dtype = fe::DataType_t::HALF;
|
||||
if (q.dtype() == bfloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
|
||||
fe::graph::Graph graph;
|
||||
graph.set_io_data_type(dtype)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto q_ = graph.tensor(fe::graph::Tensor_attributes().set_name("Q"));
|
||||
auto k_ = graph.tensor(fe::graph::Tensor_attributes().set_name("K"));
|
||||
auto v_ = graph.tensor(fe::graph::Tensor_attributes().set_name("V"));
|
||||
auto o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("O"));
|
||||
auto d_o_ = graph.tensor(fe::graph::Tensor_attributes().set_name("D_O"));
|
||||
auto stats_ = graph.tensor(fe::graph::Tensor_attributes().set_name("STATS"));
|
||||
set_tensor_attrs(q_, Q, q);
|
||||
set_tensor_attrs(k_, K, k);
|
||||
set_tensor_attrs(v_, V, v);
|
||||
set_tensor_attrs(o_, O, o);
|
||||
set_tensor_attrs(d_o_, D_O, d_o);
|
||||
set_tensor_attrs(stats_, STATS, stats);
|
||||
stats_->set_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto scale = graph.tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Scale")
|
||||
.set_uid(SCALE)
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
|
||||
auto options = fe::graph::SDPA_backward_attributes()
|
||||
.set_name("sdpa_backward_cudnn")
|
||||
.set_attn_scale(scale)
|
||||
.set_causal_mask(do_causal);
|
||||
|
||||
auto [d_q_, d_k_, d_v_] =
|
||||
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
||||
d_q_->set_output(true);
|
||||
d_k_->set_output(true);
|
||||
d_v_->set_output(true);
|
||||
set_tensor_attrs(d_q_, D_Q, d_q);
|
||||
set_tensor_attrs(d_k_, D_K, d_k);
|
||||
set_tensor_attrs(d_v_, D_V, d_v);
|
||||
|
||||
CHECK_CUDNN_FE_ERROR(graph.validate());
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_operation_graph(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.create_execution_plans({fe::HeurMode_t::A}));
|
||||
graph.select_behavior_notes(
|
||||
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||
CHECK_CUDNN_FE_ERROR(graph.check_support(handle));
|
||||
CHECK_CUDNN_FE_ERROR(graph.build_plans(handle));
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
void execute_graph(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnHandle_t handle,
|
||||
fe::graph::Graph& graph,
|
||||
std::unordered_map<int64_t, void*>& variant_pack) {
|
||||
int64_t workspace_size = 0;
|
||||
CHECK_CUDNN_FE_ERROR(graph.get_workspace_size(workspace_size));
|
||||
void* workspace_ptr = nullptr;
|
||||
if (workspace_size > 0) {
|
||||
array workspace(
|
||||
cu::malloc_async(workspace_size, encoder),
|
||||
{static_cast<int>(workspace_size)},
|
||||
uint8);
|
||||
encoder.add_temporary(workspace);
|
||||
workspace_ptr = gpu_ptr<void>(workspace);
|
||||
}
|
||||
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
CudaGraph cuda_graph(encoder.device());
|
||||
CHECK_CUDNN_FE_ERROR(graph.populate_cuda_graph(
|
||||
handle, variant_pack, workspace_ptr, cuda_graph));
|
||||
encoder.add_graph_node(cuda_graph);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool supports_sdpa_cudnn(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
|
||||
if (!enabled) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// cuDNN SDPA requires Ampere and later.
|
||||
if (cu::device(s.device).compute_capability_major() < 8) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (has_mask) {
|
||||
// TODO: Support array masks.
|
||||
if (!do_causal) {
|
||||
return false;
|
||||
}
|
||||
// FIXME: Causal mask generates wrong results when L_Q != L_K.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Only use cuDNN for prefilling and training.
|
||||
if (q.shape(2) != k.shape(2)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// D_qk and D_v must be a multiple of 8 with maximum value 128.
|
||||
if ((q.shape(-1) % 8 != 0) || (q.shape(-1) > 128) || (v.shape(-1) % 8 != 0) ||
|
||||
(v.shape(-1) > 128)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Dtype dtype = q.dtype();
|
||||
return dtype == float16 || dtype == bfloat16;
|
||||
}
|
||||
|
||||
void sdpa_cudnn(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
float scale,
|
||||
array& o,
|
||||
array& stats,
|
||||
bool do_causal,
|
||||
bool output_logsumexp,
|
||||
Stream s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
|
||||
// TODO: Handle donation.
|
||||
// TODO: Make O use same memory layout with Q.
|
||||
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
encoder.set_output_array(o);
|
||||
|
||||
if (output_logsumexp) {
|
||||
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
|
||||
encoder.set_output_array(stats);
|
||||
}
|
||||
|
||||
// Search cache.
|
||||
auto cache_key =
|
||||
build_sdpa_cache_key(encoder, q, k, v, do_causal, output_logsumexp);
|
||||
auto it = sdpa_cache().find(cache_key);
|
||||
if (it == sdpa_cache().end()) {
|
||||
auto graph = build_sdpa_graph(
|
||||
handle, q, k, v, do_causal, output_logsumexp, o, stats);
|
||||
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
|
||||
}
|
||||
auto& graph = it->second;
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack{
|
||||
{Q, const_cast<void*>(gpu_ptr<void>(q))},
|
||||
{K, const_cast<void*>(gpu_ptr<void>(k))},
|
||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||
{SCALE, &scale},
|
||||
{O, gpu_ptr<void>(o)}};
|
||||
if (output_logsumexp) {
|
||||
variant_pack[STATS] = gpu_ptr<void>(stats);
|
||||
}
|
||||
|
||||
execute_graph(encoder, handle, graph, variant_pack);
|
||||
}
|
||||
|
||||
void sdpa_backward_cudnn(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
float scale,
|
||||
const array& o,
|
||||
const array& stats,
|
||||
bool do_causal,
|
||||
const array& d_o,
|
||||
array& d_q,
|
||||
array& d_k,
|
||||
array& d_v,
|
||||
Stream s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
|
||||
// TODO: Handle donation.
|
||||
d_q.set_data(cu::malloc_async(d_q.nbytes(), encoder));
|
||||
d_k.set_data(cu::malloc_async(d_k.nbytes(), encoder));
|
||||
d_v.set_data(cu::malloc_async(d_v.nbytes(), encoder));
|
||||
|
||||
encoder.set_input_array(q);
|
||||
encoder.set_input_array(k);
|
||||
encoder.set_input_array(v);
|
||||
encoder.set_input_array(o);
|
||||
encoder.set_input_array(stats);
|
||||
encoder.set_input_array(d_o);
|
||||
encoder.set_output_array(d_q);
|
||||
encoder.set_output_array(d_k);
|
||||
encoder.set_output_array(d_v);
|
||||
|
||||
// Search cache.
|
||||
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal);
|
||||
auto it = sdpa_backward_cache().find(cache_key);
|
||||
if (it == sdpa_backward_cache().end()) {
|
||||
auto graph = build_sdpa_backward_graph(
|
||||
handle, q, k, v, do_causal, o, d_o, stats, d_q, d_k, d_v);
|
||||
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
|
||||
}
|
||||
auto& graph = it->second;
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack{
|
||||
{Q, const_cast<void*>(gpu_ptr<void>(q))},
|
||||
{K, const_cast<void*>(gpu_ptr<void>(k))},
|
||||
{V, const_cast<void*>(gpu_ptr<void>(v))},
|
||||
{SCALE, &scale},
|
||||
{O, const_cast<void*>(gpu_ptr<void>(o))},
|
||||
{STATS, const_cast<void*>(gpu_ptr<void>(stats))},
|
||||
{D_O, const_cast<void*>(gpu_ptr<void>(d_o))},
|
||||
{D_Q, gpu_ptr<void>(d_q)},
|
||||
{D_K, gpu_ptr<void>(d_k)},
|
||||
{D_V, gpu_ptr<void>(d_v)}};
|
||||
|
||||
execute_graph(encoder, handle, graph, variant_pack);
|
||||
}
|
||||
|
||||
// Defined in scaled_dot_product_attention.cu file.
|
||||
bool supports_sdpa_vector(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
bool output_logsumexp);
|
||||
void sdpa_vector(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
float scale,
|
||||
array& o,
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks,
|
||||
Stream s);
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool ScaledDotProductAttention::use_fallback(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
bool is_training,
|
||||
bool output_logsumexp,
|
||||
Stream s) {
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return !supports_sdpa_vector(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
|
||||
!supports_sdpa_cudnn(q, k, v, has_mask, do_causal, s);
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
|
||||
array q = prepare_sdpa_input(inputs[0], s);
|
||||
array k = prepare_sdpa_input(inputs[1], s);
|
||||
array v = prepare_sdpa_input(inputs[2], s);
|
||||
auto& out = outputs[0];
|
||||
auto& stats = outputs[1];
|
||||
bool has_mask = inputs.size() - has_sinks_ > 3;
|
||||
bool has_arr_mask = has_mask && !do_causal_;
|
||||
|
||||
if (supports_sdpa_vector(
|
||||
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
|
||||
if (has_sinks_) {
|
||||
sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s);
|
||||
} else {
|
||||
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
|
||||
}
|
||||
} else {
|
||||
sdpa_cudnn(q, k, v, scale_, out, stats, do_causal_, output_logsumexp_, s);
|
||||
}
|
||||
}
|
||||
|
||||
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
|
||||
// The frontend adds a padding mask when sequence length is not a multiple of
|
||||
// tile size.
|
||||
if (q.shape(2) % 128 != 0) {
|
||||
return true;
|
||||
}
|
||||
return s.device == Device::cpu;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttentionVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("ScaledDotProductAttentionVJP::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
|
||||
assert(inputs.size() == 6);
|
||||
array q = prepare_sdpa_input(inputs[0], s);
|
||||
array k = prepare_sdpa_input(inputs[1], s);
|
||||
array v = prepare_sdpa_input(inputs[2], s);
|
||||
array o = prepare_sdpa_input(inputs[3], s);
|
||||
array stats = prepare_sdpa_input(inputs[4], s);
|
||||
array d_o = prepare_sdpa_input(inputs[5], s);
|
||||
|
||||
assert(outputs.size() == 3);
|
||||
auto& d_q = outputs[0];
|
||||
auto& d_k = outputs[1];
|
||||
auto& d_v = outputs[2];
|
||||
|
||||
sdpa_backward_cudnn(
|
||||
q, k, v, scale_, o, stats, do_causal_, d_o, d_q, d_k, d_v, s);
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -6,10 +6,6 @@
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
@@ -565,10 +561,9 @@ void sdpa_vector_2pass_fallback(
|
||||
array sums(intermediate_shape, float32, nullptr, {});
|
||||
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||
|
||||
intermediate.set_data(
|
||||
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
|
||||
sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream()));
|
||||
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
|
||||
intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
|
||||
sums.set_data(cu::malloc_async(sums.nbytes(), encoder));
|
||||
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder));
|
||||
|
||||
encoder.add_temporary(intermediate);
|
||||
encoder.add_temporary(sums);
|
||||
@@ -663,21 +658,16 @@ void sdpa_vector_fallback(
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace fast {
|
||||
|
||||
bool ScaledDotProductAttention::use_fallback(
|
||||
bool supports_sdpa_vector(
|
||||
const array& q,
|
||||
const array& k,
|
||||
const array& v,
|
||||
bool has_mask,
|
||||
bool has_arr_mask,
|
||||
bool do_causal,
|
||||
Stream s) {
|
||||
if (detail::in_grad_tracing()) {
|
||||
return true;
|
||||
}
|
||||
if (s.device == Device::cpu) {
|
||||
return true;
|
||||
bool output_logsumexp) {
|
||||
if (output_logsumexp) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int value_head_dim = v.shape(-1);
|
||||
@@ -691,29 +681,24 @@ bool ScaledDotProductAttention::use_fallback(
|
||||
const bool supported_vector_config =
|
||||
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||
|
||||
const bool supported_config = supported_vector_config;
|
||||
|
||||
return has_arr_mask || !supported_config;
|
||||
return supported_vector_config && !has_arr_mask;
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out) {
|
||||
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||
|
||||
auto& s = stream();
|
||||
void sdpa_vector(
|
||||
const array& q_pre,
|
||||
const array& k_pre,
|
||||
const array& v_pre,
|
||||
float scale,
|
||||
array& o,
|
||||
bool do_causal,
|
||||
const std::optional<array>& sinks_pre,
|
||||
Stream s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
auto& q_pre = inputs[0];
|
||||
auto& k_pre = inputs[1];
|
||||
auto& v_pre = inputs[2];
|
||||
auto& o = out;
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
// Define some copy functions to ensure the layout of the inputs is as
|
||||
// expected.
|
||||
copies.reserve(inputs.size());
|
||||
copies.reserve(4);
|
||||
auto copy_unless = [&copies, &s](
|
||||
auto predicate, const array& arr) -> const array& {
|
||||
if (!predicate(arr)) {
|
||||
@@ -731,8 +716,8 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
};
|
||||
|
||||
std::optional<array> sinks = std::nullopt;
|
||||
if (has_sinks_) {
|
||||
sinks = copy_unless(is_matrix_contiguous, inputs.back());
|
||||
if (sinks_pre) {
|
||||
sinks = copy_unless(is_matrix_contiguous, sinks_pre.value());
|
||||
}
|
||||
|
||||
// We are in vector mode ie single query
|
||||
@@ -788,7 +773,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
};
|
||||
|
||||
o.set_data(
|
||||
cu::malloc_async(o.nbytes(), encoder.stream()),
|
||||
cu::malloc_async(o.nbytes(), encoder),
|
||||
o.size(),
|
||||
{str_oB, str_oH, str_oL, str_oD},
|
||||
flags);
|
||||
@@ -798,8 +783,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
return sdpa_vector_fallback(
|
||||
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
|
||||
sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||
}
|
||||
|
||||
// Full attention mode should never reach here
|
||||
@@ -808,6 +792,4 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -374,7 +374,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
|
||||
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
|
||||
@@ -24,7 +24,7 @@ void concatenate_gpu(
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
||||
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
@@ -89,7 +89,7 @@ array compute_dynamic_offset(
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream()));
|
||||
offset.set_data(cu::malloc_async(offset.itemsize(), encoder));
|
||||
}
|
||||
|
||||
encoder.add_temporary(offset);
|
||||
|
||||
@@ -118,7 +118,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
|
||||
@@ -49,14 +49,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
array trans = swapaxes_in_eval(in, axis, last_dim);
|
||||
in = contiguous_copy_gpu(trans, s);
|
||||
encoder.add_temporary(in);
|
||||
out = array(
|
||||
cu::malloc_async(out.nbytes(), encoder.stream()),
|
||||
in.shape(),
|
||||
out.dtype());
|
||||
out =
|
||||
array(cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
|
||||
encoder.add_temporary(out);
|
||||
} else {
|
||||
out.set_data(
|
||||
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
|
||||
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
@@ -74,17 +72,13 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
if (argsort) {
|
||||
// Indices in the sorted dimension.
|
||||
array indices(
|
||||
cu::malloc_async(out.nbytes(), encoder.stream()),
|
||||
in.shape(),
|
||||
out.dtype());
|
||||
cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
|
||||
encoder.add_temporary(indices);
|
||||
|
||||
// In argsort though we don't need the result of sorted values, the
|
||||
// API requires us to provide an array to store it.
|
||||
array discard(
|
||||
cu::malloc_async(in.nbytes(), encoder.stream()),
|
||||
in.shape(),
|
||||
in.dtype());
|
||||
cu::malloc_async(in.nbytes(), encoder), in.shape(), in.dtype());
|
||||
encoder.add_temporary(discard);
|
||||
|
||||
size_t size;
|
||||
@@ -104,9 +98,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
stream));
|
||||
|
||||
array temp(
|
||||
cu::malloc_async(size, encoder.stream()),
|
||||
{static_cast<int>(size)},
|
||||
uint8);
|
||||
cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
|
||||
// Start capturing after allocations
|
||||
@@ -148,9 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||
stream));
|
||||
|
||||
array temp(
|
||||
cu::malloc_async(size, encoder.stream()),
|
||||
{static_cast<int>(size)},
|
||||
uint8);
|
||||
cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
|
||||
encoder.add_temporary(temp);
|
||||
|
||||
// Start capturing after allocations
|
||||
|
||||
@@ -257,9 +257,8 @@ void ternary_op_gpu(
|
||||
auto& c = inputs[2];
|
||||
auto topt = get_ternary_op_type(a, b, c);
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
set_ternary_op_output_data(
|
||||
a, b, c, out, topt, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
||||
}
|
||||
|
||||
|
||||
@@ -208,9 +208,8 @@ void unary_op_gpu(
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
set_unary_output_data(inputs[0], out, [&](auto n) {
|
||||
return cu::malloc_async(n, encoder.stream());
|
||||
});
|
||||
set_unary_output_data(
|
||||
inputs[0], out, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <vector>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -60,7 +61,7 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||
case float64:
|
||||
return "double";
|
||||
case complex64:
|
||||
return "complex64_t";
|
||||
return "mlx::core::cu::complex64_t";
|
||||
default:
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ void Worker::commit(cudaStream_t stream) {
|
||||
}
|
||||
signal_event_.record(stream);
|
||||
signal_event_.wait(signal_stream_);
|
||||
cudaLaunchHostFunc(signal_stream_, signal, this);
|
||||
CHECK_CUDA_ERROR(cudaLaunchHostFunc(signal_stream_, signal, this));
|
||||
}
|
||||
|
||||
void Worker::thread_fn() {
|
||||
|
||||
@@ -11,7 +11,7 @@ void slice_gpu(
|
||||
array& out,
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
const Stream&) {
|
||||
slice(in, out, start_indices, strides);
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||
make_jit_source(indexing/scatter kernels/indexing/indexing.h)
|
||||
make_jit_source(indexing/masked_scatter)
|
||||
make_jit_source(indexing/gather kernels/indexing/indexing.h)
|
||||
make_jit_source(indexing/gather_front kernels/indexing/indexing.h)
|
||||
make_jit_source(indexing/gather_axis)
|
||||
|
||||
@@ -32,7 +32,7 @@ std::string write_signature(
|
||||
const std::vector<Dtype>& output_dtypes,
|
||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||
const std::vector<std::string>& attributes,
|
||||
const std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
const std::vector<std::tuple<bool, bool, bool>>& shape_infos,
|
||||
bool atomic_outputs) {
|
||||
std::string kernel_source;
|
||||
kernel_source.reserve(header.size() + source.size() + 16384);
|
||||
@@ -88,19 +88,19 @@ std::string write_signature(
|
||||
index++;
|
||||
// Add input shape, strides and ndim if present in the source
|
||||
if (arr.ndim() > 0) {
|
||||
if (shape_infos[i].shape) {
|
||||
if (std::get<0>(shape_infos[i])) {
|
||||
kernel_source +=
|
||||
(" const constant int* " + name + "_shape [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].strides) {
|
||||
if (std::get<1>(shape_infos[i])) {
|
||||
kernel_source +=
|
||||
(" const constant int64_t* " + name + "_strides [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
index++;
|
||||
}
|
||||
if (shape_infos[i].ndim) {
|
||||
if (std::get<2>(shape_infos[i])) {
|
||||
kernel_source +=
|
||||
(" const constant int& " + name + "_ndim [[buffer(" +
|
||||
std::to_string(index) + ")]],\n");
|
||||
@@ -184,12 +184,12 @@ CustomKernelFunction metal_kernel(
|
||||
throw std::invalid_argument(
|
||||
"[metal_kernel] Must specify at least one output.");
|
||||
}
|
||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||
std::vector<std::tuple<bool, bool, bool>> shape_infos;
|
||||
for (auto& n : input_names) {
|
||||
CustomKernelShapeInfo shape_info;
|
||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
||||
std::tuple<bool, bool, bool> shape_info;
|
||||
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
|
||||
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
|
||||
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
|
||||
shape_infos.push_back(shape_info);
|
||||
}
|
||||
const std::vector<std::pair<std::string, std::string>> metal_attributes = {
|
||||
@@ -388,15 +388,15 @@ void CustomKernel::eval_gpu(
|
||||
index++;
|
||||
if (in.ndim() > 0) {
|
||||
int ndim = in.ndim();
|
||||
if (shape_info.shape) {
|
||||
if (std::get<0>(shape_info)) {
|
||||
compute_encoder.set_vector_bytes(in.shape(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.strides) {
|
||||
if (std::get<1>(shape_info)) {
|
||||
compute_encoder.set_vector_bytes(in.strides(), ndim, index);
|
||||
index++;
|
||||
}
|
||||
if (shape_info.ndim) {
|
||||
if (std::get<2>(shape_info)) {
|
||||
compute_encoder.set_bytes(ndim, index);
|
||||
index++;
|
||||
}
|
||||
|
||||
@@ -382,11 +382,8 @@ MTL::CommandQueue* Device::get_queue(Stream stream) {
|
||||
|
||||
bool Device::command_buffer_needs_commit(int index) {
|
||||
auto& stream = get_stream_(index);
|
||||
if (stream.buffer_ops > max_ops_per_buffer_ ||
|
||||
(stream.buffer_sizes >> 20) > max_mb_per_buffer_) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return (stream.buffer_ops > max_ops_per_buffer_) ||
|
||||
((stream.buffer_sizes >> 20) > max_mb_per_buffer_);
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
|
||||
@@ -265,4 +265,19 @@ Device& device(mlx::core::Device);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
inline bool is_nax_available() {
|
||||
auto _check_nax = []() {
|
||||
bool can_use_nax = false;
|
||||
if (__builtin_available(
|
||||
macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) {
|
||||
can_use_nax = true;
|
||||
}
|
||||
can_use_nax &=
|
||||
metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17;
|
||||
return can_use_nax;
|
||||
};
|
||||
static bool is_nax_available_ = _check_nax();
|
||||
return is_nax_available_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
@@ -8,7 +9,9 @@
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/indexing.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/scan.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -641,4 +644,84 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
const array& dst = inputs[0];
|
||||
const array& mask = inputs[1];
|
||||
const array& src = inputs[2];
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
const size_t total = mask.size();
|
||||
const CopyType ct = (total == 1)
|
||||
? CopyType::Scalar
|
||||
: (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy_gpu(dst, out, ct, s);
|
||||
if (total == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
array mask_flat = flatten_in_eval(mask, 1, -1, s);
|
||||
if (mask_flat.data<void>() != mask.data<void>()) {
|
||||
d.add_temporary(mask_flat, s.index);
|
||||
}
|
||||
|
||||
if (!mask_flat.flags().row_contiguous) {
|
||||
mask_flat = contiguous_copy_gpu(mask_flat, s);
|
||||
d.add_temporary(mask_flat, s.index);
|
||||
}
|
||||
|
||||
// Prefix (exclusive) of mask → scatter_offsets
|
||||
array scatter_offsets(mask_flat.shape(), uint32, nullptr, {});
|
||||
scatter_offsets.set_data(allocator::malloc(scatter_offsets.nbytes()));
|
||||
d.add_temporary(scatter_offsets, s.index);
|
||||
|
||||
scan_gpu_inplace(
|
||||
mask_flat,
|
||||
scatter_offsets,
|
||||
Scan::Sum,
|
||||
/*axis=*/1,
|
||||
/*reverse=*/false,
|
||||
/*inclusive=*/false,
|
||||
s);
|
||||
|
||||
// Kernel selection/build
|
||||
static constexpr std::string_view kBaseName = "masked_assign";
|
||||
const std::string dtype_tag = type_to_name(out.dtype());
|
||||
const std::string value_type = get_type_string(out.dtype());
|
||||
const std::string contiguous =
|
||||
(src.flags().row_contiguous) ? "true" : "false";
|
||||
const std::string kernel_name =
|
||||
fmt::format("{}_{}_{}", kBaseName, dtype_tag, contiguous);
|
||||
|
||||
auto lib = d.get_library(kernel_name, [&]() {
|
||||
std::string source = metal::utils();
|
||||
source += metal::masked_scatter();
|
||||
source += fmt::format(
|
||||
std::string(masked_assign_kernel), kernel_name, value_type, contiguous);
|
||||
return source;
|
||||
});
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
|
||||
// Binding
|
||||
int bind_idx = 0;
|
||||
const int ndim = static_cast<int>(src.ndim());
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(mask_flat, bind_idx++);
|
||||
compute_encoder.set_input_array(scatter_offsets, bind_idx++);
|
||||
compute_encoder.set_input_array(src, bind_idx++);
|
||||
compute_encoder.set_output_array(out, bind_idx++);
|
||||
compute_encoder.set_vector_bytes(src.shape(), bind_idx++);
|
||||
compute_encoder.set_vector_bytes(src.strides(), bind_idx++);
|
||||
compute_encoder.set_bytes(ndim, bind_idx++);
|
||||
compute_encoder.set_bytes(src.size() / src.shape(0), bind_idx++);
|
||||
compute_encoder.set_bytes(mask_flat.size() / mask.shape(0), bind_idx++);
|
||||
|
||||
// Dispatch
|
||||
auto group_dims = get_block_dims(total, 1, 1);
|
||||
MTL::Size grid_dims(total, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -11,6 +11,7 @@ const char* ternary_ops();
|
||||
const char* reduce_utils();
|
||||
const char* gather();
|
||||
const char* scatter();
|
||||
const char* masked_scatter();
|
||||
|
||||
const char* arange();
|
||||
const char* unary();
|
||||
|
||||
@@ -70,3 +70,7 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
gid);
|
||||
}}
|
||||
)";
|
||||
|
||||
constexpr std::string_view masked_assign_kernel = R"(
|
||||
template [[host_name("{0}")]] [[kernel]] decltype(masked_assign_impl<{1}, {2}>) masked_assign_impl<{1}, {2}>;
|
||||
)";
|
||||
|
||||
@@ -9,7 +9,14 @@ set(BASE_HEADERS
|
||||
utils.h)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
set(METAL_FLAGS
|
||||
-x
|
||||
metal
|
||||
-Wall
|
||||
-Wextra
|
||||
-fno-fast-math
|
||||
-Wno-c++17-extensions
|
||||
-Wno-c++20-extensions)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||
endif()
|
||||
@@ -120,6 +127,30 @@ if(NOT MLX_METAL_JIT)
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
endif()
|
||||
|
||||
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
|
||||
26.2))
|
||||
set(STEEL_NAX_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/nax.h
|
||||
steel/gemm/gemm_nax.h
|
||||
steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
|
||||
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
|
||||
|
||||
build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS})
|
||||
build_kernel(fp_quantized_nax fp_quantized_nax.h ${STEEL_NAX_HEADERS})
|
||||
|
||||
set(STEEL_NAX_ATTN_HEADERS
|
||||
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
|
||||
steel/utils/integral_constant.h)
|
||||
|
||||
build_kernel(steel/attn/kernels/steel_attention_nax ${STEEL_NAX_ATTN_HEADERS})
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||
|
||||
1066
mlx/backend/metal/kernels/fp_quantized_nax.h
Normal file
1066
mlx/backend/metal/kernels/fp_quantized_nax.h
Normal file
File diff suppressed because it is too large
Load Diff
74
mlx/backend/metal/kernels/fp_quantized_nax.metal
Normal file
74
mlx/backend/metal/kernels/fp_quantized_nax.metal
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/quantized_utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||
#include "mlx/backend/metal/kernels/fp_quantized_nax.h"
|
||||
|
||||
|
||||
#define instantiate_quantized_batched(mode, name, type, bm, bn, bk, wm, wn, batched) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
batched)
|
||||
|
||||
#define instantiate_quantized_aligned(mode, name, type, bm, bn, bk, wm, wn, aligned) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
aligned)
|
||||
|
||||
#define instantiate_quantized_aligned_batched(mode, name, type, bm, bn, bk, wm, wn, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \
|
||||
fp_ ## name, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
aligned, \
|
||||
batched)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||
func, \
|
||||
type, \
|
||||
32, \
|
||||
4, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
transpose)
|
||||
|
||||
|
||||
#define instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_aligned(mxfp4, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, true) \
|
||||
instantiate_quantized_aligned(mxfp4, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, false) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 0) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 0)
|
||||
|
||||
|
||||
#define instantiate_quantized_all_rhs(type) \
|
||||
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, mxfp4_gather_qmm_rhs_nax_nt, type, 64, 64, 64, 2, 2, true) \
|
||||
instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, mxfp4_gather_qmm_rhs_nax_nn, type, 64, 64, 64, 2, 2, false)
|
||||
|
||||
#define instantiate_quantized_types(type) \
|
||||
instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_all_rhs(type)
|
||||
|
||||
instantiate_quantized_types(float)
|
||||
instantiate_quantized_types(bfloat16_t)
|
||||
instantiate_quantized_types(float16_t)
|
||||
// clang-format on
|
||||
38
mlx/backend/metal/kernels/indexing/masked_scatter.h
Normal file
38
mlx/backend/metal/kernels/indexing/masked_scatter.h
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename T, bool src_contiguous>
|
||||
[[kernel]] void masked_assign_impl(
|
||||
const device bool* mask [[buffer(0)]],
|
||||
const device uint* scatter_offsets [[buffer(1)]],
|
||||
const device T* src [[buffer(2)]],
|
||||
device T* out [[buffer(3)]],
|
||||
const constant int* src_shapes [[buffer(4)]],
|
||||
const constant int64_t* src_strides [[buffer(5)]],
|
||||
const constant int& src_ndim [[buffer(6)]],
|
||||
const constant int64_t& src_batch_size [[buffer(7)]],
|
||||
const constant int64_t& mask_batch_size [[buffer(8)]],
|
||||
uint idx [[thread_position_in_grid]]) {
|
||||
const bool mask_value = mask[idx];
|
||||
if (!mask_value) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint src_index = scatter_offsets[idx];
|
||||
if (src_index >= src_batch_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint batch_idx = idx / mask_batch_size;
|
||||
|
||||
if (src_contiguous) {
|
||||
out[idx] = src[batch_idx * src_batch_size + src_index];
|
||||
} else {
|
||||
out[idx] = src[elem_to_loc<uint>(
|
||||
batch_idx * src_batch_size + src_index,
|
||||
src_shapes,
|
||||
src_strides,
|
||||
src_ndim)];
|
||||
}
|
||||
}
|
||||
1705
mlx/backend/metal/kernels/quantized_nax.h
Normal file
1705
mlx/backend/metal/kernels/quantized_nax.h
Normal file
File diff suppressed because it is too large
Load Diff
106
mlx/backend/metal/kernels/quantized_nax.metal
Normal file
106
mlx/backend/metal/kernels/quantized_nax.metal
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/loader.h"
|
||||
#include "mlx/backend/metal/kernels/quantized_nax.h"
|
||||
|
||||
#define instantiate_quantized(name, type, group_size, bits, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, bm, bk, bn, wm, wn)
|
||||
|
||||
#define instantiate_quantized_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
batched, bm, bk, bn, wm, wn)
|
||||
|
||||
#define instantiate_quantized_aligned(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
aligned, bm, bk, bn, wm, wn)
|
||||
|
||||
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
aligned, \
|
||||
batched, bm, bk, bn, wm, wn)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||
func, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
transpose)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 1) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 0)
|
||||
|
||||
#define instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(affine_qmm_n_nax, type, group_size, bits)
|
||||
|
||||
|
||||
#define instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized(affine_gather_qmm_n_nax, type, group_size, bits, 64, 64, 64, 2, 2)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true) \
|
||||
instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 1) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 0) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 1) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 0)
|
||||
|
||||
#define instantiate_quantized_all_rhs(type, group_size, bits) \
|
||||
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nt, type, group_size, bits, 64, 64, 64, 2, 2, true) \
|
||||
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nn, type, group_size, bits, 64, 64, 64, 2, 2, false)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_all_rhs(type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_types(group_size, bits) \
|
||||
instantiate_quantized_funcs(float, group_size, bits) \
|
||||
instantiate_quantized_funcs(float16_t, group_size, bits) \
|
||||
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_groups(bits) \
|
||||
instantiate_quantized_types(128, bits) \
|
||||
instantiate_quantized_types(64, bits) \
|
||||
instantiate_quantized_types(32, bits)
|
||||
|
||||
#define instantiate_quantized_all() \
|
||||
instantiate_quantized_groups(2) \
|
||||
instantiate_quantized_groups(3) \
|
||||
instantiate_quantized_groups(4) \
|
||||
instantiate_quantized_groups(5) \
|
||||
instantiate_quantized_groups(6) \
|
||||
instantiate_quantized_groups(8)
|
||||
|
||||
instantiate_quantized_all() // clang-format on
|
||||
@@ -51,6 +51,7 @@ using namespace metal;
|
||||
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
|
||||
|
||||
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_bool__uint32, bool, uint32_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4)
|
||||
|
||||
@@ -0,0 +1,476 @@
|
||||
// Copyright © 2024-25 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GEMM kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constant bool align_Q [[function_constant(200)]];
|
||||
constant bool align_K [[function_constant(201)]];
|
||||
|
||||
constant bool has_mask [[function_constant(300)]];
|
||||
constant bool do_causal [[function_constant(301)]];
|
||||
constant bool has_sinks [[function_constant(302)]];
|
||||
|
||||
template <typename T>
|
||||
struct TransformScale {
|
||||
T scale;
|
||||
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
|
||||
|
||||
METAL_FUNC T apply(T x) const {
|
||||
return scale * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct MaxOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
struct SumOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct MulOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct SubOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpSubOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return fast::exp2(x - y);
|
||||
}
|
||||
};
|
||||
|
||||
struct DivOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
int BQ,
|
||||
int BK,
|
||||
int BD,
|
||||
int WM,
|
||||
int WN,
|
||||
typename MaskType = float,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax(
|
||||
const device T* Q [[buffer(0)]],
|
||||
const device T* K [[buffer(1)]],
|
||||
const device T* V [[buffer(2)]],
|
||||
device T* O [[buffer(3)]],
|
||||
const constant AttnParams* params [[buffer(4)]],
|
||||
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
|
||||
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
|
||||
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||
|
||||
// Pacifying compiler
|
||||
(void)lid;
|
||||
(void)simd_lane_id;
|
||||
|
||||
// Move to correct block
|
||||
ulong3 tidl{tid.x, tid.y, tid.z};
|
||||
|
||||
Q += tidl.z * params->Q_strides[0] + // Batch
|
||||
tidl.y * params->Q_strides[1] + // Head
|
||||
tidl.x * BQ * params->Q_strides[2]; // Sequence
|
||||
|
||||
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
|
||||
K += tidl.z * params->K_strides[0] + // Batch
|
||||
kv_head_idx * params->K_strides[1]; // Head
|
||||
|
||||
V += tidl.z * params->V_strides[0] + // Batch
|
||||
kv_head_idx * params->V_strides[1]; // Head
|
||||
|
||||
O += tidl.z * params->O_strides[0] + // Batch
|
||||
tidl.y * params->O_strides[1] + // Head
|
||||
tidl.x * BQ * params->O_strides[2]; // Sequence
|
||||
|
||||
if (has_mask) {
|
||||
mask += tidl.z * mask_params->M_strides[0] + // Batch
|
||||
tidl.y * mask_params->M_strides[1]; // Head
|
||||
}
|
||||
|
||||
const metal::uniform<float> scale2 =
|
||||
make_uniform(params->scale) * make_uniform(1.44269504089f);
|
||||
|
||||
// Prepare MMA tiles
|
||||
constexpr short UQ = 16;
|
||||
constexpr short UD = 32;
|
||||
|
||||
constexpr int kNWarps = WM * WN;
|
||||
static_assert(
|
||||
BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0,
|
||||
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
|
||||
|
||||
// Q seq frags per warp
|
||||
constexpr int TQ = BQ / (kNWarps * UQ);
|
||||
// HeadDim frags (all warps load the same frags)
|
||||
constexpr int TD = BD / UD;
|
||||
|
||||
static_assert(TQ == 1, "Check TQ");
|
||||
|
||||
using OSubTile = NAXSubTile<AccumType, UQ, UD>;
|
||||
NAXTile<AccumType, TQ, TD, OSubTile> Otile;
|
||||
|
||||
Otile.clear();
|
||||
|
||||
// Prepare mma tile offsets
|
||||
const short2 simd_coord = OSubTile::NAXFrag_t::get_coord();
|
||||
const short sm = simd_coord.y;
|
||||
const short sn = simd_coord.x;
|
||||
const short tm = UQ * TQ * simd_group_id;
|
||||
|
||||
Q += (tm + sm) * int(params->Q_strides[2]) + sn;
|
||||
K += sm * int(params->K_strides[2]) + sn;
|
||||
V += sm * int(params->V_strides[2]) + sn;
|
||||
|
||||
// Init row reduction variables
|
||||
constexpr short kRowsPT = decltype(Otile)::kRowsPerThread;
|
||||
|
||||
metal::vec<AccumType, kRowsPT> max_score;
|
||||
metal::vec<AccumType, kRowsPT> sum_score{0};
|
||||
|
||||
// Init to -Inf
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = Limits<AccumType>::finite_min;
|
||||
}
|
||||
|
||||
if (has_sinks) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
|
||||
sum_score[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
int kb_lim = params->NK;
|
||||
|
||||
if (do_causal) {
|
||||
int q_max = (tid.x + 1) * BQ + params->qL_off;
|
||||
kb_lim = (q_max + BK - 1) / BK;
|
||||
kb_lim = min(params->NK, kb_lim);
|
||||
}
|
||||
|
||||
const bool is_last_bq = int(tid.x) == (params->NQ_aligned);
|
||||
// const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);
|
||||
const bool is_last_q = is_last_bq;
|
||||
|
||||
const short lim_rows_q = params->qL_rem - (tm + sm);
|
||||
const short lim_rows_k = params->kL_rem - sm;
|
||||
|
||||
// Loop over KV seq length
|
||||
for (int kb = 0; kb < kb_lim; kb++) {
|
||||
const int is_last_k = (kb == (params->NK_aligned));
|
||||
|
||||
// Do S = Q @ K.T
|
||||
constexpr short UDs = 16;
|
||||
constexpr short UKs = 32;
|
||||
|
||||
constexpr short TDs = BD / UDs;
|
||||
constexpr short TKs = BK / UKs;
|
||||
|
||||
using SSubTile = NAXSubTile<AccumType, UQ, UKs>;
|
||||
using QSubTile = NAXSubTile<T, UQ, UDs>;
|
||||
using KSubTile = NAXSubTile<T, UKs, UDs>;
|
||||
|
||||
NAXTile<AccumType, TQ, TKs, SSubTile> Stile;
|
||||
|
||||
Stile.clear();
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short iq = 0; iq < TQ; iq++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ik = 0; ik < TKs; ik++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short id = 0; id < TDs; id++) {
|
||||
NAXTile<T, 1, 1, QSubTile> Qtile;
|
||||
NAXTile<T, 1, 1, KSubTile> Ktile;
|
||||
|
||||
const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs;
|
||||
const int K_load_off =
|
||||
ik * UKs * int(params->K_strides[2]) + id * UDs;
|
||||
|
||||
if (!align_Q && is_last_q) {
|
||||
// Qtile.load_rows(
|
||||
// Q + Q_load_off,
|
||||
// int(params->Q_strides[2]),
|
||||
// lim_rows_q - iq * UQ);
|
||||
Qtile.load_safe(
|
||||
Q + Q_load_off,
|
||||
int(params->Q_strides[2]),
|
||||
short2(BD, lim_rows_q - iq * UQ));
|
||||
} else {
|
||||
Qtile.load(Q + Q_load_off, int(params->Q_strides[2]));
|
||||
}
|
||||
|
||||
if (!align_K && is_last_k) {
|
||||
// Ktile.load_rows(
|
||||
// K + K_load_off,
|
||||
// int(params->K_strides[2]),
|
||||
// lim_rows_k - ik * UKs);
|
||||
Ktile.load_safe(
|
||||
K + K_load_off,
|
||||
int(params->K_strides[2]),
|
||||
short2(BD, lim_rows_k - ik * UKs));
|
||||
} else {
|
||||
Ktile.load(K + K_load_off, int(params->K_strides[2]));
|
||||
}
|
||||
|
||||
subtile_matmad_nax(
|
||||
Stile.subtile_at(iq, ik),
|
||||
Qtile.subtile_at(0, 0),
|
||||
metal::false_type{},
|
||||
Ktile.subtile_at(0, 0),
|
||||
metal::true_type{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale S
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
|
||||
Stile.elems()[ii] *= float(scale2);
|
||||
}
|
||||
|
||||
// Scale and Retile S
|
||||
constexpr short UK = 16;
|
||||
constexpr short TK = BK / UK;
|
||||
using PSubTile = NAXSubTile<AccumType, UQ, UK>;
|
||||
|
||||
NAXTile<AccumType, TQ, TK, PSubTile> Ptile;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
|
||||
Ptile.elems()[ii] = Stile.elems()[ii];
|
||||
}
|
||||
|
||||
// Mask out length sequence
|
||||
if (!align_K && is_last_k) {
|
||||
constexpr auto neg_inf = Limits<AccumType>::finite_min;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short iq = 0; iq < TQ; iq++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ik = 0; ik < TK; ik++) {
|
||||
const short col_pos = sn + ik * UK;
|
||||
|
||||
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
|
||||
const auto loc = ii * PSubTile::kFragThrCols + jj;
|
||||
fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mask out if causal
|
||||
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
|
||||
constexpr auto neg_inf = Limits<AccumType>::finite_min;
|
||||
|
||||
const int base_row = tid.x * BQ + params->qL_off + tm;
|
||||
const int base_col = kb * BK;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short iq = 0; iq < TQ; iq++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ik = 0; ik < TK; ik++) {
|
||||
const short row_pos = base_row + iq * UQ;
|
||||
const short col_pos = base_col + ik * UK;
|
||||
|
||||
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
|
||||
const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm;
|
||||
const auto c = col_pos + jj + sn;
|
||||
const auto loc = ii * PSubTile::kFragThrCols + jj;
|
||||
fg[loc] = (r < c) ? neg_inf : fg[loc];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Other masking as needed
|
||||
if (has_mask) {
|
||||
constexpr auto neg_inf = Limits<AccumType>::finite_min;
|
||||
|
||||
const int base_row = tid.x * BQ + tm;
|
||||
const int base_col = kb * BK;
|
||||
|
||||
constexpr bool is_bool = is_same_v<MaskType, bool>;
|
||||
using melem_t = typename metal::conditional_t<is_bool, bool, AccumType>;
|
||||
using MSubTile = NAXSubTile<melem_t, UQ, UK>;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short iq = 0; iq < TQ; iq++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ik = 0; ik < TK; ik++) {
|
||||
const short row_pos = base_row + iq * UQ + sm;
|
||||
const short col_pos = base_col + ik * UK + sn;
|
||||
|
||||
MSubTile mfrag;
|
||||
mfrag.load_safe(
|
||||
mask,
|
||||
int(mask_params->M_strides[2]),
|
||||
Int<1>{},
|
||||
params->qL,
|
||||
params->kL,
|
||||
row_pos,
|
||||
col_pos);
|
||||
|
||||
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) {
|
||||
if constexpr (is_bool) {
|
||||
fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf;
|
||||
} else {
|
||||
fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do softmax
|
||||
|
||||
// Temp variables
|
||||
metal::vec<AccumType, kRowsPT> new_max;
|
||||
metal::vec<AccumType, kRowsPT> factor;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
new_max[i] = max_score[i];
|
||||
}
|
||||
|
||||
// Row max
|
||||
Ptile.template row_reduce<MaxOp>(new_max);
|
||||
|
||||
// exp(Si - rowmax(Si))
|
||||
Ptile.template row_bin_op<ExpSubOp>(new_max);
|
||||
|
||||
// Factor exp(rowmax(Si) - rowmax(Si-1))
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
factor[i] = fast::exp2(max_score[i] - new_max[i]);
|
||||
max_score[i] = new_max[i];
|
||||
}
|
||||
|
||||
// Row Sum
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
sum_score[i] = sum_score[i] * factor[i];
|
||||
}
|
||||
|
||||
Ptile.template row_reduce<SumOp>(sum_score);
|
||||
|
||||
// Update O
|
||||
Otile.template row_bin_op<MulOp>(factor);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Do O = P @ V
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short iq = 0; iq < TQ; iq++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short id = 0; id < TD; id++) {
|
||||
if constexpr (BD == 128) {
|
||||
if (id == 2) {
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
}
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ik = 0; ik < TK; ik++) {
|
||||
using VSubTile = NAXSubTile<T, UK, UD>;
|
||||
NAXTile<T, 1, 1, VSubTile> Vtile;
|
||||
|
||||
const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD;
|
||||
|
||||
if (!align_K && is_last_k) {
|
||||
// Vtile.load_rows(
|
||||
// V + V_load_off,
|
||||
// int(params->V_strides[2]),
|
||||
// lim_rows_k - ik * UK);
|
||||
Vtile.load_safe(
|
||||
V + V_load_off,
|
||||
int(params->V_strides[2]),
|
||||
short2(BD, lim_rows_k - ik * UK));
|
||||
} else {
|
||||
Vtile.load(V + V_load_off, int(params->V_strides[2]));
|
||||
}
|
||||
|
||||
subtile_matmad_nax(
|
||||
Otile.subtile_at(iq, id),
|
||||
Ptile.subtile_at(iq, ik),
|
||||
metal::bool_constant<false>{},
|
||||
Vtile.subtile_at(0, 0),
|
||||
metal::bool_constant<false>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
K += BK * int(params->K_strides[2]);
|
||||
V += BK * int(params->V_strides[2]);
|
||||
}
|
||||
|
||||
// Normalize output
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
metal::vec<AccumType, kRowsPT> rcp;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
rcp[i] = (1.f / sum_score[i]);
|
||||
}
|
||||
|
||||
Otile.template row_bin_op<MulOp>(rcp);
|
||||
|
||||
// Store results
|
||||
O += (tm + sm) * int(params->O_strides[2]) + sn;
|
||||
|
||||
if (!align_Q && is_last_q) {
|
||||
if (lim_rows_q <= 0)
|
||||
return;
|
||||
|
||||
// Otile.store_rows(O, params->O_strides[2], lim_rows_q);
|
||||
Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q));
|
||||
} else {
|
||||
Otile.store(O, int(params->O_strides[2]));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
// Copyright © 2024-25 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h"
|
||||
|
||||
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
|
||||
instantiate_kernel( \
|
||||
"steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \
|
||||
"_wm" #wm "_wn" #wn "_mask" #mname, \
|
||||
attention_nax, dtype, bq, bk, bd, wm, wn, mtype, float)
|
||||
|
||||
#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
|
||||
instantiate_attn(iname, itype, 64, 32, 128, 4, 1, mname, mtype) \
|
||||
instantiate_attn(iname, itype, 64, 32, 64, 4, 1, mname, mtype) \
|
||||
instantiate_attn(iname, itype, 64, 64, 128, 4, 1, mname, mtype) \
|
||||
instantiate_attn(iname, itype, 64, 64, 64, 4, 1, mname, mtype)
|
||||
|
||||
#define instantiate_attn_mask_helper(iname, itype) \
|
||||
instantiate_attn_shapes_helper(iname, itype, iname, itype) \
|
||||
instantiate_attn_shapes_helper(iname, itype, bool_, bool)
|
||||
|
||||
instantiate_attn_mask_helper(float16, half);
|
||||
instantiate_attn_mask_helper(bfloat16, bfloat);
|
||||
|
||||
instantiate_attn_mask_helper(float32, float);
|
||||
// clang-format on
|
||||
1076
mlx/backend/metal/kernels/steel/attn/nax.h
Normal file
1076
mlx/backend/metal/kernels/steel/attn/nax.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#define STEEL_CONST static constant constexpr const
|
||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||
#define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)")
|
||||
|
||||
154
mlx/backend/metal/kernels/steel/gemm/gemm_nax.h
Normal file
154
mlx/backend/metal/kernels/steel/gemm/gemm_nax.h
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
namespace mlx::steel {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short SM,
|
||||
short SN,
|
||||
short SK,
|
||||
short BK,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
bool kAlignedM,
|
||||
bool kAlignedN,
|
||||
bool kAlignedK,
|
||||
short UM,
|
||||
short UN,
|
||||
short UK,
|
||||
typename AccumType = float>
|
||||
auto gemm_loop(
|
||||
const device T* A,
|
||||
const device T* B,
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const short sgp_sm,
|
||||
const short sgp_sn) {
|
||||
constexpr short TM = SM / UM;
|
||||
constexpr short TN = SN / UN;
|
||||
constexpr short TK = SK / UK;
|
||||
|
||||
constexpr int RA = transpose_a ? TK : TM;
|
||||
constexpr int CA = transpose_a ? TM : TK;
|
||||
|
||||
constexpr int RB = transpose_b ? TN : TK;
|
||||
constexpr int CB = transpose_b ? TK : TN;
|
||||
|
||||
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
||||
using ASubTile =
|
||||
NAXSubTile<T, (transpose_a ? UK : UM), (transpose_a ? UM : UK)>;
|
||||
using BSubTile =
|
||||
NAXSubTile<T, (transpose_b ? UN : UK), (transpose_b ? UK : UN)>;
|
||||
|
||||
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
|
||||
Dtile.clear();
|
||||
|
||||
int gemm_k_iterations_ = params->gemm_k_iterations_aligned;
|
||||
|
||||
STEEL_PRAGMA_NO_UNROLL
|
||||
for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) {
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
STEEL_PRAGMA_NO_UNROLL
|
||||
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
||||
NAXTile<T, RA, CA, ASubTile> Atile;
|
||||
NAXTile<T, RB, CB, BSubTile> Btile;
|
||||
const int k = kk1;
|
||||
|
||||
volatile int compiler_barrier;
|
||||
|
||||
const int A_offset = transpose_a ? k * params->lda : k;
|
||||
const int B_offset = transpose_b ? k : k * params->ldb;
|
||||
|
||||
if constexpr (kAlignedM) {
|
||||
Atile.load(A + A_offset, params->lda);
|
||||
} else {
|
||||
const short rmax = transpose_a ? SK : sgp_sm;
|
||||
const short cmax = transpose_a ? sgp_sm : SK;
|
||||
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
|
||||
}
|
||||
|
||||
if constexpr (kAlignedN) {
|
||||
Btile.load(B + B_offset, params->ldb);
|
||||
} else {
|
||||
const short rmax = transpose_b ? sgp_sn : SK;
|
||||
const short cmax = transpose_b ? SK : sgp_sn;
|
||||
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
|
||||
}
|
||||
|
||||
tile_matmad_nax(
|
||||
Dtile,
|
||||
Atile,
|
||||
metal::bool_constant<transpose_a>{},
|
||||
Btile,
|
||||
metal::bool_constant<transpose_b>{});
|
||||
|
||||
(void)compiler_barrier;
|
||||
}
|
||||
|
||||
A += transpose_a ? (BK * params->lda) : BK;
|
||||
B += transpose_b ? BK : (BK * params->ldb);
|
||||
}
|
||||
|
||||
if constexpr (!kAlignedK) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
const short rem_bk = params->K - gemm_k_iterations_ * BK;
|
||||
|
||||
STEEL_PRAGMA_NO_UNROLL
|
||||
for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) {
|
||||
NAXTile<T, 1, 1, ASubTile> Atile;
|
||||
NAXTile<T, 1, 1, BSubTile> Btile;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int mm = 0; mm < TM; mm++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int nn = 0; nn < TN; nn++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int kk = 0; kk < TK; kk++) {
|
||||
const int m = mm * UM;
|
||||
const int n = nn * UN;
|
||||
const int k = kk1 + kk * UK;
|
||||
const short psk = max(0, rem_bk - k);
|
||||
|
||||
const int A_offset =
|
||||
transpose_a ? (m + k * params->lda) : (m * params->lda + k);
|
||||
const int B_offset =
|
||||
transpose_b ? (k + n * params->ldb) : (k * params->ldb + n);
|
||||
|
||||
{
|
||||
const short psm = kAlignedM ? SM : max(0, sgp_sm - m);
|
||||
const short rmax = transpose_a ? psk : psm;
|
||||
const short cmax = transpose_a ? psm : psk;
|
||||
Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax));
|
||||
}
|
||||
|
||||
{
|
||||
const short psn = kAlignedN ? SN : max(0, sgp_sn - n);
|
||||
const short rmax = transpose_b ? psn : psk;
|
||||
const short cmax = transpose_b ? psk : psn;
|
||||
Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax));
|
||||
}
|
||||
|
||||
subtile_matmad_nax(
|
||||
Dtile.subtile_at(mm, nn),
|
||||
Atile.subtile_at(0, 0),
|
||||
metal::bool_constant<transpose_a>{},
|
||||
Btile.subtile_at(0, 0),
|
||||
metal::bool_constant<transpose_b>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Dtile;
|
||||
}
|
||||
|
||||
} // namespace mlx::steel
|
||||
@@ -0,0 +1,207 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
constant bool has_batch [[function_constant(10)]];
|
||||
|
||||
constant bool use_out_source [[function_constant(100)]];
|
||||
constant bool do_axpby [[function_constant(110)]];
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
bool kAlignedM,
|
||||
bool kAlignedN,
|
||||
typename NAXTile_t,
|
||||
typename T>
|
||||
void gemm_epilogue(
|
||||
thread NAXTile_t& Dtile,
|
||||
const device T* C,
|
||||
const constant GEMMParams* params,
|
||||
const constant GEMMAddMMParams* addmm_params,
|
||||
const short sgp_sm,
|
||||
const short sgp_sn) { // clang-format on
|
||||
|
||||
(void)params;
|
||||
|
||||
constexpr short UM = NAXTile_t::kSubTileRows;
|
||||
constexpr short UN = NAXTile_t::kSubTileCols;
|
||||
using CSubTile = NAXSubTile<T, UM, UN>;
|
||||
|
||||
using V = typename NAXTile_t::elem_type;
|
||||
|
||||
constexpr short TM = NAXTile_t::kTileRows;
|
||||
constexpr short TN = NAXTile_t::kTileCols;
|
||||
constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short mm = 0; mm < TM; mm++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short nn = 0; nn < TN; nn++) {
|
||||
const short m = mm * UM;
|
||||
const short n = nn * UN;
|
||||
|
||||
CSubTile CTile;
|
||||
|
||||
if constexpr (kAlignedM && kAlignedN) {
|
||||
CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n);
|
||||
} else {
|
||||
CTile.load_safe(
|
||||
C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n);
|
||||
}
|
||||
|
||||
auto delems = Dtile.subtile_at(mm, nn).elems();
|
||||
auto celems = CTile.elems();
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kElemsPerSubTile; i++) {
|
||||
if (do_axpby) {
|
||||
delems[i] = addmm_params->alpha * delems[i] +
|
||||
addmm_params->beta * static_cast<V>(celems[i]);
|
||||
} else {
|
||||
delems[i] += static_cast<V>(celems[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device T* C [[buffer(2), function_constant(use_out_source)]],
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
|
||||
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on
|
||||
// Find block
|
||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
||||
|
||||
// Exit early if out of bounds
|
||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Adjust for batch
|
||||
if (has_batch) {
|
||||
const constant auto* A_bstrides = batch_strides;
|
||||
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
||||
|
||||
ulong2 batch_offsets = elem_to_loc_broadcast(
|
||||
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
||||
|
||||
A += batch_offsets.x;
|
||||
B += batch_offsets.y;
|
||||
|
||||
if (use_out_source) {
|
||||
const constant auto* C_bstrides = B_bstrides + params->batch_ndim;
|
||||
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
|
||||
}
|
||||
} else {
|
||||
A += params->batch_stride_a * tid.z;
|
||||
B += params->batch_stride_b * tid.z;
|
||||
|
||||
if (use_out_source) {
|
||||
C += addmm_params->batch_stride_c * tid.z;
|
||||
}
|
||||
}
|
||||
|
||||
D += params->batch_stride_d * tid.z;
|
||||
|
||||
// Prepare threadgroup memory
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Find block in A, B, C
|
||||
const int c_row = tid_y * BM;
|
||||
const int c_col = tid_x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
D += c_row_long * params->ldd + c_col_long;
|
||||
|
||||
if (use_out_source) {
|
||||
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
|
||||
}
|
||||
|
||||
constexpr short UM = 16;
|
||||
constexpr short UN = 32;
|
||||
constexpr short UK = 16;
|
||||
constexpr short SM = BM / WM;
|
||||
constexpr short SN = BN / WN;
|
||||
constexpr short SK = 32;
|
||||
|
||||
constexpr short TM = SM / UM;
|
||||
constexpr short TN = SN / UN;
|
||||
|
||||
const short tm = SM * (simd_group_id / WN);
|
||||
const short tn = SN * (simd_group_id % WN);
|
||||
|
||||
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
|
||||
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
|
||||
|
||||
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
|
||||
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
|
||||
|
||||
A += transpose_a ? tm : (tm * params->lda);
|
||||
B += transpose_b ? (tn * params->ldb) : tn;
|
||||
D += tm * params->ldd + tn;
|
||||
|
||||
if (use_out_source) {
|
||||
C += tm * addmm_params->ldc + tn * addmm_params->fdc;
|
||||
}
|
||||
|
||||
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
||||
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
|
||||
|
||||
dispatch_bool(align_K, [&](auto kAlignedK) {
|
||||
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
|
||||
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
|
||||
Dtile = gemm_loop<
|
||||
T,
|
||||
SM,
|
||||
SN,
|
||||
SK,
|
||||
BK,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
kAlignedM.value,
|
||||
kAlignedN.value,
|
||||
kAlignedK.value,
|
||||
UM,
|
||||
UN,
|
||||
UK,
|
||||
AccumType>(A, B, params, sgp_sm, sgp_sn);
|
||||
if (use_out_source) {
|
||||
gemm_epilogue<kAlignedM.value, kAlignedN.value>(
|
||||
Dtile, C, params, addmm_params, sgp_sm, sgp_sn);
|
||||
}
|
||||
if constexpr (kAlignedM && kAlignedN) {
|
||||
Dtile.store(D, int(params->ldd));
|
||||
} else {
|
||||
Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h"
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
"steel_gemm_fused_nax_" #tname "_" #iname "_" #oname \
|
||||
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float)
|
||||
|
||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,132 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
using namespace mlx::steel;
|
||||
|
||||
constant bool align_M [[function_constant(200)]];
|
||||
constant bool align_N [[function_constant(201)]];
|
||||
constant bool align_K [[function_constant(202)]];
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
typename AccumType = float>
|
||||
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
||||
gather_mm_rhs_nax(
|
||||
const device T* A [[buffer(0)]],
|
||||
const device T* B [[buffer(1)]],
|
||||
const device uint32_t* rhs_indices [[buffer(2)]],
|
||||
device T* C [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
constexpr short UM = 16;
|
||||
constexpr short UN = 32;
|
||||
constexpr short UK = 16;
|
||||
constexpr short SM = BM / WM;
|
||||
constexpr short SN = BN / WN;
|
||||
constexpr short SK = 32;
|
||||
constexpr short TM = SM / UM;
|
||||
constexpr short TN = SN / UN;
|
||||
|
||||
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
||||
params->tiles_m <= static_cast<int>(tid.y)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find the block in A, B, C
|
||||
const int c_row = tid.y * BM;
|
||||
const int c_col = tid.x * BN;
|
||||
const size_t c_row_long = size_t(c_row);
|
||||
const size_t c_col_long = size_t(c_col);
|
||||
|
||||
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
||||
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
||||
C += c_row_long * params->ldd + c_col_long;
|
||||
rhs_indices += c_row;
|
||||
|
||||
const short tm = SM * (simd_group_id / WN);
|
||||
const short tn = SN * (simd_group_id % WN);
|
||||
|
||||
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
|
||||
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
|
||||
|
||||
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
|
||||
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
|
||||
|
||||
A += transpose_a ? tm : (tm * params->lda);
|
||||
B += transpose_b ? (tn * params->ldb) : tn;
|
||||
C += tm * params->ldd + tn;
|
||||
rhs_indices += tm;
|
||||
|
||||
// Do as many matmuls as necessary
|
||||
uint32_t index;
|
||||
short offset;
|
||||
uint32_t index_next = rhs_indices[0];
|
||||
short offset_next = 0;
|
||||
int n = 0;
|
||||
while (n < sgp_sm) {
|
||||
n++;
|
||||
offset = offset_next;
|
||||
index = index_next;
|
||||
offset_next = sgp_sm;
|
||||
for (; n < sgp_sm; n++) {
|
||||
if (rhs_indices[n] != index) {
|
||||
offset_next = n;
|
||||
index_next = rhs_indices[n];
|
||||
break;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
||||
NAXTile<AccumType, TM, TN, DSubTile> Ctile;
|
||||
|
||||
dispatch_bool(align_K, [&](auto kAlignedK) {
|
||||
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
|
||||
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
|
||||
auto do_gemm = gemm_loop<
|
||||
T,
|
||||
SM,
|
||||
SN,
|
||||
SK,
|
||||
BK,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
kAlignedM.value,
|
||||
kAlignedN.value,
|
||||
kAlignedK.value,
|
||||
UM,
|
||||
UN,
|
||||
UK,
|
||||
AccumType>;
|
||||
Ctile = do_gemm(
|
||||
A, B + index * params->batch_stride_b, params, sgp_sm, sgp_sn);
|
||||
|
||||
if constexpr (kAlignedN.value) {
|
||||
if (offset_next - offset == SM) {
|
||||
Ctile.store(C, int(params->ldd));
|
||||
} else {
|
||||
Ctile.store_slice(
|
||||
C,
|
||||
int(params->ldd),
|
||||
short2(0, offset),
|
||||
short2(SN, offset_next));
|
||||
}
|
||||
} else {
|
||||
Ctile.store_slice(
|
||||
C,
|
||||
int(params->ldd),
|
||||
short2(0, offset),
|
||||
short2(sgp_sn, offset_next));
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/nax.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_kernel( \
|
||||
"steel_gather_mm_rhs_nax_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \
|
||||
"_bk" #bk "_wm" #wm "_wn" #wn, \
|
||||
gather_mm_rhs_nax, \
|
||||
itype, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
trans_a, \
|
||||
trans_b, \
|
||||
float)
|
||||
|
||||
#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
||||
instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
||||
|
||||
#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \
|
||||
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 128, 128, 1, 4) \
|
||||
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 32, 128, 128, 1, 4) \
|
||||
instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 64, 128, 128, 2, 4)
|
||||
// clang-format on
|
||||
|
||||
instantiate_gather_mm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat);
|
||||
1084
mlx/backend/metal/kernels/steel/gemm/nax.h
Normal file
1084
mlx/backend/metal/kernels/steel/gemm/nax.h
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user