mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
89 Commits
v0.29.4
...
1b591ec736
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1b591ec736 | ||
|
|
47d2505ea9 | ||
|
|
bedefed784 | ||
|
|
ccaaa7d6df | ||
|
|
f3e5ca5414 | ||
|
|
81dfe5f137 | ||
|
|
012fb220a1 | ||
|
|
e1fee0074b | ||
|
|
3c8ce9b00e | ||
|
|
937ce79660 | ||
|
|
208f5441a7 | ||
|
|
b862d842e1 | ||
|
|
f7a400951a | ||
|
|
27232db1ba | ||
|
|
a4b3bc969b | ||
|
|
667c0f3bb9 | ||
|
|
6245824d42 | ||
|
|
39289ef025 | ||
|
|
aefc9bd3f6 | ||
|
|
997cfc7699 | ||
|
|
1fa8dc5797 | ||
|
|
a6d6717181 | ||
|
|
941cfe23d7 | ||
|
|
9abb0b8123 | ||
|
|
50d3914c67 | ||
|
|
cacbdbf995 | ||
|
|
193cdcd81a | ||
|
|
d8ceae7b77 | ||
|
|
eff0e31f00 | ||
|
|
6c5785bc2f | ||
|
|
8879ee00eb | ||
|
|
6e762fe2e2 | ||
|
|
2b95d0c270 | ||
|
|
b054838780 | ||
|
|
dd79d3c465 | ||
|
|
704fd1ae28 | ||
|
|
c9f4dc851f | ||
|
|
f8bd675655 | ||
|
|
23a9168d34 | ||
|
|
bca205e287 | ||
|
|
1d4eacb737 | ||
|
|
8abd37ad05 | ||
|
|
3e05cea9f8 | ||
|
|
5b0f047226 | ||
|
|
618c87af8c | ||
|
|
d5f61a93fa | ||
|
|
4a09264236 | ||
|
|
0dbc7e5bee | ||
|
|
0d68efd461 | ||
|
|
f9e1a14135 | ||
|
|
d8e9ded928 | ||
|
|
60939d010c | ||
|
|
fdcd2923fd | ||
|
|
54f1cc6e3e | ||
|
|
b3825ac149 | ||
|
|
7f4b7e553c | ||
|
|
ad16f41a7f | ||
|
|
f46877bc08 | ||
|
|
6f35017d1b | ||
|
|
b167f0df1c | ||
|
|
a9f0d6b160 | ||
|
|
940f4c7818 | ||
|
|
35f81728f1 | ||
|
|
4442ed86c1 | ||
|
|
698559c231 | ||
|
|
ecc4879b07 | ||
|
|
32b18d8b66 | ||
|
|
472c43a0c8 | ||
|
|
b7214ff01e | ||
|
|
76414c8971 | ||
|
|
49e4566df3 | ||
|
|
aad49f932f | ||
|
|
86765cce34 | ||
|
|
1bedcbd556 | ||
|
|
9ac7dbe877 | ||
|
|
1bf605d56d | ||
|
|
3c622ddd1d | ||
|
|
27ff069175 | ||
|
|
3b2ffcefc3 | ||
|
|
b65f882df3 | ||
|
|
b704e9e77a | ||
|
|
66519fb348 | ||
|
|
8973550ff3 | ||
|
|
3f866be665 | ||
|
|
23f81ed1c1 | ||
|
|
3fe2250c00 | ||
|
|
047114b988 | ||
|
|
9320eb89a8 | ||
|
|
75819d70ea |
@@ -1,579 +0,0 @@
|
|||||||
version: 2.1
|
|
||||||
|
|
||||||
orbs:
|
|
||||||
apple: ml-explore/pr-approval@0.1.0
|
|
||||||
|
|
||||||
parameters:
|
|
||||||
nightly_build:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
test_release:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build_documentation:
|
|
||||||
parameters:
|
|
||||||
upload-docs:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
macos:
|
|
||||||
xcode: "26.0.0"
|
|
||||||
resource_class: m4pro.medium
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install
|
|
||||||
command: |
|
|
||||||
xcodebuild -downloadComponent MetalToolchain
|
|
||||||
brew install python@3.10
|
|
||||||
brew install doxygen
|
|
||||||
python3.10 -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install -r docs/requirements.txt
|
|
||||||
pip install . -v
|
|
||||||
- when:
|
|
||||||
condition:
|
|
||||||
not: << parameters.upload-docs >>
|
|
||||||
steps:
|
|
||||||
- run:
|
|
||||||
name: Build documentation
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
cd docs && doxygen && make html O=-W
|
|
||||||
- when:
|
|
||||||
condition: << parameters.upload-docs >>
|
|
||||||
steps:
|
|
||||||
- add_ssh_keys:
|
|
||||||
fingerprints:
|
|
||||||
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
|
||||||
- run:
|
|
||||||
name: Upload documentation
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
git config user.email "mlx@group.apple.com"
|
|
||||||
git config user.name "CircleCI Docs"
|
|
||||||
git checkout gh-pages
|
|
||||||
git rebase main
|
|
||||||
cd docs
|
|
||||||
git rm -rf build/html
|
|
||||||
doxygen && make html O=-W
|
|
||||||
git add -f build/html
|
|
||||||
git commit -m "rebase"
|
|
||||||
git push -f origin gh-pages
|
|
||||||
|
|
||||||
linux_build_and_test:
|
|
||||||
machine:
|
|
||||||
image: ubuntu-2204:current
|
|
||||||
resource_class: large
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Run style checks
|
|
||||||
command: |
|
|
||||||
pip install pre-commit
|
|
||||||
pre-commit run --all
|
|
||||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
export DEBIAN_FRONTEND=noninteractive
|
|
||||||
export NEEDRESTART_MODE=a
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
||||||
- run:
|
|
||||||
name: Install Python package
|
|
||||||
command: |
|
|
||||||
uv venv
|
|
||||||
uv pip install cmake
|
|
||||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
|
||||||
uv pip install -e ".[dev]" -v
|
|
||||||
- run:
|
|
||||||
name: Generate package stubs
|
|
||||||
command: |
|
|
||||||
uv pip install typing_extensions
|
|
||||||
uv run --no-project setup.py generate_stubs
|
|
||||||
- run:
|
|
||||||
name: Run Python tests
|
|
||||||
command: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
python -m unittest discover python/tests -v
|
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
|
||||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
|
||||||
- run:
|
|
||||||
name: Build CPP only
|
|
||||||
command: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
mkdir -p build && cd build
|
|
||||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
|
||||||
make -j `nproc`
|
|
||||||
- run:
|
|
||||||
name: Run CPP tests
|
|
||||||
command: ./build/tests/tests
|
|
||||||
|
|
||||||
mac_build_and_test:
|
|
||||||
parameters:
|
|
||||||
xcode_version:
|
|
||||||
type: string
|
|
||||||
default: "26.0.0"
|
|
||||||
macosx_deployment_target:
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
macos:
|
|
||||||
xcode: << parameters.xcode_version >>
|
|
||||||
environment:
|
|
||||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
|
||||||
resource_class: m4pro.medium
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
xcodebuild -downloadComponent MetalToolchain
|
|
||||||
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
|
||||||
brew install openmpi uv
|
|
||||||
- run:
|
|
||||||
name: Install Python package
|
|
||||||
command: |
|
|
||||||
uv venv --python 3.10
|
|
||||||
uv pip install \
|
|
||||||
nanobind==2.4.0 \
|
|
||||||
cmake \
|
|
||||||
numpy \
|
|
||||||
torch \
|
|
||||||
tensorflow \
|
|
||||||
unittest-xml-reporting
|
|
||||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
|
||||||
uv pip install -e . -v
|
|
||||||
- run:
|
|
||||||
name: Generate package stubs
|
|
||||||
command: |
|
|
||||||
uv pip install typing_extensions
|
|
||||||
uv run --no-project setup.py generate_stubs
|
|
||||||
- run:
|
|
||||||
name: Run Python tests
|
|
||||||
command: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
|
||||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
|
||||||
- run:
|
|
||||||
name: Build example extension
|
|
||||||
command: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
cd examples/extensions
|
|
||||||
uv pip install -r requirements.txt
|
|
||||||
uv run --no-project setup.py build_ext --inplace
|
|
||||||
uv run --no-project python test.py
|
|
||||||
- store_test_results:
|
|
||||||
path: test-results
|
|
||||||
- run:
|
|
||||||
name: Build CPP only
|
|
||||||
command: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
|
||||||
- run:
|
|
||||||
name: Run CPP tests
|
|
||||||
command: |
|
|
||||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
|
||||||
- run:
|
|
||||||
name: Build small binary
|
|
||||||
command: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
cd build/
|
|
||||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
|
||||||
-DBUILD_SHARED_LIBS=ON \
|
|
||||||
-DMLX_BUILD_CPU=OFF \
|
|
||||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
|
||||||
-DMLX_BUILD_GGUF=OFF \
|
|
||||||
-DMLX_METAL_JIT=ON
|
|
||||||
make -j `sysctl -n hw.ncpu`
|
|
||||||
- run:
|
|
||||||
name: Run Python tests with JIT
|
|
||||||
command: |
|
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
|
||||||
uv pip install -e . -v
|
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
|
||||||
METAL_DEBUG_ERROR_MODE=0 \
|
|
||||||
uv run --no-project python -m xmlrunner discover \
|
|
||||||
-v python/tests \
|
|
||||||
-o test-results/gpu_jit
|
|
||||||
|
|
||||||
cuda_build_and_test:
|
|
||||||
parameters:
|
|
||||||
image_date:
|
|
||||||
type: string
|
|
||||||
default: "2023.11.1"
|
|
||||||
machine:
|
|
||||||
image: "linux-cuda-12:<< parameters.image_date >>"
|
|
||||||
resource_class: gpu.nvidia.small.gen2
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- restore_cache:
|
|
||||||
keys:
|
|
||||||
- cuda-<< parameters.image_date >>-{{ arch }}-
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install libcudnn9-dev-cuda-12
|
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
sudo apt-get install libnccl2 libnccl-dev
|
|
||||||
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
|
||||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
|
||||||
rm -rf ccache-4.11.3-linux-x86_64
|
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
||||||
- run:
|
|
||||||
name: Set CCache size
|
|
||||||
command: ccache --max-size 1G
|
|
||||||
- run:
|
|
||||||
name: Install Python package
|
|
||||||
command: |
|
|
||||||
uv venv
|
|
||||||
uv pip install cmake
|
|
||||||
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
|
||||||
uv pip install -e ".[dev]" -v
|
|
||||||
- run:
|
|
||||||
name: Run Python tests
|
|
||||||
command: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
|
||||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
|
||||||
- run:
|
|
||||||
name: Build CPP only
|
|
||||||
command: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
cmake . -B build \
|
|
||||||
-DMLX_BUILD_CUDA=ON \
|
|
||||||
-DCMAKE_CUDA_COMPILER=`which nvcc` \
|
|
||||||
-DCMAKE_BUILD_TYPE=DEBUG
|
|
||||||
cmake --build build -j `nproc`
|
|
||||||
- run:
|
|
||||||
name: Run CPP tests
|
|
||||||
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
|
||||||
- run:
|
|
||||||
name: CCache report
|
|
||||||
command: |
|
|
||||||
ccache --show-stats
|
|
||||||
ccache --zero-stats
|
|
||||||
ccache --cleanup
|
|
||||||
- save_cache:
|
|
||||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
|
||||||
paths:
|
|
||||||
- /home/circleci/.cache/ccache
|
|
||||||
|
|
||||||
build_release:
|
|
||||||
parameters:
|
|
||||||
python_version:
|
|
||||||
type: string
|
|
||||||
default: "3.10"
|
|
||||||
xcode_version:
|
|
||||||
type: string
|
|
||||||
default: "26.0.0"
|
|
||||||
build_env:
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
macosx_deployment_target:
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
macos:
|
|
||||||
xcode: << parameters.xcode_version >>
|
|
||||||
resource_class: m4pro.medium
|
|
||||||
environment:
|
|
||||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
xcodebuild -downloadComponent MetalToolchain
|
|
||||||
mkdir -p ~/miniconda3
|
|
||||||
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
|
||||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
|
||||||
rm ~/miniconda3/miniconda.sh
|
|
||||||
source ~/miniconda3/bin/activate
|
|
||||||
conda init --all
|
|
||||||
conda create -n env python=<< parameters.python_version >> -y
|
|
||||||
conda activate env
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install nanobind==2.4.0
|
|
||||||
pip install --upgrade setuptools
|
|
||||||
pip install numpy
|
|
||||||
pip install twine
|
|
||||||
pip install build
|
|
||||||
- run:
|
|
||||||
name: Install Python package
|
|
||||||
command: |
|
|
||||||
conda activate env
|
|
||||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
|
||||||
pip install . -v
|
|
||||||
- run:
|
|
||||||
name: Generate package stubs
|
|
||||||
command: |
|
|
||||||
conda activate env
|
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
|
||||||
name: Build Python package
|
|
||||||
command: |
|
|
||||||
conda activate env
|
|
||||||
python setup.py clean --all
|
|
||||||
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
|
||||||
- when:
|
|
||||||
condition:
|
|
||||||
equal: ["3.10", << parameters.python_version >>]
|
|
||||||
steps:
|
|
||||||
- run:
|
|
||||||
name: Build common package
|
|
||||||
command: |
|
|
||||||
conda activate env
|
|
||||||
python setup.py clean --all
|
|
||||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
|
||||||
- when:
|
|
||||||
condition: << parameters.build_env >>
|
|
||||||
steps:
|
|
||||||
- run:
|
|
||||||
name: Upload package
|
|
||||||
command: |
|
|
||||||
conda activate env
|
|
||||||
twine upload dist/*
|
|
||||||
- store_artifacts:
|
|
||||||
path: dist/
|
|
||||||
|
|
||||||
build_linux_release:
|
|
||||||
parameters:
|
|
||||||
python_version:
|
|
||||||
type: string
|
|
||||||
default: "3.10"
|
|
||||||
build_env:
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
machine:
|
|
||||||
image: ubuntu-2204:current
|
|
||||||
resource_class: large
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Build wheel
|
|
||||||
command: |
|
|
||||||
PYTHON=python<< parameters.python_version >>
|
|
||||||
export DEBIAN_FRONTEND=noninteractive
|
|
||||||
export NEEDRESTART_MODE=a
|
|
||||||
sudo apt-get update
|
|
||||||
TZ=Etc/UTC sudo apt-get -y install tzdata
|
|
||||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
|
||||||
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
|
||||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
$PYTHON -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install auditwheel
|
|
||||||
pip install patchelf
|
|
||||||
pip install build
|
|
||||||
pip install twine
|
|
||||||
<< parameters.build_env >> pip install ".[dev]" -v
|
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
|
||||||
python setup.py clean --all
|
|
||||||
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
|
||||||
bash python/scripts/repair_linux.sh
|
|
||||||
- when:
|
|
||||||
condition:
|
|
||||||
equal: ["3.10", << parameters.python_version >>]
|
|
||||||
steps:
|
|
||||||
- run:
|
|
||||||
name: Build common package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
python setup.py clean --all
|
|
||||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
|
||||||
python -m build -w
|
|
||||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
|
||||||
- when:
|
|
||||||
condition: << parameters.build_env >>
|
|
||||||
steps:
|
|
||||||
- run:
|
|
||||||
name: Upload packages
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
twine upload wheelhouse/*.whl
|
|
||||||
- store_artifacts:
|
|
||||||
path: wheelhouse/
|
|
||||||
|
|
||||||
build_cuda_release:
|
|
||||||
parameters:
|
|
||||||
build_env:
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
machine:
|
|
||||||
image: ubuntu-2204:current
|
|
||||||
resource_class: xlarge
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Build wheel
|
|
||||||
command: |
|
|
||||||
export DEBIAN_FRONTEND=noninteractive
|
|
||||||
export NEEDRESTART_MODE=a
|
|
||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
|
||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
sudo apt-get install zip
|
|
||||||
pip install auditwheel
|
|
||||||
pip install patchelf
|
|
||||||
pip install build
|
|
||||||
pip install twine
|
|
||||||
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
|
||||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
|
||||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
|
||||||
python -m build -w
|
|
||||||
bash python/scripts/repair_cuda.sh
|
|
||||||
- when:
|
|
||||||
condition: << parameters.build_env >>
|
|
||||||
steps:
|
|
||||||
- run:
|
|
||||||
name: Upload package
|
|
||||||
command: |
|
|
||||||
twine upload wheelhouse/*.whl
|
|
||||||
- store_artifacts:
|
|
||||||
path: wheelhouse/
|
|
||||||
|
|
||||||
workflows:
|
|
||||||
build_and_test:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- matches:
|
|
||||||
pattern: "^(?!pull/)[-\\w]+$"
|
|
||||||
value: << pipeline.git.branch >>
|
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
|
||||||
jobs:
|
|
||||||
- mac_build_and_test:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
macosx_deployment_target: ["13.5", "15.0"]
|
|
||||||
- linux_build_and_test
|
|
||||||
- cuda_build_and_test:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
image_date: ["2023.11.1", "2025.05.1"]
|
|
||||||
- build_documentation
|
|
||||||
|
|
||||||
build_pypi_release:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
|
||||||
jobs:
|
|
||||||
- build_release:
|
|
||||||
filters:
|
|
||||||
tags:
|
|
||||||
only: /^v.*/
|
|
||||||
branches:
|
|
||||||
ignore: /.*/
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
|
||||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
|
||||||
build_env: ["PYPI_RELEASE=1"]
|
|
||||||
xcode_version: ["26.0.0"]
|
|
||||||
- build_documentation:
|
|
||||||
filters:
|
|
||||||
tags:
|
|
||||||
only: /^v.*/
|
|
||||||
branches:
|
|
||||||
ignore: /.*/
|
|
||||||
upload-docs: true
|
|
||||||
- build_linux_release:
|
|
||||||
filters:
|
|
||||||
tags:
|
|
||||||
only: /^v.*/
|
|
||||||
branches:
|
|
||||||
ignore: /.*/
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
|
||||||
build_env: ["PYPI_RELEASE=1"]
|
|
||||||
- build_cuda_release:
|
|
||||||
filters:
|
|
||||||
tags:
|
|
||||||
only: /^v.*/
|
|
||||||
branches:
|
|
||||||
ignore: /.*/
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
build_env: ["PYPI_RELEASE=1"]
|
|
||||||
|
|
||||||
prb:
|
|
||||||
when:
|
|
||||||
matches:
|
|
||||||
pattern: "^pull/\\d+(/head)?$"
|
|
||||||
value: << pipeline.git.branch >>
|
|
||||||
jobs:
|
|
||||||
- hold:
|
|
||||||
type: approval
|
|
||||||
- apple/authenticate:
|
|
||||||
context: pr-approval
|
|
||||||
- mac_build_and_test:
|
|
||||||
requires: [ hold ]
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
macosx_deployment_target: ["13.5", "15.0"]
|
|
||||||
- linux_build_and_test:
|
|
||||||
requires: [ hold ]
|
|
||||||
- cuda_build_and_test:
|
|
||||||
requires: [ hold ]
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
image_date: ["2023.11.1", "2025.05.1"]
|
|
||||||
nightly_build:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
|
||||||
- << pipeline.parameters.nightly_build >>
|
|
||||||
jobs:
|
|
||||||
- build_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
|
||||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
|
||||||
xcode_version: ["26.0.0"]
|
|
||||||
- build_linux_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
|
||||||
- build_cuda_release
|
|
||||||
|
|
||||||
build_dev_release:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
|
||||||
- << pipeline.parameters.test_release >>
|
|
||||||
jobs:
|
|
||||||
- build_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
|
||||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
|
||||||
build_env: ["DEV_RELEASE=1"]
|
|
||||||
xcode_version: ["26.0.0"]
|
|
||||||
- build_linux_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
|
||||||
build_env: ["DEV_RELEASE=1"]
|
|
||||||
- build_cuda_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
build_env: ["DEV_RELEASE=1"]
|
|
||||||
12
.github/actions/build-cuda-release/action.yml
vendored
12
.github/actions/build-cuda-release/action.yml
vendored
@@ -2,9 +2,13 @@ name: 'Build CUDA wheel'
|
|||||||
description: 'Build CUDA wheel'
|
description: 'Build CUDA wheel'
|
||||||
|
|
||||||
inputs:
|
inputs:
|
||||||
nvcc-location:
|
arch:
|
||||||
description: 'Location of nvcc compiler'
|
description: 'Platform architecture tag'
|
||||||
required: true
|
required: true
|
||||||
|
type: choice
|
||||||
|
options:
|
||||||
|
- x86_64
|
||||||
|
- aarch64
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
@@ -12,9 +16,9 @@ runs:
|
|||||||
- name: Build package
|
- name: Build package
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
|
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
|
||||||
run: |
|
run: |
|
||||||
pip install auditwheel build patchelf setuptools
|
pip install auditwheel build patchelf setuptools
|
||||||
python setup.py clean --all
|
python setup.py clean --all
|
||||||
MLX_BUILD_STAGE=2 python -m build -w
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
bash python/scripts/repair_cuda.sh
|
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}
|
||||||
|
|||||||
45
.github/actions/build-cuda/action.yml
vendored
45
.github/actions/build-cuda/action.yml
vendored
@@ -1,45 +0,0 @@
|
|||||||
name: 'Build and Test with CUDA'
|
|
||||||
description: 'Build and test MLX with CUDA'
|
|
||||||
|
|
||||||
inputs:
|
|
||||||
nvcc-location:
|
|
||||||
description: 'Location of nvcc compiler'
|
|
||||||
required: true
|
|
||||||
default: '/usr/local/cuda-12.9/bin/nvcc'
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: "composite"
|
|
||||||
steps:
|
|
||||||
- name: Install Python package
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
DEBUG: 1
|
|
||||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }}
|
|
||||||
run: pip install -e ".[dev]" -v
|
|
||||||
|
|
||||||
- name: Run Python tests - CPU
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
LOW_MEMORY: 1
|
|
||||||
DEVICE: cpu
|
|
||||||
run: python -m unittest discover python/tests -v
|
|
||||||
|
|
||||||
- name: Run Python tests - GPU
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
LOW_MEMORY: 1
|
|
||||||
DEVICE: gpu
|
|
||||||
run: python -m tests discover python/tests -v
|
|
||||||
|
|
||||||
- name: Build CPP only
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
cmake . -B build \
|
|
||||||
-DMLX_BUILD_CUDA=ON \
|
|
||||||
-DCMAKE_CUDA_COMPILER=${{ inputs.nvcc-location }} \
|
|
||||||
-DCMAKE_BUILD_TYPE=DEBUG
|
|
||||||
cmake --build build -j $(nproc)
|
|
||||||
|
|
||||||
- name: Run CPP tests
|
|
||||||
shell: bash
|
|
||||||
run: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
|
||||||
18
.github/actions/build-docs/action.yml
vendored
18
.github/actions/build-docs/action.yml
vendored
@@ -1,19 +1,19 @@
|
|||||||
name: 'Build Documentation'
|
name: 'Build Documentation'
|
||||||
description: 'Build documentation on a mac'
|
description: 'Build documentation'
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- name: Setup machine
|
- name: Setup machine
|
||||||
uses: ./.github/actions/setup-macos
|
uses: ./.github/actions/setup-linux
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
shell: sh
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
brew install doxygen
|
sudo apt-get install -y doxygen
|
||||||
uv pip install --upgrade pip cmake
|
source .venv/bin/activate
|
||||||
uv pip install -r docs/requirements.txt
|
pip install -r docs/requirements.txt
|
||||||
uv pip install . -v
|
pip install . -v
|
||||||
|
|
||||||
- name: Build documentation
|
- name: Build documentation
|
||||||
shell: bash
|
shell: bash
|
||||||
@@ -24,8 +24,8 @@ runs:
|
|||||||
make html O=-W
|
make html O=-W
|
||||||
|
|
||||||
- name: Create artifact tar
|
- name: Create artifact tar
|
||||||
shell: sh
|
shell: bash
|
||||||
run: tar -cf artifact.tar --cd docs --dereference build/html index.html
|
run: tar -cf artifact.tar -C docs --dereference build/html index.html
|
||||||
|
|
||||||
# Do it manually because upload-pages-artifact requires gtar
|
# Do it manually because upload-pages-artifact requires gtar
|
||||||
- name: Upload artifact
|
- name: Upload artifact
|
||||||
|
|||||||
11
.github/actions/build-linux-release/action.yml
vendored
11
.github/actions/build-linux-release/action.yml
vendored
@@ -7,6 +7,13 @@ inputs:
|
|||||||
type: boolean
|
type: boolean
|
||||||
required: false
|
required: false
|
||||||
default: false
|
default: false
|
||||||
|
arch:
|
||||||
|
description: 'Platform architecture tag'
|
||||||
|
required: true
|
||||||
|
type: choice
|
||||||
|
options:
|
||||||
|
- x86_64
|
||||||
|
- aarch64
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
@@ -23,11 +30,11 @@ runs:
|
|||||||
pip install auditwheel patchelf build
|
pip install auditwheel patchelf build
|
||||||
python setup.py clean --all
|
python setup.py clean --all
|
||||||
MLX_BUILD_STAGE=1 python -m build -w
|
MLX_BUILD_STAGE=1 python -m build -w
|
||||||
bash python/scripts/repair_linux.sh
|
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
|
||||||
- name: Build backend package
|
- name: Build backend package
|
||||||
if: ${{ inputs.build-backend }}
|
if: ${{ inputs.build-backend }}
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
python setup.py clean --all
|
python setup.py clean --all
|
||||||
MLX_BUILD_STAGE=2 python -m build -w
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}
|
||||||
|
|||||||
44
.github/actions/build-linux/action.yml
vendored
44
.github/actions/build-linux/action.yml
vendored
@@ -1,15 +1,32 @@
|
|||||||
name: 'Build and Test on Linux'
|
name: 'Build and Test on Linux'
|
||||||
description: 'Build and test MLX on Linux'
|
|
||||||
|
inputs:
|
||||||
|
toolkit:
|
||||||
|
description: 'The toolkit to build with'
|
||||||
|
required: false
|
||||||
|
default: 'cpu'
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- name: Install Python package
|
- name: Install Python package
|
||||||
|
id: python_build
|
||||||
shell: sh
|
shell: sh
|
||||||
env:
|
env:
|
||||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
|
||||||
DEBUG: 1
|
DEBUG: 1
|
||||||
run: pip install -e ".[dev]" -v
|
CMAKE_ARGS: >-
|
||||||
|
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
|
||||||
|
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
|
||||||
|
run: |
|
||||||
|
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
|
||||||
|
# There is no GPU in arm64 runner, use a common arch.
|
||||||
|
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
|
||||||
|
# Can not build tests when the built executables can not run.
|
||||||
|
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
|
||||||
|
fi
|
||||||
|
pip install --no-build-isolation -e ".[dev]" -v
|
||||||
|
# Pass the CMAKE_ARGS to following steps.
|
||||||
|
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Generate package stubs
|
- name: Generate package stubs
|
||||||
shell: sh
|
shell: sh
|
||||||
@@ -17,25 +34,8 @@ runs:
|
|||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
|
|
||||||
- name: Run Python tests
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
python -m unittest discover python/tests -v
|
|
||||||
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
|
||||||
if grep -Fq '[WARN]' stderr.log ; then
|
|
||||||
grep -F '[WARN]' stderr.log
|
|
||||||
echo "Distributed ring test failed";
|
|
||||||
exit 1;
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Build CPP only
|
- name: Build CPP only
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
mkdir -p build && cd build
|
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
|
||||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
cmake --build build -j $(nproc)
|
||||||
make -j $(nproc)
|
|
||||||
|
|
||||||
- name: Run CPP tests
|
|
||||||
shell: sh
|
|
||||||
run: ./build/tests/tests
|
|
||||||
|
|||||||
15
.github/actions/build-macos-release/action.yml
vendored
15
.github/actions/build-macos-release/action.yml
vendored
@@ -16,18 +16,19 @@ runs:
|
|||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- name: Build Python package
|
- name: Build Python package
|
||||||
shell: bash
|
shell: bash -l {0}
|
||||||
env:
|
env:
|
||||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||||
run: |
|
run: |
|
||||||
uv pip install build
|
pip install build
|
||||||
uv run --no-project setup.py clean --all
|
python setup.py clean --all
|
||||||
MLX_BUILD_STAGE=1 uv run -m build -w
|
MLX_BUILD_STAGE=1 python -m build -w
|
||||||
|
|
||||||
- name: Build backend package
|
- name: Build backend package
|
||||||
if: ${{ inputs.build-backend }}
|
if: ${{ inputs.build-backend }}
|
||||||
shell: bash
|
shell: bash -l {0}
|
||||||
env:
|
env:
|
||||||
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||||
run: |
|
run: |
|
||||||
uv run --no-project setup.py clean --all
|
python setup.py clean --all
|
||||||
MLX_BUILD_STAGE=2 uv run -m build -w
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
|
|||||||
44
.github/actions/build-macos/action.yml
vendored
44
.github/actions/build-macos/action.yml
vendored
@@ -5,47 +5,47 @@ runs:
|
|||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
shell: sh
|
|
||||||
env:
|
env:
|
||||||
DEBUG: 1
|
DEBUG: 1
|
||||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||||
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
uv pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
uv pip install cmake setuptools nanobind==2.4.0
|
pip install cmake setuptools nanobind==2.10.2
|
||||||
uv pip install -e . -v
|
pip install -e . -v
|
||||||
|
|
||||||
- name: Generate package stubs
|
- name: Generate package stubs
|
||||||
shell: bash
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
uv pip install typing_extensions
|
pip install typing_extensions
|
||||||
uv run --no-project setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
|
|
||||||
- name: Install tests dependencies
|
- name: Install tests dependencies
|
||||||
shell: sh
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
uv pip install numpy torch tensorflow unittest-xml-reporting
|
pip install numpy torch tensorflow unittest-xml-reporting
|
||||||
|
|
||||||
- name: Run Python tests
|
- name: Run Python tests
|
||||||
shell: bash
|
shell: bash -l {0}
|
||||||
env:
|
env:
|
||||||
LOW_MEMORY: 1
|
LOW_MEMORY: 1
|
||||||
run: |
|
run: |
|
||||||
DEVICE=cpu uv run -m xmlrunner discover -v python/tests -o test-results/cpu
|
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 uv run -m xmlrunner discover -v python/tests -o test-results/gpu
|
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||||
|
|
||||||
- name: Build example extension
|
- name: Build example extension
|
||||||
shell: bash
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
cd examples/extensions
|
cd examples/extensions
|
||||||
uv pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
uv run --no-project setup.py build_ext --inplace
|
python setup.py build_ext --inplace
|
||||||
uv run --no-project test.py
|
python test.py
|
||||||
|
|
||||||
- name: Build CPP only
|
- name: Build CPP only
|
||||||
shell: bash
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
mkdir -p build
|
mkdir -p build
|
||||||
cd build
|
cd build
|
||||||
@@ -53,7 +53,7 @@ runs:
|
|||||||
make -j $(sysctl -n hw.ncpu)
|
make -j $(sysctl -n hw.ncpu)
|
||||||
|
|
||||||
- name: Run CPP tests
|
- name: Run CPP tests
|
||||||
shell: bash
|
shell: bash -l {0}
|
||||||
env:
|
env:
|
||||||
DEVICE: gpu
|
DEVICE: gpu
|
||||||
METAL_DEVICE_WRAPPER_TYPE: 1
|
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||||
@@ -61,7 +61,7 @@ runs:
|
|||||||
run: ./build/tests/tests
|
run: ./build/tests/tests
|
||||||
|
|
||||||
- name: Build small binary with JIT
|
- name: Build small binary with JIT
|
||||||
shell: bash
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
mkdir -p build
|
mkdir -p build
|
||||||
cd build
|
cd build
|
||||||
@@ -74,7 +74,7 @@ runs:
|
|||||||
make -j $(sysctl -n hw.ncpu)
|
make -j $(sysctl -n hw.ncpu)
|
||||||
|
|
||||||
- name: Run Python tests with JIT
|
- name: Run Python tests with JIT
|
||||||
shell: bash
|
shell: bash -l {0}
|
||||||
env:
|
env:
|
||||||
LOW_MEMORY: 1
|
LOW_MEMORY: 1
|
||||||
DEVICE: gpu
|
DEVICE: gpu
|
||||||
@@ -82,7 +82,7 @@ runs:
|
|||||||
METAL_DEBUG_ERROR_MODE: 0
|
METAL_DEBUG_ERROR_MODE: 0
|
||||||
run: |
|
run: |
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||||
uv pip install -e . -v
|
pip install -e . -v
|
||||||
uv run -m xmlrunner discover \
|
python -m xmlrunner discover \
|
||||||
-v python/tests \
|
-v python/tests \
|
||||||
-o test-results/gpu_jit
|
-o test-results/gpu_jit
|
||||||
|
|||||||
76
.github/actions/setup-linux/action.yml
vendored
76
.github/actions/setup-linux/action.yml
vendored
@@ -2,72 +2,82 @@ name: 'Setup Linux Environment'
|
|||||||
description: 'Install dependencies for Linux builds'
|
description: 'Install dependencies for Linux builds'
|
||||||
|
|
||||||
inputs:
|
inputs:
|
||||||
runner-type:
|
toolkit:
|
||||||
description: 'Whether to set this up as a linux or CUDA runner'
|
description: 'Which toolkit to install'
|
||||||
required: false
|
required: false
|
||||||
default: 'linux'
|
default: 'cpu'
|
||||||
type: choice
|
|
||||||
options:
|
|
||||||
- linux
|
|
||||||
- cuda
|
|
||||||
python-version:
|
python-version:
|
||||||
description: 'Version of python to set up'
|
description: 'Version of python to set up'
|
||||||
required: false
|
required: false
|
||||||
default: '3.10'
|
default: '3.10'
|
||||||
|
use-ccache:
|
||||||
|
description: 'Whether to enable ccache'
|
||||||
|
required: false
|
||||||
|
default: 'true'
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- name: Free disk space
|
|
||||||
shell: sh
|
|
||||||
if: inputs.runner-type == 'linux'
|
|
||||||
run: sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
|
||||||
|
|
||||||
- name: Install common dependencies
|
- name: Install common dependencies
|
||||||
env:
|
|
||||||
TZ: Etc/UTC
|
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev tzdata zip
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
|
||||||
sudo apt autoremove -y
|
|
||||||
|
- name: Use ccache
|
||||||
|
if: ${{ inputs.use-ccache == 'true' }}
|
||||||
|
uses: hendrikmuhs/ccache-action@v1.2
|
||||||
|
with:
|
||||||
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}
|
||||||
|
max-size: 1GB
|
||||||
|
# ccache-action bug: running "apt-get update" fails on large arm runner.
|
||||||
|
update-package-index: false
|
||||||
|
|
||||||
- uses: actions/setup-python@v6
|
- uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ inputs.python-version }}
|
||||||
cache: 'pip'
|
|
||||||
|
|
||||||
- name: setup python venv
|
- name: Setup Python venv
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
python -m venv .venv
|
python -m venv .venv
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
|
pip install setuptools cmake nanobind==2.10.2
|
||||||
echo PATH=$PATH >> $GITHUB_ENV
|
echo PATH=$PATH >> $GITHUB_ENV
|
||||||
pip install --upgrade pip cmake
|
# Make cmake search .venv for nanobind
|
||||||
|
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Install MPI
|
- name: Install MPI
|
||||||
if: inputs.runner-type == 'linux'
|
|
||||||
shell: bash
|
shell: bash
|
||||||
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
|
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
|
||||||
|
|
||||||
- name: Network CUDA installation from packages
|
- name: Install CUDA toolkit
|
||||||
id: install-cuda
|
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||||
if: inputs.runner-type == 'cuda'
|
shell: bash
|
||||||
env:
|
env:
|
||||||
TZ: Etc/UTC
|
# Note: the CI machine does not meet CUDA 13's driver requirement.
|
||||||
shell: bash ## Specific to Ubuntu 22.04 & Architecture x86_64
|
# Compatibility matrix:
|
||||||
|
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
||||||
|
PACKAGES: |
|
||||||
|
{
|
||||||
|
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
|
||||||
|
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
|
||||||
|
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
|
||||||
|
}
|
||||||
run: |
|
run: |
|
||||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
|
||||||
|
# Jetson specific. SBSA means Arm Server Base System Architecture.
|
||||||
|
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
|
||||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y libcudnn9-dev-cuda-12 libnccl2 libnccl-dev cuda-toolkit-12-9
|
sudo apt-get install -y \
|
||||||
# Note: This installs CUDA 12.9, which is the latest supported by cuDNN 9.x and works with the NVidia 570 drivers
|
libnccl2 libnccl-dev \
|
||||||
# cuda-toolkit by itself installs version 13 (+) and requires updated drives (580+), which require a reboot to function properly.
|
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
|
||||||
# Compatibility matrix: https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
|
||||||
# This also drops `nvcc` into `/usr/local/cuda-12.9/bin/nvcc` - but it's *not* on the default PATH
|
|
||||||
|
|
||||||
- name: Package and Driver Report
|
- name: CUDA packages and driver report
|
||||||
if: inputs.runner-type == 'cuda'
|
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get install -y ubuntu-drivers-common dkms
|
sudo apt-get install -y ubuntu-drivers-common dkms
|
||||||
|
|||||||
7
.github/actions/setup-macos/action.yml
vendored
7
.github/actions/setup-macos/action.yml
vendored
@@ -18,8 +18,7 @@ runs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
run: xcodebuild -showComponent MetalToolchain
|
run: xcodebuild -showComponent MetalToolchain
|
||||||
|
|
||||||
- name: Setup uv
|
- uses: conda-incubator/setup-miniconda@v3
|
||||||
uses: astral-sh/setup-uv@v6
|
|
||||||
with:
|
with:
|
||||||
python-version: ${{ inputs.python-version }}
|
miniconda-version: "latest"
|
||||||
activate-environment: true
|
python-version: ${{ inputs.python-version }}
|
||||||
|
|||||||
69
.github/actions/test-linux/action.yml
vendored
Normal file
69
.github/actions/test-linux/action.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
name: 'Run Linux tests'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
has-gpu:
|
||||||
|
description: 'Run GPU tests'
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Run MPI tests
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "::group::MPI tests"
|
||||||
|
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run distributed tests
|
||||||
|
if: ${{ inputs.has-gpu == 'false' }}
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "::group::Distributed tests"
|
||||||
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if grep -Fq '[WARN]' stderr.log ; then
|
||||||
|
grep -F '[WARN]' stderr.log
|
||||||
|
echo "Distributed ring test failed";
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run Python tests - CPU
|
||||||
|
if: ${{ inputs.has-gpu == 'false' }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: cpu
|
||||||
|
run: |
|
||||||
|
echo "::group::Python tests - CPU"
|
||||||
|
python -m unittest discover python/tests -v
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run Python tests - GPU
|
||||||
|
if: ${{ inputs.has-gpu == 'true' }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: gpu
|
||||||
|
run: |
|
||||||
|
echo "::group::Python tests - GPU"
|
||||||
|
python -m tests discover python/tests -v
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run CPP tests - CPU
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: cpu
|
||||||
|
run: |
|
||||||
|
echo "::group::CPP tests - CPU"
|
||||||
|
./build/tests/tests
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run CPP tests - GPU
|
||||||
|
if: ${{ inputs.has-gpu == 'true' }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: gpu
|
||||||
|
run: |
|
||||||
|
echo "::group::CPP tests - GPU"
|
||||||
|
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||||
|
echo "::endgroup::"
|
||||||
108
.github/workflows/build_and_test.yml
vendored
Normal file
108
.github/workflows/build_and_test.yml
vendored
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
name: Build and Test
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
# For testing CI without starting a pull request:
|
||||||
|
- test/*
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check_lint:
|
||||||
|
name: Check Lint
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: pre-commit/action@v3.0.1
|
||||||
|
|
||||||
|
linux_build_and_test:
|
||||||
|
name: Linux (cpu, ${{ matrix.arch }})
|
||||||
|
needs: check_lint
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: ['x86_64', 'aarch64']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
- uses: ./.github/actions/build-linux
|
||||||
|
- uses: ./.github/actions/test-linux
|
||||||
|
|
||||||
|
cuda_build_and_test:
|
||||||
|
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
needs: check_lint
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: ['x86_64', 'aarch64']
|
||||||
|
toolkit: ['cuda-12.6', 'cuda-12.9']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
with:
|
||||||
|
toolkit: ${{ matrix.toolkit }}
|
||||||
|
- uses: ./.github/actions/build-linux
|
||||||
|
with:
|
||||||
|
toolkit: ${{ matrix.toolkit }}
|
||||||
|
- uses: ./.github/actions/test-linux
|
||||||
|
if: matrix.arch == 'x86_64'
|
||||||
|
with:
|
||||||
|
has-gpu: true
|
||||||
|
|
||||||
|
mac_build_and_test:
|
||||||
|
name: macOS (${{ matrix.macos-target }})
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
macos-target: ["14.0", "15.0"]
|
||||||
|
runs-on: [self-hosted, macos]
|
||||||
|
env:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
|
||||||
|
needs: check_lint
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-macos
|
||||||
|
- uses: ./.github/actions/build-macos
|
||||||
|
|
||||||
|
build_documentation:
|
||||||
|
name: Build Documentation
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
needs: check_lint
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
|
linux_fedora_build_cpp:
|
||||||
|
name: Linux Fedora (${{ matrix.arch }})
|
||||||
|
needs: check_lint
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- host: ubuntu-22.04
|
||||||
|
arch: x86_64
|
||||||
|
- host: ubuntu-22.04-arm
|
||||||
|
arch: aarch64
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.host }}
|
||||||
|
container:
|
||||||
|
image: fedora:42
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: CPP Build Test - No Release
|
||||||
|
run: |
|
||||||
|
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
|
||||||
4
.github/workflows/documentation.yml
vendored
4
.github/workflows/documentation.yml
vendored
@@ -8,9 +8,9 @@ permissions:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
runs-on: [self-hosted, macos]
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/build-docs
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
deploy:
|
deploy:
|
||||||
|
|||||||
54
.github/workflows/nightly.yml
vendored
54
.github/workflows/nightly.yml
vendored
@@ -16,11 +16,12 @@ jobs:
|
|||||||
python_version: ["3.10", "3.14"]
|
python_version: ["3.10", "3.14"]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
- uses: ./.github/actions/build-linux-release
|
- uses: ./.github/actions/build-linux-release
|
||||||
with:
|
with:
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
arch: "x86_64"
|
||||||
- name: Upload mlx artifacts
|
- name: Upload mlx artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
@@ -39,14 +40,18 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
python_version: ["3.11", "3.12", "3.13", "3.14"]
|
||||||
runs-on: ubuntu-22.04
|
runner:
|
||||||
|
- ubuntu-22.04
|
||||||
|
- ubuntu-22.04-arm
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
- uses: ./.github/actions/build-linux
|
- uses: ./.github/actions/build-linux
|
||||||
|
- uses: ./.github/actions/test-linux
|
||||||
|
|
||||||
build_mac_release:
|
build_mac_release:
|
||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
@@ -55,12 +60,11 @@ jobs:
|
|||||||
python-version: ["3.10", "3.13"]
|
python-version: ["3.10", "3.13"]
|
||||||
runs-on: [self-hosted, macos]
|
runs-on: [self-hosted, macos]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-macos
|
- uses: ./.github/actions/setup-macos
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- uses: ./.github/actions/build-macos
|
- uses: ./.github/actions/build-macos
|
||||||
|
|
||||||
- name: Build macOS 15 package
|
- name: Build macOS 15 package
|
||||||
uses: ./.github/actions/build-macos-release
|
uses: ./.github/actions/build-macos-release
|
||||||
with:
|
with:
|
||||||
@@ -72,53 +76,21 @@ jobs:
|
|||||||
macos-target: 14.0
|
macos-target: 14.0
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
|
||||||
build_cuda_with_tests:
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
runs-on: gpu-t4-4-core
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v5
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
with:
|
|
||||||
runner-type: 'cuda'
|
|
||||||
- uses: ./.github/actions/build-cuda
|
|
||||||
|
|
||||||
build_cuda_release:
|
build_cuda_release:
|
||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: ubuntu-22-large
|
runs-on: ubuntu-22-large
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
runner-type: 'cuda'
|
toolkit: 'cuda-12.9'
|
||||||
- name: Build Python package
|
- name: Build Python package
|
||||||
uses: ./.github/actions/build-cuda-release
|
uses: ./.github/actions/build-cuda-release
|
||||||
with:
|
with:
|
||||||
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
|
toolkit: 'cuda-12.9'
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: mlx-cuda
|
name: mlx-cuda
|
||||||
path: wheelhouse/mlx_cuda-*.whl
|
path: wheelhouse/mlx_cuda-*.whl
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
|
|
||||||
linux_fedora_build_cpp:
|
|
||||||
name: Linux Fedora CPP Build (${{ matrix.arch }})
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- host: ubuntu-22.04
|
|
||||||
arch: x86_64
|
|
||||||
- host: ubuntu-22.04-arm
|
|
||||||
arch: aarch64
|
|
||||||
|
|
||||||
runs-on: ${{ matrix.host }}
|
|
||||||
container:
|
|
||||||
image: fedora:42
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v5
|
|
||||||
|
|
||||||
- name: CPP Build Test - No Release
|
|
||||||
run: |
|
|
||||||
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
|
|
||||||
|
|||||||
71
.github/workflows/pull_request.yml
vendored
71
.github/workflows/pull_request.yml
vendored
@@ -1,71 +0,0 @@
|
|||||||
name: Build and Test
|
|
||||||
|
|
||||||
on: pull_request
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check_lint:
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v5
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
- uses: pre-commit/action@v3.0.1
|
|
||||||
|
|
||||||
linux_build_and_test:
|
|
||||||
runs-on: ubuntu-22.04
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v5
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
- uses: ./.github/actions/build-linux
|
|
||||||
|
|
||||||
mac_build_and_test:
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
runs-on: [self-hosted, macos]
|
|
||||||
needs: check_lint
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v5
|
|
||||||
- uses: ./.github/actions/setup-macos
|
|
||||||
- uses: ./.github/actions/build-macos
|
|
||||||
|
|
||||||
cuda_build_and_test:
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
runs-on: gpu-t4-4-core
|
|
||||||
needs: check_lint
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v5
|
|
||||||
- uses: ./.github/actions/setup-linux
|
|
||||||
with:
|
|
||||||
runner-type: 'cuda'
|
|
||||||
- uses: ./.github/actions/build-cuda
|
|
||||||
|
|
||||||
build_documentation:
|
|
||||||
if: github.repository == 'ml-explore/mlx'
|
|
||||||
runs-on: [self-hosted, macos]
|
|
||||||
needs: check_lint
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v5
|
|
||||||
- uses: ./.github/actions/build-docs
|
|
||||||
|
|
||||||
linux_fedora_build_cpp:
|
|
||||||
name: Linux Fedora CPP Build (${{ matrix.arch }})
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- host: ubuntu-22.04
|
|
||||||
arch: x86_64
|
|
||||||
- host: ubuntu-22.04-arm
|
|
||||||
arch: aarch64
|
|
||||||
|
|
||||||
runs-on: ${{ matrix.host }}
|
|
||||||
container:
|
|
||||||
image: fedora:42
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v5
|
|
||||||
|
|
||||||
- name: CPP Build Test - No Release
|
|
||||||
run: |
|
|
||||||
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
|
|
||||||
82
.github/workflows/release.yml
vendored
82
.github/workflows/release.yml
vendored
@@ -5,6 +5,11 @@ on:
|
|||||||
tags:
|
tags:
|
||||||
- 'v*'
|
- 'v*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
dev_release:
|
||||||
|
description: "Do a dev release or regular release"
|
||||||
|
required: true
|
||||||
|
default: "false"
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
@@ -12,18 +17,15 @@ permissions:
|
|||||||
jobs:
|
jobs:
|
||||||
setup:
|
setup:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
|
||||||
pypi_env: ${{ github.event_name == 'push' && 'pypi' || 'test-pypi' }}
|
|
||||||
pypi_url: ${{ github.event_name == 'push' && 'https://upload.pypi.org/legacy/' || 'https://test.pypi.org/legacy/' }}
|
|
||||||
steps:
|
steps:
|
||||||
- name: Set publishing variables
|
- name: Set publishing variables
|
||||||
run: echo "Publishing setup complete"
|
run: echo "Publishing setup complete"
|
||||||
|
|
||||||
build_documentation:
|
build_documentation:
|
||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: [self-hosted, macos]
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/build-docs
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
deploy_documentation:
|
deploy_documentation:
|
||||||
@@ -45,27 +47,33 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
runs-on: ubuntu-22.04
|
arch: ['x86_64', 'aarch64']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
||||||
env:
|
env:
|
||||||
PYPI_RELEASE: 1
|
PYPI_RELEASE: 1
|
||||||
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
|
use-ccache: false
|
||||||
- uses: ./.github/actions/build-linux-release
|
- uses: ./.github/actions/build-linux-release
|
||||||
with:
|
with:
|
||||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
arch: ${{ matrix.arch }}
|
||||||
- name: Upload MLX artifacts
|
- name: Upload MLX artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: linux-wheels-${{ matrix.python_version }}
|
overwrite: true
|
||||||
|
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
|
||||||
path: wheelhouse/mlx-*.whl
|
path: wheelhouse/mlx-*.whl
|
||||||
- name: Upload CPU artifacts
|
- name: Upload CPU artifacts
|
||||||
if: matrix.python_version == '3.10'
|
if: matrix.python_version == '3.10'
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: mlx-cpu
|
overwrite: true
|
||||||
|
name: mlx-cpu-${{ matrix.arch }}
|
||||||
path: wheelhouse/mlx_cpu-*.whl
|
path: wheelhouse/mlx_cpu-*.whl
|
||||||
|
|
||||||
build_mac_release:
|
build_mac_release:
|
||||||
@@ -76,22 +84,25 @@ jobs:
|
|||||||
runs-on: [self-hosted, macos]
|
runs-on: [self-hosted, macos]
|
||||||
env:
|
env:
|
||||||
PYPI_RELEASE: 1
|
PYPI_RELEASE: 1
|
||||||
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-macos
|
- uses: ./.github/actions/setup-macos
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
shell: sh
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
uv pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
uv pip install cmake setuptools nanobind==2.4.0
|
pip install cmake setuptools nanobind==2.10.2
|
||||||
uv pip install -e . -v
|
pip install -e . -v
|
||||||
- name: Generate package stubs
|
- name: Generate package stubs
|
||||||
shell: bash
|
shell: bash -l {0}
|
||||||
run: |
|
run: |
|
||||||
uv pip install typing_extensions
|
pip install typing_extensions
|
||||||
uv run --no-project setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- name: Build macOS 14 package
|
- name: Build macOS 14 package
|
||||||
uses: ./.github/actions/build-macos-release
|
uses: ./.github/actions/build-macos-release
|
||||||
with:
|
with:
|
||||||
@@ -105,32 +116,41 @@ jobs:
|
|||||||
- name: Upload MLX artifacts
|
- name: Upload MLX artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
|
overwrite: true
|
||||||
name: mac-wheels-${{ matrix.python-version }}
|
name: mac-wheels-${{ matrix.python-version }}
|
||||||
path: dist/mlx-*.whl
|
path: dist/mlx-*.whl
|
||||||
- name: Upload Metal artifacts
|
- name: Upload Metal artifacts
|
||||||
if: matrix.python-version == '3.10'
|
if: matrix.python-version == '3.10'
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
|
overwrite: true
|
||||||
name: mlx-metal
|
name: mlx-metal
|
||||||
path: dist/mlx_metal-*.whl
|
path: dist/mlx_metal-*.whl
|
||||||
|
|
||||||
build_cuda_release:
|
build_cuda_release:
|
||||||
if: github.repository == 'ml-explore/mlx'
|
if: github.repository == 'ml-explore/mlx'
|
||||||
runs-on: ubuntu-22-large
|
strategy:
|
||||||
|
matrix:
|
||||||
|
arch: ['x86_64', 'aarch64']
|
||||||
|
toolkit: ['cuda-12.9', 'cuda-13.0']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
|
||||||
env:
|
env:
|
||||||
PYPI_RELEASE: 1
|
PYPI_RELEASE: 1
|
||||||
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: ./.github/actions/setup-linux
|
- uses: ./.github/actions/setup-linux
|
||||||
with:
|
with:
|
||||||
runner-type: 'cuda'
|
toolkit: ${{ matrix.toolkit }}
|
||||||
|
use-ccache: false
|
||||||
- name: Build Python package
|
- name: Build Python package
|
||||||
uses: ./.github/actions/build-cuda-release
|
uses: ./.github/actions/build-cuda-release
|
||||||
with:
|
with:
|
||||||
nvcc-location: '/usr/local/cuda-12.9/bin/nvcc'
|
arch: ${{ matrix.arch }}
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
|
overwrite: true
|
||||||
name: mlx-cuda
|
name: mlx-cuda
|
||||||
path: wheelhouse/mlx_cuda-*.whl
|
path: wheelhouse/mlx_cuda-*.whl
|
||||||
|
|
||||||
@@ -141,7 +161,7 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
id-token: write
|
id-token: write
|
||||||
environment:
|
environment:
|
||||||
name: ${{ needs.setup.outputs.pypi_env }}
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx
|
url: https://pypi.org/p/mlx
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v6
|
- uses: actions/download-artifact@v6
|
||||||
@@ -159,7 +179,7 @@ jobs:
|
|||||||
- name: Publish package distributions to PyPI
|
- name: Publish package distributions to PyPI
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
with:
|
with:
|
||||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
|
|
||||||
pypi-publish-cuda:
|
pypi-publish-cuda:
|
||||||
name: Upload CUDA release to PyPI
|
name: Upload CUDA release to PyPI
|
||||||
@@ -168,7 +188,7 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
id-token: write
|
id-token: write
|
||||||
environment:
|
environment:
|
||||||
name: ${{ needs.setup.outputs.pypi_env }}
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx-cuda
|
url: https://pypi.org/p/mlx-cuda
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v6
|
- uses: actions/download-artifact@v6
|
||||||
@@ -180,7 +200,7 @@ jobs:
|
|||||||
- name: Publish package distributions to PyPI
|
- name: Publish package distributions to PyPI
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
with:
|
with:
|
||||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
|
|
||||||
pypi-publish-cpu:
|
pypi-publish-cpu:
|
||||||
name: Upload CPU release to PyPI
|
name: Upload CPU release to PyPI
|
||||||
@@ -189,19 +209,20 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
id-token: write
|
id-token: write
|
||||||
environment:
|
environment:
|
||||||
name: ${{ needs.setup.outputs.pypi_env }}
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx-cpu
|
url: https://pypi.org/p/mlx-cpu
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v6
|
- uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: mlx-cpu
|
pattern: mlx-cpu-*
|
||||||
|
merge-multiple: true
|
||||||
path: dist
|
path: dist
|
||||||
- name: Display structure of downloaded files
|
- name: Display structure of downloaded files
|
||||||
run: ls -R dist
|
run: ls -R dist
|
||||||
- name: Publish package distributions to PyPI
|
- name: Publish package distributions to PyPI
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
with:
|
with:
|
||||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
|
|
||||||
pypi-publish-metal:
|
pypi-publish-metal:
|
||||||
name: Upload Metal release to PyPI
|
name: Upload Metal release to PyPI
|
||||||
@@ -210,7 +231,7 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
id-token: write
|
id-token: write
|
||||||
environment:
|
environment:
|
||||||
name: ${{ needs.setup.outputs.pypi_env }}
|
name: pypi
|
||||||
url: https://pypi.org/p/mlx-metal
|
url: https://pypi.org/p/mlx-metal
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/download-artifact@v6
|
- uses: actions/download-artifact@v6
|
||||||
@@ -222,5 +243,4 @@ jobs:
|
|||||||
- name: Publish package distributions to PyPI
|
- name: Publish package distributions to PyPI
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
with:
|
with:
|
||||||
repository-url: ${{ needs.setup.outputs.pypi_url }}
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
|
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ endif()
|
|||||||
if(MLX_USE_CCACHE)
|
if(MLX_USE_CCACHE)
|
||||||
find_program(CCACHE_PROGRAM ccache)
|
find_program(CCACHE_PROGRAM ccache)
|
||||||
if(CCACHE_PROGRAM)
|
if(CCACHE_PROGRAM)
|
||||||
|
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
|
||||||
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
@@ -272,7 +273,7 @@ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
|||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
message(STATUS "Building Python bindings.")
|
message(STATUS "Building Python bindings.")
|
||||||
find_package(
|
find_package(
|
||||||
Python 3.8
|
Python 3.10
|
||||||
COMPONENTS Interpreter Development.Module
|
COMPONENTS Interpreter Development.Module
|
||||||
REQUIRED)
|
REQUIRED)
|
||||||
execute_process(
|
execute_process(
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ void time_irregular_binary_ops_3D() {
|
|||||||
|
|
||||||
void time_irregular_binary_ops_4D() {
|
void time_irregular_binary_ops_4D() {
|
||||||
auto device = mx::default_device();
|
auto device = mx::default_device();
|
||||||
std::vector<int> shape = {8, 8, 512, 512};
|
mx::Shape shape = {8, 8, 512, 512};
|
||||||
auto a = mx::random::uniform(shape);
|
auto a = mx::random::uniform(shape);
|
||||||
auto b = mx::random::uniform(shape);
|
auto b = mx::random::uniform(shape);
|
||||||
|
|
||||||
@@ -115,7 +115,7 @@ void time_irregular_binary_ops_4D() {
|
|||||||
|
|
||||||
void time_irregular_reshape() {
|
void time_irregular_reshape() {
|
||||||
auto device = mx::default_device();
|
auto device = mx::default_device();
|
||||||
std::vector<int> shape;
|
mx::Shape shape;
|
||||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||||
return mx::reshape(a, shape, device);
|
return mx::reshape(a, shape, device);
|
||||||
};
|
};
|
||||||
@@ -170,7 +170,7 @@ void time_irregular_astype_1D() {
|
|||||||
void time_irregular_astype_2D() {
|
void time_irregular_astype_2D() {
|
||||||
auto device = mx::default_device();
|
auto device = mx::default_device();
|
||||||
int size = 2048;
|
int size = 2048;
|
||||||
std::vector<int> shape = {size, size};
|
mx::Shape shape = {size, size};
|
||||||
|
|
||||||
auto a = mx::random::uniform(shape);
|
auto a = mx::random::uniform(shape);
|
||||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
|||||||
212
benchmarks/python/masked_scatter.py
Normal file
212
benchmarks/python/masked_scatter.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from copy import copy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from matplotlib.ticker import FuncFormatter
|
||||||
|
|
||||||
|
RESULTS_DIR = "./results"
|
||||||
|
|
||||||
|
|
||||||
|
if not os.path.isdir(RESULTS_DIR):
|
||||||
|
os.mkdir(RESULTS_DIR)
|
||||||
|
|
||||||
|
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
|
||||||
|
|
||||||
|
TORCH_DEVICE = torch.device(
|
||||||
|
"mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
N_WARMUP = 5
|
||||||
|
N_ITER_BENCH = 50
|
||||||
|
N_ITER_FUNC = 20
|
||||||
|
|
||||||
|
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
|
||||||
|
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
|
||||||
|
D_TYPES = ("float32", "float16")
|
||||||
|
|
||||||
|
|
||||||
|
def _power_of_two_formatter(value, _position):
|
||||||
|
if value <= 0:
|
||||||
|
return ""
|
||||||
|
exponent = int(round(math.log2(value)))
|
||||||
|
if abs(value - (1 << exponent)) / value > 1e-6:
|
||||||
|
return f"{value:g}"
|
||||||
|
return f"$2^{{{exponent}}}$"
|
||||||
|
|
||||||
|
|
||||||
|
def torch_sync():
|
||||||
|
if TORCH_DEVICE.type == "cuda":
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elif TORCH_DEVICE.type == "mps":
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
|
||||||
|
outs = []
|
||||||
|
for _ in range(N_ITER_FUNC):
|
||||||
|
out = copy(self_arr)
|
||||||
|
out[mask_arr] = src_arr
|
||||||
|
outs.append(out)
|
||||||
|
mx.eval(outs)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
|
||||||
|
outs = []
|
||||||
|
for _ in range(N_ITER_FUNC):
|
||||||
|
out = self_tensor.clone()
|
||||||
|
out.masked_scatter_(mask_tensor, src_tensor)
|
||||||
|
outs.append(out)
|
||||||
|
torch_sync()
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
def measure(fn):
|
||||||
|
for _ in range(N_WARMUP):
|
||||||
|
fn()
|
||||||
|
start = time.perf_counter_ns()
|
||||||
|
for _ in range(N_ITER_BENCH):
|
||||||
|
fn()
|
||||||
|
end = time.perf_counter_ns()
|
||||||
|
return (end - start) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_touched(length, true_count, item_size):
|
||||||
|
mask_bytes = length
|
||||||
|
self_bytes = length * item_size * 2 # read + write
|
||||||
|
src_bytes = true_count * item_size
|
||||||
|
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
|
||||||
|
|
||||||
|
|
||||||
|
def build_case(length, density, np_dtype, torch_dtype):
|
||||||
|
true_count = max(1, int(round(length * density)))
|
||||||
|
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
|
||||||
|
mask_np = np.zeros(length, dtype=bool)
|
||||||
|
mask_np[:true_count] = True
|
||||||
|
rng.shuffle(mask_np)
|
||||||
|
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
|
||||||
|
|
||||||
|
self_mlx = mx.array(self_np)
|
||||||
|
mask_mlx = mx.array(mask_np)
|
||||||
|
src_mlx = mx.array(src_np)
|
||||||
|
|
||||||
|
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||||
|
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
|
||||||
|
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||||
|
|
||||||
|
# Correctness check once per configuration
|
||||||
|
mx_out = mx.array(self_np)
|
||||||
|
mx_out[mask_mlx] = src_mlx
|
||||||
|
mx.eval(mx_out)
|
||||||
|
torch_out = self_torch.clone()
|
||||||
|
torch_out.masked_scatter_(mask_torch, src_torch)
|
||||||
|
|
||||||
|
atol = 5e-3 if np_dtype == np.float16 else 1e-5
|
||||||
|
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
|
||||||
|
raise AssertionError("masked_scatter results diverged between MLX and Torch")
|
||||||
|
|
||||||
|
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_case(length, density, dtype):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
torch_dtype = getattr(torch, dtype)
|
||||||
|
(
|
||||||
|
self_mlx,
|
||||||
|
mask_mlx,
|
||||||
|
src_mlx,
|
||||||
|
self_torch,
|
||||||
|
mask_torch,
|
||||||
|
src_torch,
|
||||||
|
true_count,
|
||||||
|
) = build_case(length, density, np_dtype, torch_dtype)
|
||||||
|
|
||||||
|
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
|
||||||
|
time_torch = measure(
|
||||||
|
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
|
||||||
|
bytes_per_gb = float(1024**3)
|
||||||
|
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
|
||||||
|
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
|
||||||
|
|
||||||
|
return time_mlx, time_torch, mlx_gbps, torch_gbps
|
||||||
|
|
||||||
|
|
||||||
|
def plot_density(ax_perf, ax_speedup, density, dtype):
|
||||||
|
mlx_gbps = []
|
||||||
|
torch_gbps = []
|
||||||
|
mlx_times = []
|
||||||
|
torch_times = []
|
||||||
|
|
||||||
|
for length in VECTOR_LENGTHS:
|
||||||
|
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
|
||||||
|
mlx_gbps.append(gbps_mlx)
|
||||||
|
torch_gbps.append(gbps_torch)
|
||||||
|
mlx_times.append(t_mlx)
|
||||||
|
torch_times.append(t_torch)
|
||||||
|
|
||||||
|
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
|
||||||
|
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
|
||||||
|
ax_perf.set_xscale("log", base=2)
|
||||||
|
ax_perf.set_xticks(VECTOR_LENGTHS)
|
||||||
|
formatter = FuncFormatter(_power_of_two_formatter)
|
||||||
|
ax_perf.xaxis.set_major_formatter(formatter)
|
||||||
|
ax_perf.set_title(f"density={density:.2f}")
|
||||||
|
ax_perf.set_ylabel("GB/s")
|
||||||
|
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||||
|
ax_perf.legend()
|
||||||
|
|
||||||
|
speedup = np.array(torch_times) / np.array(mlx_times)
|
||||||
|
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
|
||||||
|
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
|
||||||
|
ax_speedup.set_xscale("log", base=2)
|
||||||
|
ax_speedup.set_xticks(VECTOR_LENGTHS)
|
||||||
|
ax_speedup.xaxis.set_major_formatter(formatter)
|
||||||
|
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
|
||||||
|
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
for dtype in D_TYPES:
|
||||||
|
fig, axs = plt.subplots(
|
||||||
|
len(MASK_DENSITIES),
|
||||||
|
2,
|
||||||
|
figsize=(10, 12),
|
||||||
|
layout="constrained",
|
||||||
|
sharex=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, density in enumerate(MASK_DENSITIES):
|
||||||
|
plot_density(axs[i][0], axs[i][1], density, dtype)
|
||||||
|
axs[i][0].set_xlabel("vector length")
|
||||||
|
axs[i][1].set_xlabel("vector length")
|
||||||
|
|
||||||
|
fig.suptitle(
|
||||||
|
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
|
||||||
|
)
|
||||||
|
output_path = os.path.join(
|
||||||
|
RESULTS_DIR,
|
||||||
|
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
|
||||||
|
)
|
||||||
|
fig.savefig(output_path)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
3
cmake/Findnvpl.cmake
Normal file
3
cmake/Findnvpl.cmake
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# This file does nothing but to suppress the cmake warning: "By not providing
|
||||||
|
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
|
||||||
|
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.
|
||||||
@@ -29,17 +29,20 @@ MLX has a CUDA backend which you can install with:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install mlx[cuda]
|
pip install mlx[cuda12]
|
||||||
|
|
||||||
|
|
||||||
To install the CUDA package from PyPi your system must meet the following
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
requirements:
|
requirements:
|
||||||
|
|
||||||
- Nvidia architecture >= SM 7.0 (Volta)
|
- Nvidia architecture >= SM 7.5
|
||||||
- Nvidia driver >= 550.54.14
|
- Nvidia driver >= 550.54.14
|
||||||
- CUDA toolkit >= 12.0
|
- CUDA toolkit >= 12.0
|
||||||
- Linux distribution with glibc >= 2.35
|
- Linux distribution with glibc >= 2.35
|
||||||
- Python >= 3.10
|
- Python >= 3.10
|
||||||
|
|
||||||
|
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
|
||||||
|
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
|
||||||
|
|
||||||
CPU-only (Linux)
|
CPU-only (Linux)
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^
|
||||||
|
|||||||
@@ -70,7 +70,8 @@ Differences from NumPy
|
|||||||
|
|
||||||
* Indexing does not perform bounds checking. Indexing out of bounds is
|
* Indexing does not perform bounds checking. Indexing out of bounds is
|
||||||
undefined behavior.
|
undefined behavior.
|
||||||
* Boolean mask based indexing is not yet supported.
|
* Boolean mask based indexing is supported for assignment only (see
|
||||||
|
:ref:`boolean-mask-assignment`).
|
||||||
|
|
||||||
The reason for the lack of bounds checking is that exceptions cannot propagate
|
The reason for the lack of bounds checking is that exceptions cannot propagate
|
||||||
from the GPU. Performing bounds checking for array indices before launching the
|
from the GPU. Performing bounds checking for array indices before launching the
|
||||||
@@ -143,3 +144,51 @@ expected. For example:
|
|||||||
|
|
||||||
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
||||||
and ones elsewhere.
|
and ones elsewhere.
|
||||||
|
|
||||||
|
.. _boolean-mask-assignment:
|
||||||
|
|
||||||
|
Boolean Mask Assignment
|
||||||
|
-----------------------
|
||||||
|
|
||||||
|
MLX supports boolean indices using NumPy syntax. A mask must already be
|
||||||
|
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
|
||||||
|
Other index types are routed through the standard scatter code.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1.0, 2.0, 3.0])
|
||||||
|
>>> mask = mx.array([True, False, True])
|
||||||
|
>>> updates = mx.array([5.0, 6.0])
|
||||||
|
>>> a[mask] = updates
|
||||||
|
>>> a
|
||||||
|
array([5.0, 2.0, 6.0], dtype=float32)
|
||||||
|
|
||||||
|
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
|
||||||
|
assignments, ``updates`` must provide at least as many elements as there are
|
||||||
|
``True`` entries in ``mask``.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.zeros((2, 3))
|
||||||
|
>>> mask = mx.array([[True, False, True],
|
||||||
|
[False, False, True]])
|
||||||
|
>>> a[mask] = 1.0
|
||||||
|
>>> a
|
||||||
|
array([[1.0, 0.0, 1.0],
|
||||||
|
[0.0, 0.0, 1.0]], dtype=float32)
|
||||||
|
|
||||||
|
Boolean masks follow NumPy semantics:
|
||||||
|
|
||||||
|
- The mask shape must match the shape of the axes it indexes exactly. The only
|
||||||
|
exception is a scalar boolean mask, which broadcasts to the full array.
|
||||||
|
- Any axes not covered by the mask are taken in full.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
||||||
|
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||||
|
|
||||||
|
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
||||||
|
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
||||||
|
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
|
||||||
|
axes and therefore raise errors.
|
||||||
|
|||||||
@@ -3,6 +3,6 @@ requires = [
|
|||||||
"setuptools>=42",
|
"setuptools>=42",
|
||||||
"cmake>=3.25",
|
"cmake>=3.25",
|
||||||
"mlx>=0.18.0",
|
"mlx>=0.18.0",
|
||||||
"nanobind==2.4.0",
|
"nanobind==2.10.2",
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
setuptools>=42
|
setuptools>=42
|
||||||
cmake>=3.25
|
cmake>=3.25
|
||||||
mlx>=0.21.0
|
mlx>=0.21.0
|
||||||
nanobind==2.4.0
|
nanobind==2.10.2
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
|
|
||||||
namespace mlx::core::allocator {
|
|
||||||
|
|
||||||
Buffer malloc(size_t size) {
|
|
||||||
auto buffer = allocator().malloc(size);
|
|
||||||
if (size && !buffer.ptr()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
void free(Buffer buffer) {
|
|
||||||
allocator().free(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
|
||||||
@@ -28,16 +28,16 @@ class Buffer {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
Buffer malloc(size_t size);
|
|
||||||
|
|
||||||
void free(Buffer buffer);
|
|
||||||
|
|
||||||
class Allocator {
|
class Allocator {
|
||||||
/** Abstract base class for a memory allocator. */
|
/** Abstract base class for a memory allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size) = 0;
|
virtual Buffer malloc(size_t size) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
virtual size_t size(Buffer buffer) const = 0;
|
virtual size_t size(Buffer buffer) const = 0;
|
||||||
|
virtual Buffer make_buffer(void* ptr, size_t size) {
|
||||||
|
return Buffer{nullptr};
|
||||||
|
};
|
||||||
|
virtual void release(Buffer buffer) {}
|
||||||
|
|
||||||
Allocator() = default;
|
Allocator() = default;
|
||||||
Allocator(const Allocator& other) = delete;
|
Allocator(const Allocator& other) = delete;
|
||||||
@@ -49,4 +49,25 @@ class Allocator {
|
|||||||
|
|
||||||
Allocator& allocator();
|
Allocator& allocator();
|
||||||
|
|
||||||
|
inline Buffer malloc(size_t size) {
|
||||||
|
return allocator().malloc(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void free(Buffer buffer) {
|
||||||
|
allocator().free(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a Buffer from a raw pointer of the given size without a copy. If a
|
||||||
|
// no-copy conversion is not possible then the returned buffer.ptr() will be
|
||||||
|
// nullptr. Any buffer created with this function must be released with
|
||||||
|
// release(buffer)
|
||||||
|
inline Buffer make_buffer(void* ptr, size_t size) {
|
||||||
|
return allocator().make_buffer(ptr, size);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Release a buffer from the allocator made with make_buffer
|
||||||
|
inline void release(Buffer buffer) {
|
||||||
|
allocator().release(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
|||||||
@@ -82,6 +82,28 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
|||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array::array(
|
||||||
|
void* data,
|
||||||
|
Shape shape,
|
||||||
|
Dtype dtype,
|
||||||
|
const std::function<void(void*)>& deleter)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
|
auto buffer = allocator::make_buffer(data, nbytes());
|
||||||
|
if (buffer.ptr() == nullptr) {
|
||||||
|
set_data(allocator::malloc(nbytes()));
|
||||||
|
auto ptr = static_cast<char*>(data);
|
||||||
|
std::copy(ptr, ptr + nbytes(), this->data<char>());
|
||||||
|
deleter(data);
|
||||||
|
} else {
|
||||||
|
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
|
||||||
|
auto ptr = buffer.ptr();
|
||||||
|
allocator::release(buffer);
|
||||||
|
return deleter(ptr);
|
||||||
|
};
|
||||||
|
set_data(buffer, std::move(wrapped_deleter));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* Build an array from a shared buffer */
|
/* Build an array from a shared buffer */
|
||||||
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
@@ -167,7 +189,7 @@ void array::copy_shared_buffer(
|
|||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
Flags flags,
|
Flags flags,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
size_t offset /* = 0 */) {
|
int64_t offset /* = 0 */) {
|
||||||
array_desc_->data = other.array_desc_->data;
|
array_desc_->data = other.array_desc_->data;
|
||||||
array_desc_->strides = strides;
|
array_desc_->strides = strides;
|
||||||
array_desc_->flags = flags;
|
array_desc_->flags = flags;
|
||||||
|
|||||||
12
mlx/array.h
12
mlx/array.h
@@ -57,6 +57,16 @@ class array {
|
|||||||
Shape shape,
|
Shape shape,
|
||||||
Dtype dtype = TypeToDtype<T>());
|
Dtype dtype = TypeToDtype<T>());
|
||||||
|
|
||||||
|
/* Build an array from a raw pointer. The constructor will attempt to use the
|
||||||
|
* input data without a copy. The deleter will be called when the array no
|
||||||
|
* longer needs the underlying memory - after the array is destroyed in the
|
||||||
|
* no-copy case and after the copy otherwise. */
|
||||||
|
explicit array(
|
||||||
|
void* data,
|
||||||
|
Shape shape,
|
||||||
|
Dtype dtype,
|
||||||
|
const std::function<void(void*)>& deleter);
|
||||||
|
|
||||||
/* Build an array from a buffer */
|
/* Build an array from a buffer */
|
||||||
explicit array(
|
explicit array(
|
||||||
allocator::Buffer data,
|
allocator::Buffer data,
|
||||||
@@ -439,7 +449,7 @@ class array {
|
|||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
Flags flags,
|
Flags flags,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
size_t offset = 0);
|
int64_t offset = 0);
|
||||||
|
|
||||||
void copy_shared_buffer(const array& other);
|
void copy_shared_buffer(const array& other);
|
||||||
|
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||||
in.is_donatable() && is_constant(i)) {
|
in.is_donatable() && !is_constant(i)) {
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o++].copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
// Get representative input flags to properly set non-donated outputs
|
// Get representative input flags to properly set non-donated outputs
|
||||||
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||||
is_constant(i)) {
|
!is_constant(i)) {
|
||||||
outputs[o].copy_shared_buffer(
|
outputs[o].copy_shared_buffer(
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
o++;
|
o++;
|
||||||
|
|||||||
@@ -14,17 +14,13 @@ std::tuple<int64_t, Strides> prepare_slice(
|
|||||||
data_offset += start_indices[i] * in.strides()[i];
|
data_offset += start_indices[i] * in.strides()[i];
|
||||||
inp_strides[i] = in.strides()[i] * strides[i];
|
inp_strides[i] = in.strides()[i] * strides[i];
|
||||||
}
|
}
|
||||||
// Normalize the offset
|
|
||||||
if (data_offset < 0) {
|
|
||||||
data_offset += in.data_size();
|
|
||||||
}
|
|
||||||
return std::make_tuple(data_offset, inp_strides);
|
return std::make_tuple(data_offset, inp_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
void shared_buffer_slice(
|
void shared_buffer_slice(
|
||||||
const array& in,
|
const array& in,
|
||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
size_t data_offset,
|
int64_t data_offset,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
array& out) {
|
array& out) {
|
||||||
// Compute row/col contiguity
|
// Compute row/col contiguity
|
||||||
@@ -51,17 +47,24 @@ void slice(
|
|||||||
|
|
||||||
// Calculate out strides, initial offset
|
// Calculate out strides, initial offset
|
||||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||||
int64_t data_end = 1;
|
|
||||||
for (int i = 0; i < start_indices.size(); ++i) {
|
// Get the location of the end based on the inp strides and out.shape()
|
||||||
if (in.shape()[i] > 1) {
|
int64_t low_idx = 0;
|
||||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
int64_t high_idx = 0;
|
||||||
data_end += end_idx * in.strides()[i];
|
for (int i = 0; i < inp_strides.size(); ++i) {
|
||||||
|
auto delta = inp_strides[i] * (out.shape()[i] - 1);
|
||||||
|
if (inp_strides[i] > 0) {
|
||||||
|
high_idx += delta;
|
||||||
|
} else {
|
||||||
|
low_idx += delta;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (data_end < 0) {
|
int64_t data_size = (high_idx - low_idx) + 1;
|
||||||
data_end += in.data_size();
|
if (data_size < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[slice] Computed invalid data size: " << data_size << ".";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
size_t data_size = (data_end - data_offset);
|
|
||||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,167 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
complex64_t to_complex(T r, T i) {
|
||||||
|
return {static_cast<float>(r), static_cast<float>(i)};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, class Enable = void>
|
||||||
|
struct EigWork {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct EigWork<
|
||||||
|
T,
|
||||||
|
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||||
|
using O = complex64_t;
|
||||||
|
|
||||||
|
char jobl;
|
||||||
|
char jobr;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
||||||
|
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1) {
|
||||||
|
T work;
|
||||||
|
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work);
|
||||||
|
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * N * 2));
|
||||||
|
if (compute_eigenvectors) {
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * N * N * 2));
|
||||||
|
}
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, O* values, O* vectors) {
|
||||||
|
auto eig_tmp = static_cast<T*>(buffers[0].buffer.raw_ptr());
|
||||||
|
T* vec_tmp = nullptr;
|
||||||
|
if (vectors) {
|
||||||
|
vec_tmp = static_cast<T*>(buffers[1].buffer.raw_ptr());
|
||||||
|
}
|
||||||
|
auto work = static_cast<T*>(buffers.back().buffer.raw_ptr());
|
||||||
|
|
||||||
|
int n_vecs_l = vectors ? N : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
a,
|
||||||
|
&N,
|
||||||
|
eig_tmp,
|
||||||
|
eig_tmp + N,
|
||||||
|
vectors ? vec_tmp : nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
work,
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
values[i] = to_complex(eig_tmp[i], eig_tmp[N + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vectors) {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
if (values[i].imag() != 0) {
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vectors[i * N + j] =
|
||||||
|
to_complex(vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]);
|
||||||
|
vectors[(i + 1) * N + j] =
|
||||||
|
to_complex(vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]);
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vectors[i * N + j] = to_complex(vec_tmp[i * N + j], T(0.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct EigWork<std::complex<float>> {
|
||||||
|
using T = std::complex<float>;
|
||||||
|
using R = float;
|
||||||
|
using O = T;
|
||||||
|
|
||||||
|
char jobl;
|
||||||
|
char jobr;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int lrwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EigWork(char jobl_, char jobr_, int N_, bool compute_eigenvectors)
|
||||||
|
: jobl(jobl_), jobr(jobr_), N(N_), lwork(-1), lrwork(2 * N_) {
|
||||||
|
T work;
|
||||||
|
R rwork;
|
||||||
|
int n_vecs_l = compute_eigenvectors ? N_ : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&rwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work.real());
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, T* values, T* vectors) {
|
||||||
|
int n_vecs_l = vectors ? N : 1;
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
a,
|
||||||
|
&N,
|
||||||
|
values,
|
||||||
|
vectors,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
&info);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void eig_impl(
|
void eig_impl(
|
||||||
array& a,
|
array& a,
|
||||||
@@ -19,101 +180,39 @@ void eig_impl(
|
|||||||
array& values,
|
array& values,
|
||||||
bool compute_eigenvectors,
|
bool compute_eigenvectors,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
using OT = std::complex<T>;
|
|
||||||
auto a_ptr = a.data<T>();
|
auto a_ptr = a.data<T>();
|
||||||
auto eig_ptr = values.data<OT>();
|
auto val_ptr = values.data<complex64_t>();
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
encoder.set_output_array(values);
|
encoder.set_output_array(values);
|
||||||
OT* vec_ptr = nullptr;
|
complex64_t* vec_ptr = nullptr;
|
||||||
if (compute_eigenvectors) {
|
if (compute_eigenvectors) {
|
||||||
encoder.set_output_array(vectors);
|
encoder.set_output_array(vectors);
|
||||||
vec_ptr = vectors.data<OT>();
|
vec_ptr = vectors.data<complex64_t>();
|
||||||
}
|
}
|
||||||
encoder.dispatch([a_ptr,
|
encoder.dispatch([a_ptr,
|
||||||
|
val_ptr,
|
||||||
vec_ptr,
|
vec_ptr,
|
||||||
eig_ptr,
|
|
||||||
compute_eigenvectors,
|
compute_eigenvectors,
|
||||||
N = vectors.shape(-1),
|
N = vectors.shape(-1),
|
||||||
size = vectors.size()]() mutable {
|
size = vectors.size()]() mutable {
|
||||||
// Work query
|
|
||||||
char jobr = 'N';
|
char jobr = 'N';
|
||||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||||
int n_vecs_r = 1;
|
|
||||||
int n_vecs_l = compute_eigenvectors ? N : 1;
|
|
||||||
int lwork = -1;
|
|
||||||
int info;
|
|
||||||
{
|
|
||||||
T work;
|
|
||||||
geev<T>(
|
|
||||||
&jobl,
|
|
||||||
&jobr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
EigWork<T> work(jobl, jobr, N, compute_eigenvectors);
|
||||||
auto vec_tmp_data =
|
|
||||||
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
|
||||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
|
||||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
|
||||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
geev<T>(
|
work.run(a_ptr, val_ptr, vec_ptr);
|
||||||
&jobl,
|
a_ptr += N * N;
|
||||||
&jobr,
|
val_ptr += N;
|
||||||
&N,
|
|
||||||
a_ptr,
|
|
||||||
&N,
|
|
||||||
eig_tmp,
|
|
||||||
eig_tmp + N,
|
|
||||||
vec_tmp,
|
|
||||||
&n_vecs_l,
|
|
||||||
nullptr,
|
|
||||||
&n_vecs_r,
|
|
||||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
&info);
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
|
||||||
}
|
|
||||||
if (vec_ptr) {
|
if (vec_ptr) {
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
if (eig_ptr[i].imag() != 0) {
|
|
||||||
// This vector and the next are a pair
|
|
||||||
for (int j = 0; j < N; ++j) {
|
|
||||||
vec_ptr[i * N + j] = {
|
|
||||||
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
|
||||||
vec_ptr[(i + 1) * N + j] = {
|
|
||||||
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < N; ++j) {
|
|
||||||
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
vec_ptr += N * N;
|
vec_ptr += N * N;
|
||||||
}
|
}
|
||||||
a_ptr += N * N;
|
if (work.info != 0) {
|
||||||
eig_ptr += N;
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
<< info;
|
<< work.info;
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -165,8 +264,17 @@ void Eig::eval_cpu(
|
|||||||
case float32:
|
case float32:
|
||||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
eig_impl<double>(
|
||||||
|
a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
eig_impl<std::complex<float>>(
|
||||||
|
a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
throw std::runtime_error(
|
||||||
|
"[Eig::eval_cpu] only supports float32, float64, or complex64.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -747,4 +747,108 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void masked_scatter_impl(const array& mask, const array& src, array& out) {
|
||||||
|
ContiguousIterator mask_it(mask);
|
||||||
|
ContiguousIterator src_it(src);
|
||||||
|
ContiguousIterator out_it(out);
|
||||||
|
|
||||||
|
const bool* mask_ptr = mask.data<bool>();
|
||||||
|
const T* src_ptr = src.data<T>();
|
||||||
|
T* dst_ptr = out.data<T>();
|
||||||
|
|
||||||
|
const size_t batch_count = mask.shape(0);
|
||||||
|
const size_t mask_batch_size = mask.size() / batch_count;
|
||||||
|
const size_t src_batch_size = src.size() / batch_count;
|
||||||
|
|
||||||
|
for (uint b = 0; b < batch_count; ++b) {
|
||||||
|
size_t src_consumed = 0;
|
||||||
|
src_it.seek(b * src_batch_size);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < mask_batch_size; ++i) {
|
||||||
|
if (mask_ptr[mask_it.loc]) {
|
||||||
|
if (src_consumed >= src_batch_size) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.");
|
||||||
|
}
|
||||||
|
dst_ptr[out_it.loc] = src_ptr[src_it.loc];
|
||||||
|
src_it.step();
|
||||||
|
++src_consumed;
|
||||||
|
}
|
||||||
|
mask_it.step();
|
||||||
|
out_it.step();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 3);
|
||||||
|
|
||||||
|
auto& dst = inputs[0];
|
||||||
|
auto& mask = inputs[1];
|
||||||
|
auto& src = inputs[2];
|
||||||
|
|
||||||
|
// Copy src into out (copy allocates memory for out)
|
||||||
|
auto ctype =
|
||||||
|
dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
|
copy_cpu(dst, out, ctype, stream());
|
||||||
|
|
||||||
|
if (mask.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
encoder.set_input_array(mask);
|
||||||
|
encoder.set_input_array(src);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([mask = array::unsafe_weak_copy(mask),
|
||||||
|
src = array::unsafe_weak_copy(src),
|
||||||
|
out = array::unsafe_weak_copy(out)]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
masked_scatter_impl<bool>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
masked_scatter_impl<uint8_t>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
masked_scatter_impl<uint16_t>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
masked_scatter_impl<uint32_t>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
masked_scatter_impl<uint64_t>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
masked_scatter_impl<int8_t>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
masked_scatter_impl<int16_t>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
masked_scatter_impl<int32_t>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
masked_scatter_impl<int64_t>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
masked_scatter_impl<float16_t>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
masked_scatter_impl<float>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
masked_scatter_impl<double>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
masked_scatter_impl<bfloat16_t>(mask, src, out);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
masked_scatter_impl<complex64_t>(mask, src, out);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -45,9 +45,7 @@
|
|||||||
INSTANTIATE_LAPACK_REAL(geqrf)
|
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||||
INSTANTIATE_LAPACK_REAL(orgqr)
|
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||||
INSTANTIATE_LAPACK_REAL(syevd)
|
INSTANTIATE_LAPACK_REAL(syevd)
|
||||||
INSTANTIATE_LAPACK_REAL(geev)
|
|
||||||
INSTANTIATE_LAPACK_REAL(potrf)
|
INSTANTIATE_LAPACK_REAL(potrf)
|
||||||
INSTANTIATE_LAPACK_REAL(gesdd)
|
|
||||||
INSTANTIATE_LAPACK_REAL(getrf)
|
INSTANTIATE_LAPACK_REAL(getrf)
|
||||||
INSTANTIATE_LAPACK_REAL(getri)
|
INSTANTIATE_LAPACK_REAL(getri)
|
||||||
INSTANTIATE_LAPACK_REAL(trtri)
|
INSTANTIATE_LAPACK_REAL(trtri)
|
||||||
@@ -63,3 +61,20 @@ INSTANTIATE_LAPACK_REAL(trtri)
|
|||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||||
|
|
||||||
|
#define INSTANTIATE_LAPACK_ALL(FUNC) \
|
||||||
|
template <typename T, typename... Args> \
|
||||||
|
void FUNC(Args... args) { \
|
||||||
|
if constexpr (std::is_same_v<T, float>) { \
|
||||||
|
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, double>) { \
|
||||||
|
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||||
|
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||||
|
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_LAPACK_ALL(geev)
|
||||||
|
INSTANTIATE_LAPACK_ALL(gesdd)
|
||||||
|
|||||||
@@ -291,6 +291,17 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
num_keys,
|
num_keys,
|
||||||
kshape = keys.shape(),
|
kshape = keys.shape(),
|
||||||
kstrides = keys.strides()]() mutable {
|
kstrides = keys.strides()]() mutable {
|
||||||
|
auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) {
|
||||||
|
if (4 * loc + 4 <= bytes_per_key) {
|
||||||
|
reinterpret_cast<uint32_t*>(cptr)[loc] = v;
|
||||||
|
} else {
|
||||||
|
std::copy(
|
||||||
|
reinterpret_cast<char*>(&v),
|
||||||
|
reinterpret_cast<char*>(&v) + bytes_per_key - 4 * loc,
|
||||||
|
cptr + 4 * loc);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||||
auto half_size = out_skip / 2;
|
auto half_size = out_skip / 2;
|
||||||
bool even = out_skip % 2 == 0;
|
bool even = out_skip % 2 == 0;
|
||||||
@@ -310,18 +321,12 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (count.first < half_size) {
|
if (count.first < half_size) {
|
||||||
auto rb = random::threefry2x32_hash(key, count);
|
auto rb = random::threefry2x32_hash(key, count);
|
||||||
ptr[count.first++] = rb.first;
|
ptr[count.first++] = rb.first;
|
||||||
if (bytes_per_key % 4 > 0) {
|
copy_remaining(cptr, count.second, rb.second);
|
||||||
std::copy(
|
|
||||||
reinterpret_cast<char*>(&rb.second),
|
|
||||||
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
|
|
||||||
cptr + 4 * count.second);
|
|
||||||
} else {
|
|
||||||
ptr[count.second] = rb.second;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (!even) {
|
if (!even) {
|
||||||
count.second = 0;
|
count.second = 0;
|
||||||
ptr[half_size] = random::threefry2x32_hash(key, count).first;
|
copy_remaining(
|
||||||
|
cptr, half_size, random::threefry2x32_hash(key, count).first);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -3,5 +3,9 @@
|
|||||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||||
|
|
||||||
#ifdef MLX_USE_ACCELERATE
|
#ifdef MLX_USE_ACCELERATE
|
||||||
|
#if defined(__x86_64__)
|
||||||
|
// the accelerate_simd implementation require neon -- use base implementation
|
||||||
|
#else
|
||||||
#include "mlx/backend/cpu/simd/accelerate_simd.h"
|
#include "mlx/backend/cpu/simd/accelerate_simd.h"
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|||||||
@@ -8,6 +8,183 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T, class Enable = void>
|
||||||
|
struct SVDWork {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SVDWork<
|
||||||
|
T,
|
||||||
|
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||||
|
using R = T;
|
||||||
|
|
||||||
|
int N;
|
||||||
|
int M;
|
||||||
|
int K;
|
||||||
|
int lda;
|
||||||
|
int ldu;
|
||||||
|
int ldvt;
|
||||||
|
char jobz;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
int lwork;
|
||||||
|
|
||||||
|
SVDWork(int N, int M, int K, char jobz)
|
||||||
|
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
|
||||||
|
T workspace_dimension = 0;
|
||||||
|
|
||||||
|
// Will contain the indices of eigenvectors that failed to converge (not
|
||||||
|
// used here but required by lapack).
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
|
||||||
|
|
||||||
|
int lwork_query = -1;
|
||||||
|
int info;
|
||||||
|
|
||||||
|
// Compute workspace size.
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ nullptr,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ nullptr,
|
||||||
|
/* u = */ nullptr,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
/* vt = */ nullptr,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ &workspace_dimension,
|
||||||
|
/* lwork = */ &lwork_query,
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
lwork = workspace_dimension;
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, R* s, T* u, T* vt) {
|
||||||
|
int info;
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ a,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ s,
|
||||||
|
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||||
|
/* u = */ u,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||||
|
/* vt = */ vt,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ static_cast<T*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
/* lwork = */ &lwork,
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct SVDWork<std::complex<float>> {
|
||||||
|
using T = std::complex<float>;
|
||||||
|
using R = float;
|
||||||
|
|
||||||
|
int N;
|
||||||
|
int M;
|
||||||
|
int K;
|
||||||
|
int lda;
|
||||||
|
int ldu;
|
||||||
|
int ldvt;
|
||||||
|
char jobz;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
int lwork;
|
||||||
|
|
||||||
|
SVDWork(int N, int M, int K, char jobz)
|
||||||
|
: N(N), M(M), K(K), lda(N), ldu(N), ldvt(M), jobz(jobz) {
|
||||||
|
T workspace_dimension = 0;
|
||||||
|
|
||||||
|
// Will contain the indices of eigenvectors that failed to converge (not
|
||||||
|
// used here but required by lapack).
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * 8 * K));
|
||||||
|
|
||||||
|
const int lrwork =
|
||||||
|
jobz == 'A' ? std::max(1, 5 * K * K + 5 * K) : std::max(1, 7 * K);
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(float) * lrwork));
|
||||||
|
|
||||||
|
int lwork_query = -1;
|
||||||
|
int work_query = -1;
|
||||||
|
int info;
|
||||||
|
|
||||||
|
// Compute workspace size.
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ nullptr,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ nullptr,
|
||||||
|
/* u = */ nullptr,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
/* vt = */ nullptr,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ &workspace_dimension,
|
||||||
|
/* lwork = */ &lwork_query,
|
||||||
|
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
lwork = workspace_dimension.real();
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* a, R* s, T* u, T* vt) {
|
||||||
|
int info;
|
||||||
|
gesdd<T>(
|
||||||
|
/* jobz = */ &jobz,
|
||||||
|
// M and N are swapped since lapack expects column-major.
|
||||||
|
/* m = */ &N,
|
||||||
|
/* n = */ &M,
|
||||||
|
/* a = */ a,
|
||||||
|
/* lda = */ &lda,
|
||||||
|
/* s = */ s,
|
||||||
|
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||||
|
/* u = */ u,
|
||||||
|
/* ldu = */ &ldu,
|
||||||
|
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||||
|
/* vt = */ vt,
|
||||||
|
/* ldvt = */ &ldvt,
|
||||||
|
/* work = */ static_cast<T*>(buffers[2].buffer.raw_ptr()),
|
||||||
|
/* lwork = */ &lwork,
|
||||||
|
/* rwork = */ static_cast<float*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
/* iwork = */ static_cast<int*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void svd_impl(
|
void svd_impl(
|
||||||
const array& a,
|
const array& a,
|
||||||
@@ -27,6 +204,8 @@ void svd_impl(
|
|||||||
const int N = a.shape(-1);
|
const int N = a.shape(-1);
|
||||||
const int K = std::min(M, N);
|
const int K = std::min(M, N);
|
||||||
|
|
||||||
|
using R = typename SVDWork<T>::R;
|
||||||
|
|
||||||
size_t num_matrices = a.size() / (M * N);
|
size_t num_matrices = a.size() / (M * N);
|
||||||
|
|
||||||
// lapack clobbers the input, so we have to make a copy.
|
// lapack clobbers the input, so we have to make a copy.
|
||||||
@@ -42,7 +221,7 @@ void svd_impl(
|
|||||||
encoder.set_input_array(a);
|
encoder.set_input_array(a);
|
||||||
auto in_ptr = in.data<T>();
|
auto in_ptr = in.data<T>();
|
||||||
T* u_ptr;
|
T* u_ptr;
|
||||||
T* s_ptr;
|
R* s_ptr;
|
||||||
T* vt_ptr;
|
T* vt_ptr;
|
||||||
|
|
||||||
if (compute_uv) {
|
if (compute_uv) {
|
||||||
@@ -58,7 +237,7 @@ void svd_impl(
|
|||||||
encoder.set_output_array(s);
|
encoder.set_output_array(s);
|
||||||
encoder.set_output_array(vt);
|
encoder.set_output_array(vt);
|
||||||
|
|
||||||
s_ptr = s.data<T>();
|
s_ptr = s.data<R>();
|
||||||
u_ptr = u.data<T>();
|
u_ptr = u.data<T>();
|
||||||
vt_ptr = vt.data<T>();
|
vt_ptr = vt.data<T>();
|
||||||
} else {
|
} else {
|
||||||
@@ -68,96 +247,26 @@ void svd_impl(
|
|||||||
|
|
||||||
encoder.set_output_array(s);
|
encoder.set_output_array(s);
|
||||||
|
|
||||||
s_ptr = s.data<T>();
|
s_ptr = s.data<R>();
|
||||||
u_ptr = nullptr;
|
u_ptr = nullptr;
|
||||||
vt_ptr = nullptr;
|
vt_ptr = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
|
encoder.dispatch([in_ptr, u_ptr, s_ptr, vt_ptr, M, N, K, num_matrices]() {
|
||||||
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
|
auto jobz = (u_ptr) ? 'A' : 'N';
|
||||||
const int lda = N;
|
SVDWork<T> svd_work(N, M, K, jobz);
|
||||||
// U of shape M x M. (N x N in lapack).
|
|
||||||
const int ldu = N;
|
|
||||||
// Vᵀ of shape N x N. (M x M in lapack).
|
|
||||||
const int ldvt = M;
|
|
||||||
|
|
||||||
auto jobz = (u_ptr) ? "A" : "N";
|
|
||||||
|
|
||||||
T workspace_dimension = 0;
|
|
||||||
|
|
||||||
// Will contain the indices of eigenvectors that failed to converge (not
|
|
||||||
// used here but required by lapack).
|
|
||||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)};
|
|
||||||
|
|
||||||
static const int lwork_query = -1;
|
|
||||||
|
|
||||||
int info;
|
|
||||||
|
|
||||||
// Compute workspace size.
|
|
||||||
gesdd<T>(
|
|
||||||
/* jobz = */ jobz,
|
|
||||||
// M and N are swapped since lapack expects column-major.
|
|
||||||
/* m = */ &N,
|
|
||||||
/* n = */ &M,
|
|
||||||
/* a = */ nullptr,
|
|
||||||
/* lda = */ &lda,
|
|
||||||
/* s = */ nullptr,
|
|
||||||
/* u = */ nullptr,
|
|
||||||
/* ldu = */ &ldu,
|
|
||||||
/* vt = */ nullptr,
|
|
||||||
/* ldvt = */ &ldvt,
|
|
||||||
/* work = */ &workspace_dimension,
|
|
||||||
/* lwork = */ &lwork_query,
|
|
||||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
|
||||||
/* info = */ &info);
|
|
||||||
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "[SVD::eval_cpu] workspace calculation failed with code " << info;
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
const int lwork = workspace_dimension;
|
|
||||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
|
||||||
|
|
||||||
// Loop over matrices.
|
// Loop over matrices.
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
gesdd<T>(
|
svd_work.run(
|
||||||
/* jobz = */ jobz,
|
in_ptr + M * N * i,
|
||||||
// M and N are swapped since lapack expects column-major.
|
s_ptr + K * i,
|
||||||
/* m = */ &N,
|
vt_ptr ? vt_ptr + N * N * i : nullptr,
|
||||||
/* n = */ &M,
|
u_ptr ? u_ptr + M * M * i : nullptr);
|
||||||
/* a = */ in_ptr + M * N * i,
|
|
||||||
/* lda = */ &lda,
|
|
||||||
/* s = */ s_ptr + K * i,
|
|
||||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
|
||||||
/* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr,
|
|
||||||
/* ldu = */ &ldu,
|
|
||||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
|
||||||
/* vt = */ u_ptr ? u_ptr + M * M * i : nullptr,
|
|
||||||
/* ldvt = */ &ldvt,
|
|
||||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
|
||||||
/* lwork = */ &lwork,
|
|
||||||
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
|
|
||||||
/* info = */ &info);
|
|
||||||
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "svd_impl: sgesvdx_ failed with code " << info;
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
encoder.add_temporary(in);
|
encoder.add_temporary(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void compute_svd(
|
|
||||||
const array& a,
|
|
||||||
bool compute_uv,
|
|
||||||
std::vector<array>& outputs,
|
|
||||||
Stream stream) {}
|
|
||||||
|
|
||||||
void SVD::eval_cpu(
|
void SVD::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
@@ -168,9 +277,12 @@ void SVD::eval_cpu(
|
|||||||
case float64:
|
case float64:
|
||||||
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
|
svd_impl<double>(inputs[0], outputs, compute_uv_, stream());
|
||||||
break;
|
break;
|
||||||
|
case complex64:
|
||||||
|
svd_impl<std::complex<float>>(inputs[0], outputs, compute_uv_, stream());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[SVD::eval_cpu] only supports float32 or float64.");
|
"[SVD::eval_cpu] only supports float32, float64, or complex64.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
@@ -122,10 +123,21 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
|||||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
# Use native CUDA arch by default.
|
||||||
# managed memory.
|
|
||||||
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||||
set(MLX_CUDA_ARCHITECTURES "native")
|
execute_process(
|
||||||
|
COMMAND __nvcc_device_query
|
||||||
|
OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
|
set(UPGRADABLE_ARCHITECTURES "90;100;121")
|
||||||
|
if(MLX_CUDA_ARCHITECTURES STREQUAL "")
|
||||||
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
|
"Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
|
||||||
|
elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
|
||||||
|
# Use arch-specific compute capability whenever possible.
|
||||||
|
set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||||
@@ -137,6 +149,7 @@ FetchContent_Declare(
|
|||||||
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
||||||
FetchContent_MakeAvailable(cccl)
|
FetchContent_MakeAvailable(cccl)
|
||||||
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
|
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
|
||||||
|
set_target_properties(mlx PROPERTIES CCCL_DIR "${cccl_SOURCE_DIR}/include")
|
||||||
|
|
||||||
# Use fixed version of NVTX.
|
# Use fixed version of NVTX.
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
@@ -162,7 +175,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
cudnn
|
cudnn
|
||||||
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||||
GIT_TAG v1.14.0
|
GIT_TAG v1.16.0
|
||||||
GIT_SHALLOW TRUE
|
GIT_SHALLOW TRUE
|
||||||
EXCLUDE_FROM_ALL)
|
EXCLUDE_FROM_ALL)
|
||||||
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||||
|
|||||||
@@ -20,6 +20,19 @@ constexpr int page_size = 16384;
|
|||||||
// Any allocations smaller than this will try to use the small pool
|
// Any allocations smaller than this will try to use the small pool
|
||||||
constexpr int small_block_size = 8;
|
constexpr int small_block_size = 8;
|
||||||
|
|
||||||
|
#if CUDART_VERSION >= 13000
|
||||||
|
inline cudaMemLocation cuda_mem_loc(int i) {
|
||||||
|
cudaMemLocation loc;
|
||||||
|
loc.type = cudaMemLocationTypeDevice;
|
||||||
|
loc.id = i;
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
inline int cuda_mem_loc(int i) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
#endif // CUDART_VERSION >= 13000
|
||||||
|
|
||||||
// The small pool size in bytes. This should be a multiple of the host page
|
// The small pool size in bytes. This should be a multiple of the host page
|
||||||
// size and small_block_size.
|
// size and small_block_size.
|
||||||
constexpr int small_pool_size = 4 * page_size;
|
constexpr int small_pool_size = 4 * page_size;
|
||||||
@@ -35,13 +48,7 @@ SmallSizePool::SmallSizePool() {
|
|||||||
int device_count = 0;
|
int device_count = 0;
|
||||||
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||||
for (int i = 0; i < device_count; ++i) {
|
for (int i = 0; i < device_count; ++i) {
|
||||||
#if CUDART_VERSION >= 13000
|
auto loc = cuda_mem_loc(i);
|
||||||
cudaMemLocation loc;
|
|
||||||
loc.type = cudaMemLocationTypeDevice;
|
|
||||||
loc.id = i;
|
|
||||||
#else
|
|
||||||
int loc = i;
|
|
||||||
#endif // CUDART_VERSION >= 13000
|
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
||||||
}
|
}
|
||||||
@@ -90,9 +97,10 @@ CudaAllocator::CudaAllocator()
|
|||||||
page_size,
|
page_size,
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||||
size_t free, total;
|
size_t free;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
|
||||||
memory_limit_ = total * 0.95;
|
memory_limit_ = total_memory_ * 0.95;
|
||||||
|
free_limit_ = total_memory_ - memory_limit_;
|
||||||
max_pool_size_ = memory_limit_;
|
max_pool_size_ = memory_limit_;
|
||||||
|
|
||||||
int device_count = 0;
|
int device_count = 0;
|
||||||
@@ -104,6 +112,10 @@ CudaAllocator::CudaAllocator()
|
|||||||
cudaStream_t s;
|
cudaStream_t s;
|
||||||
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
|
CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking));
|
||||||
free_streams_.push_back(s);
|
free_streams_.push_back(s);
|
||||||
|
|
||||||
|
cudaMemPool_t mem_pool;
|
||||||
|
CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pool, i));
|
||||||
|
mem_pools_.push_back(mem_pool);
|
||||||
}
|
}
|
||||||
CHECK_CUDA_ERROR(cudaSetDevice(curr));
|
CHECK_CUDA_ERROR(cudaSetDevice(curr));
|
||||||
}
|
}
|
||||||
@@ -119,7 +131,8 @@ void copy_to_managed(CudaBuffer& buf) {
|
|||||||
buf.data = new_data;
|
buf.data = new_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
Buffer
|
||||||
|
CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
return Buffer{new CudaBuffer{nullptr, 0, -1}};
|
return Buffer{new CudaBuffer{nullptr, 0, -1}};
|
||||||
}
|
}
|
||||||
@@ -134,9 +147,8 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
|||||||
size = page_size * ((size + page_size - 1) / page_size);
|
size = page_size * ((size + page_size - 1) / page_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
int device = -1;
|
if (size <= small_block_size || stream == nullptr) {
|
||||||
if (size > small_block_size && stream != nullptr) {
|
device = -1;
|
||||||
CHECK_CUDA_ERROR(cudaStreamGetDevice(stream, &device));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
@@ -154,19 +166,35 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
|||||||
}
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
buf = new CudaBuffer{nullptr, size, device};
|
void* data = nullptr;
|
||||||
cudaError_t err;
|
|
||||||
if (device == -1) {
|
if (device == -1) {
|
||||||
err = cudaMallocManaged(&buf->data, size);
|
CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
|
||||||
} else {
|
} else {
|
||||||
err = cudaMallocAsync(&buf->data, size, stream);
|
CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
|
||||||
}
|
}
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
if (!data) {
|
||||||
throw std::runtime_error(fmt::format(
|
std::ostringstream msg;
|
||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
|
buf = new CudaBuffer{data, size, device};
|
||||||
}
|
}
|
||||||
lock.lock();
|
lock.lock();
|
||||||
|
|
||||||
|
// If any cuda memory pool has too much reserved memory, clear some
|
||||||
|
// memory from the cache. This prevents graph / kernel execution failing
|
||||||
|
// from OOM
|
||||||
|
if (get_cache_memory() > 0) {
|
||||||
|
for (auto p : mem_pools_) {
|
||||||
|
size_t used = 0;
|
||||||
|
CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
|
||||||
|
p, cudaMemPoolAttrReservedMemCurrent, &used));
|
||||||
|
if (used > (total_memory_ - free_limit_)) {
|
||||||
|
buffer_cache_.release_cached_buffers(free_limit_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
active_memory_ += buf->size;
|
active_memory_ += buf->size;
|
||||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||||
@@ -176,18 +204,14 @@ Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) {
|
|||||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||||
}
|
}
|
||||||
// Copy to managed here if the buffer is not on the right device
|
// Copy to managed here if the buffer is not on the right device
|
||||||
if (buf->device != device) {
|
if (buf->device >= 0 && buf->device != device) {
|
||||||
copy_to_managed(*buf);
|
copy_to_managed(*buf);
|
||||||
}
|
}
|
||||||
return Buffer{buf};
|
return Buffer{buf};
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CudaAllocator::malloc_async(size_t size, cudaStream_t stream) {
|
|
||||||
return malloc_impl(size, stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
Buffer CudaAllocator::malloc(size_t size) {
|
Buffer CudaAllocator::malloc(size_t size) {
|
||||||
return malloc_impl(size, nullptr);
|
return malloc_async(size, -1, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaAllocator::free(Buffer buffer) {
|
void CudaAllocator::free(Buffer buffer) {
|
||||||
@@ -223,9 +247,9 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
|||||||
scalar_pool_.free(buf);
|
scalar_pool_.free(buf);
|
||||||
} else {
|
} else {
|
||||||
if (buf->device >= 0) {
|
if (buf->device >= 0) {
|
||||||
cudaFreeAsync(buf->data, free_streams_[buf->device]);
|
CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
|
||||||
} else {
|
} else {
|
||||||
cudaFree(buf->data);
|
CHECK_CUDA_ERROR(cudaFree(buf->data));
|
||||||
}
|
}
|
||||||
delete buf;
|
delete buf;
|
||||||
}
|
}
|
||||||
@@ -277,8 +301,9 @@ CudaAllocator& allocator() {
|
|||||||
return *allocator_;
|
return *allocator_;
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer malloc_async(size_t size, cudaStream_t stream) {
|
Buffer malloc_async(size_t size, CommandEncoder& encoder) {
|
||||||
auto buffer = allocator().malloc_async(size, stream);
|
auto buffer = allocator().malloc_async(
|
||||||
|
size, encoder.device().cuda_device(), encoder.stream());
|
||||||
if (size && !buffer.ptr()) {
|
if (size && !buffer.ptr()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
|
msg << "[malloc_async] Unable to allocate " << size << " bytes.";
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
class CommandEncoder;
|
||||||
|
|
||||||
using allocator::Buffer;
|
using allocator::Buffer;
|
||||||
|
|
||||||
// Stores cuda-managed unified memory.
|
// Stores cuda-managed unified memory.
|
||||||
@@ -48,7 +50,7 @@ class SmallSizePool {
|
|||||||
class CudaAllocator : public allocator::Allocator {
|
class CudaAllocator : public allocator::Allocator {
|
||||||
public:
|
public:
|
||||||
Buffer malloc(size_t size) override;
|
Buffer malloc(size_t size) override;
|
||||||
Buffer malloc_async(size_t size, cudaStream_t stream);
|
Buffer malloc_async(size_t size, int device, cudaStream_t stream);
|
||||||
void free(Buffer buffer) override;
|
void free(Buffer buffer) override;
|
||||||
size_t size(Buffer buffer) const override;
|
size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
@@ -62,7 +64,6 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
void clear_cache();
|
void clear_cache();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Buffer malloc_impl(size_t size, cudaStream_t stream);
|
|
||||||
void cuda_free(CudaBuffer* buf);
|
void cuda_free(CudaBuffer* buf);
|
||||||
|
|
||||||
CudaAllocator();
|
CudaAllocator();
|
||||||
@@ -70,16 +71,19 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
|
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
size_t memory_limit_;
|
size_t memory_limit_;
|
||||||
|
size_t free_limit_;
|
||||||
|
size_t total_memory_;
|
||||||
size_t max_pool_size_;
|
size_t max_pool_size_;
|
||||||
BufferCache<CudaBuffer> buffer_cache_;
|
BufferCache<CudaBuffer> buffer_cache_;
|
||||||
size_t active_memory_{0};
|
size_t active_memory_{0};
|
||||||
size_t peak_memory_{0};
|
size_t peak_memory_{0};
|
||||||
std::vector<cudaStream_t> free_streams_;
|
std::vector<cudaStream_t> free_streams_;
|
||||||
|
std::vector<cudaMemPool_t> mem_pools_;
|
||||||
SmallSizePool scalar_pool_;
|
SmallSizePool scalar_pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
CudaAllocator& allocator();
|
CudaAllocator& allocator();
|
||||||
|
|
||||||
Buffer malloc_async(size_t size, cudaStream_t stream);
|
Buffer malloc_async(size_t size, CommandEncoder& encoder);
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto& encoder = cu::get_command_encoder(stream());
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
|
|
||||||
// Prepare the shapes, strides and axis arguments.
|
// Prepare the shapes, strides and axis arguments.
|
||||||
Shape shape = remove_index(in.shape(), axis_);
|
Shape shape = remove_index(in.shape(), axis_);
|
||||||
|
|||||||
@@ -367,9 +367,8 @@ void binary_op_gpu(
|
|||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
set_binary_op_output_data(a, b, out, bopt, [&](auto n) {
|
set_binary_op_output_data(
|
||||||
return cu::malloc_async(n, encoder.stream());
|
a, b, out, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||||
});
|
|
||||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -246,12 +246,10 @@ void binary_two_op_gpu_inplace(
|
|||||||
auto& out_b = outputs[1];
|
auto& out_b = outputs[1];
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
set_binary_op_output_data(a, b, out_a, bopt, [&](auto n) {
|
set_binary_op_output_data(
|
||||||
return cu::malloc_async(n, encoder.stream());
|
a, b, out_a, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||||
});
|
set_binary_op_output_data(
|
||||||
set_binary_op_output_data(a, b, out_b, bopt, [&](auto n) {
|
a, b, out_b, bopt, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||||
return cu::malloc_async(n, encoder.stream());
|
|
||||||
});
|
|
||||||
|
|
||||||
if (out_a.size() == 0) {
|
if (out_a.size() == 0) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ void Compiled::eval_gpu(
|
|||||||
// Put outputs.
|
// Put outputs.
|
||||||
compiled_allocate_outputs(
|
compiled_allocate_outputs(
|
||||||
inputs, outputs, is_constant_, contiguous, [&](auto n) {
|
inputs, outputs, is_constant_, contiguous, [&](auto n) {
|
||||||
return cu::malloc_async(n, encoder.stream());
|
return cu::malloc_async(n, encoder);
|
||||||
});
|
});
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
args.append(x);
|
args.append(x);
|
||||||
|
|||||||
@@ -15,19 +15,16 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Alias for better readability.
|
enum ConvBackendType {
|
||||||
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
CONV_FALLBACK,
|
||||||
#define CONV_BACKWARD_INPUT \
|
CONV_FORWARD,
|
||||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
|
CONV_BACKWARD_INPUT,
|
||||||
#define CONV_BACKWARD_WEIGHT \
|
CONV_BACKWARD_WEIGHT,
|
||||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
};
|
||||||
|
|
||||||
// Custom placeholder representing fallback kernel.
|
|
||||||
#define CONV_FALLBACK static_cast<cudnnBackendDescriptorType_t>(-1)
|
|
||||||
|
|
||||||
struct ConvCacheKey {
|
struct ConvCacheKey {
|
||||||
int device_id;
|
int device_id;
|
||||||
cudnnDataType_t cudnn_dtype;
|
fe::DataType_t cudnn_dtype;
|
||||||
std::array<int, MAX_NDIM> input_shape;
|
std::array<int, MAX_NDIM> input_shape;
|
||||||
std::array<int, MAX_NDIM> weight_shape;
|
std::array<int, MAX_NDIM> weight_shape;
|
||||||
std::array<int, MAX_NDIM> stride;
|
std::array<int, MAX_NDIM> stride;
|
||||||
@@ -44,15 +41,13 @@ struct ConvCacheKey {
|
|||||||
auto& conv_cache() {
|
auto& conv_cache() {
|
||||||
static LRUBytesKeyCache<
|
static LRUBytesKeyCache<
|
||||||
ConvCacheKey,
|
ConvCacheKey,
|
||||||
std::pair<
|
std::pair<ConvBackendType, std::optional<DnnGraph>>>
|
||||||
cudnnBackendDescriptorType_t,
|
|
||||||
std::optional<cudnn_frontend::ExecutionPlan>>>
|
|
||||||
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
|
cache("MLX_CUDA_CONV_CACHE_SIZE", /* default_capacity */ 128);
|
||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto get_conv_op_settings(
|
auto get_conv_settings(
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
ConvBackendType backend_type,
|
||||||
array& x,
|
array& x,
|
||||||
array& w,
|
array& w,
|
||||||
array& y,
|
array& y,
|
||||||
@@ -68,8 +63,8 @@ auto get_conv_op_settings(
|
|||||||
for (int i = 0; i < padding_lo.size(); ++i) {
|
for (int i = 0; i < padding_lo.size(); ++i) {
|
||||||
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
|
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
|
||||||
padding_lo[i] = wt_size - padding_lo[i] - 1;
|
padding_lo[i] = wt_size - padding_lo[i] - 1;
|
||||||
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
|
int in_size = 1 + kernel_strides[i] * (y.shape(1 + i) - 1);
|
||||||
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
|
int out_size = 1 + input_dilation[i] * (x.shape(1 + i) - 1);
|
||||||
padding_hi[i] = out_size - in_size + padding_hi[i];
|
padding_hi[i] = out_size - in_size + padding_hi[i];
|
||||||
}
|
}
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
@@ -95,49 +90,57 @@ auto get_conv_op_settings(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<cudnn_frontend::OperationGraph> build_conv_op_graph(
|
std::optional<DnnGraph> build_conv_graph(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
ConvBackendType backend_type,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
array& x,
|
array& x,
|
||||||
array& w,
|
array& w,
|
||||||
array& y,
|
array& y,
|
||||||
const SmallVector<int64_t>& stride,
|
const std::vector<int64_t>& stride,
|
||||||
const SmallVector<int64_t>& padding_lo,
|
const std::vector<int64_t>& padding_lo,
|
||||||
const SmallVector<int64_t>& padding_hi,
|
const std::vector<int64_t>& padding_hi,
|
||||||
const SmallVector<int64_t>& dilation) {
|
const std::vector<int64_t>& dilation) {
|
||||||
try {
|
auto compute_dtype =
|
||||||
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
(dtype == float16 || dtype == bfloat16) ? float32 : dtype;
|
||||||
? CUDNN_DATA_FLOAT
|
DnnGraph graph(encoder.device().cudnn_handle(), dtype, compute_dtype);
|
||||||
: dtype_to_cudnn_type(dtype);
|
auto x_ = graph.tensor_nchw("X", 'x', x);
|
||||||
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
auto w_ = graph.tensor_nchw("W", 'w', w);
|
||||||
.setDataType(compute_dtype)
|
|
||||||
.setMathMode(CUDNN_CROSS_CORRELATION)
|
|
||||||
.setNDims(stride.size())
|
|
||||||
.setStrides(stride.size(), stride.data())
|
|
||||||
.setPrePadding(padding_lo.size(), padding_lo.data())
|
|
||||||
.setPostPadding(padding_hi.size(), padding_hi.data())
|
|
||||||
.setDilation(dilation.size(), dilation.data())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
auto set_options = [&](auto& options) {
|
||||||
.setxDesc(build_cudnn_tensor_nchw('x', x))
|
options.set_compute_data_type(dtype_to_cudnn_type(compute_dtype))
|
||||||
.setwDesc(build_cudnn_tensor_nchw('w', w))
|
.set_convolution_mode(fe::ConvolutionMode_t::CROSS_CORRELATION)
|
||||||
.setyDesc(build_cudnn_tensor_nchw('y', y))
|
.set_stride(stride)
|
||||||
.setcDesc(conv_desc)
|
.set_pre_padding(padding_lo)
|
||||||
.build();
|
.set_post_padding(padding_hi)
|
||||||
|
.set_dilation(dilation);
|
||||||
|
};
|
||||||
|
|
||||||
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
std::shared_ptr<fe::graph::Tensor_attributes> y_;
|
||||||
return cudnn_frontend::OperationGraphBuilder()
|
if (backend_type == CONV_FORWARD) {
|
||||||
.setHandle(encoder.device().cudnn_handle())
|
auto options = fe::graph::Conv_fprop_attributes();
|
||||||
.setOperationGraph(ops.size(), ops.data())
|
set_options(options);
|
||||||
.build();
|
y_ = graph.conv_fprop(x_, w_, options);
|
||||||
} catch (cudnn_frontend::cudnnException& error) {
|
} else if (backend_type == CONV_BACKWARD_INPUT) {
|
||||||
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
|
auto options = fe::graph::Conv_dgrad_attributes();
|
||||||
throw;
|
set_options(options);
|
||||||
}
|
y_ = graph.conv_dgrad(x_, w_, options);
|
||||||
|
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
|
auto options = fe::graph::Conv_wgrad_attributes();
|
||||||
|
set_options(options);
|
||||||
|
y_ = graph.conv_wgrad(w_, x_, options);
|
||||||
|
}
|
||||||
|
graph.tensor_nchw(y_, 'y', y)->set_output(true);
|
||||||
|
|
||||||
|
if (graph.prepare().is_bad()) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
graph.deselect_numeric_notes({fe::NumericalNote_t::DOWN_CONVERT_INPUTS});
|
||||||
|
if (dtype == float32 && !env::enable_tf32()) {
|
||||||
|
graph.deselect_numeric_notes({fe::NumericalNote_t::TENSOR_CORE});
|
||||||
|
}
|
||||||
|
CHECK_CUDNN_FE_ERROR(graph.build());
|
||||||
|
return graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
|
// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups).
|
||||||
@@ -181,7 +184,7 @@ array group_transpose(
|
|||||||
// eval_gpu, with cost of possible redundant copies.
|
// eval_gpu, with cost of possible redundant copies.
|
||||||
std::tuple<array, array, array> prepare_args(
|
std::tuple<array, array, array> prepare_args(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
ConvBackendType backend_type,
|
||||||
array in,
|
array in,
|
||||||
array wt,
|
array wt,
|
||||||
array out,
|
array out,
|
||||||
@@ -221,27 +224,11 @@ std::tuple<array, array, array> prepare_args(
|
|||||||
return {std::move(in), std::move(wt), std::move(out)};
|
return {std::move(in), std::move(wt), std::move(out)};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the x/w/y args from the in/wt/out args depending on backend type.
|
|
||||||
inline std::tuple<array&, array&, array&> dispatch_args(
|
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
|
||||||
array& in,
|
|
||||||
array& wt,
|
|
||||||
array& out) {
|
|
||||||
switch (backend_type) {
|
|
||||||
case CONV_BACKWARD_INPUT:
|
|
||||||
return {out, wt, in};
|
|
||||||
case CONV_BACKWARD_WEIGHT:
|
|
||||||
return {in, out, wt};
|
|
||||||
default:
|
|
||||||
return {in, wt, out};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register inputs and outputs before actually running conv op. Can only be
|
// Register inputs and outputs before actually running conv op. Can only be
|
||||||
// called once per eval_gpu.
|
// called once per eval_gpu.
|
||||||
void register_args(
|
void register_args(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
ConvBackendType backend_type,
|
||||||
array& in,
|
array& in,
|
||||||
array& wt,
|
array& wt,
|
||||||
array& intermediate_out,
|
array& intermediate_out,
|
||||||
@@ -277,11 +264,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
array in = inputs[0];
|
array in = inputs[0];
|
||||||
array wt = inputs[1];
|
array wt = inputs[1];
|
||||||
array out = out_;
|
array out = out_;
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
Dtype dtype = out.dtype();
|
Dtype dtype = out.dtype();
|
||||||
|
|
||||||
// Search cache.
|
// Search cache.
|
||||||
ConvCacheKey cache_key{
|
BytesKey<ConvCacheKey> cache_key;
|
||||||
|
cache_key.pod = {
|
||||||
encoder.device().cuda_device(),
|
encoder.device().cuda_device(),
|
||||||
dtype_to_cudnn_type(dtype),
|
dtype_to_cudnn_type(dtype),
|
||||||
vector_key(in.shape()),
|
vector_key(in.shape()),
|
||||||
@@ -296,16 +284,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
get_alignment(wt),
|
get_alignment(wt),
|
||||||
get_alignment(out)};
|
get_alignment(out)};
|
||||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||||
auto& [backend_type, plan] = it->second;
|
auto& [backend_type, graph] = it->second;
|
||||||
if (plan) {
|
if (graph) {
|
||||||
// Run cached plan.
|
// Run cached graph.
|
||||||
std::tie(in, wt, out) =
|
std::tie(in, wt, out) =
|
||||||
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
prepare_args(encoder, backend_type, in, wt, out, groups_, s);
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
|
||||||
if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
encoder,
|
||||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
{
|
||||||
}
|
{'x', gpu_ptr<void>(in)},
|
||||||
|
{'w', gpu_ptr<void>(wt)},
|
||||||
|
{'y', gpu_ptr<void>(out)},
|
||||||
|
}));
|
||||||
} else {
|
} else {
|
||||||
// Run fallback kernel.
|
// Run fallback kernel.
|
||||||
gemm_conv(
|
gemm_conv(
|
||||||
@@ -326,7 +317,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
|
|
||||||
// There is no reliable way to deduce the proper cuDNN backend for the
|
// There is no reliable way to deduce the proper cuDNN backend for the
|
||||||
// convolution, so we make a best guess and then try.
|
// convolution, so we make a best guess and then try.
|
||||||
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
|
SmallVector<ConvBackendType, 2> try_backends;
|
||||||
if (flip_) {
|
if (flip_) {
|
||||||
// When weight is flipped, we assume it is backward input convolution.
|
// When weight is flipped, we assume it is backward input convolution.
|
||||||
try_backends.push_back(CONV_BACKWARD_INPUT);
|
try_backends.push_back(CONV_BACKWARD_INPUT);
|
||||||
@@ -344,13 +335,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try to build op graph.
|
// Try to build op graph.
|
||||||
cudnnBackendDescriptorType_t backend_type;
|
ConvBackendType backend_type;
|
||||||
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
std::optional<DnnGraph> graph;
|
||||||
for (auto try_backend : try_backends) {
|
for (auto try_backend : try_backends) {
|
||||||
auto [in_copy, wt_copy, out_copy] =
|
auto [x, w, y] =
|
||||||
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
|
prepare_args(encoder, try_backend, in, wt, out, groups_, s);
|
||||||
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
auto [stride, padding_lo, padding_hi, dilation] = get_conv_settings(
|
||||||
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
|
||||||
try_backend,
|
try_backend,
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
@@ -360,7 +350,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
padding_hi_,
|
padding_hi_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_);
|
input_dilation_);
|
||||||
op_graph = build_conv_op_graph(
|
graph = build_conv_graph(
|
||||||
encoder,
|
encoder,
|
||||||
try_backend,
|
try_backend,
|
||||||
dtype,
|
dtype,
|
||||||
@@ -371,30 +361,27 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
padding_lo,
|
padding_lo,
|
||||||
padding_hi,
|
padding_hi,
|
||||||
dilation);
|
dilation);
|
||||||
if (op_graph) {
|
if (graph) {
|
||||||
backend_type = try_backend;
|
backend_type = try_backend;
|
||||||
in = std::move(in_copy);
|
in = std::move(x);
|
||||||
wt = std::move(wt_copy);
|
wt = std::move(w);
|
||||||
out = std::move(out_copy);
|
out = std::move(y);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op_graph) {
|
if (graph) {
|
||||||
// Find a plan for the graph and execute it.
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
auto plan = find_cudnn_plan_from_op_graph(
|
CHECK_CUDNN_FE_ERROR(graph->encode_capturing(
|
||||||
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
encoder,
|
||||||
if (plan) {
|
{
|
||||||
// Setup inputs and outputs.
|
{'x', gpu_ptr<void>(in)},
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
{'w', gpu_ptr<void>(wt)},
|
||||||
|
{'y', gpu_ptr<void>(out)},
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
}));
|
||||||
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
conv_cache().emplace(
|
||||||
conv_cache().emplace(
|
cache_key, std::make_pair(backend_type, std::move(*graph)));
|
||||||
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
return;
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use fallback kernel for settings not supported by cuDNN.
|
// Use fallback kernel for settings not supported by cuDNN.
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ array unfold_inputs_nd(
|
|||||||
int mat_N,
|
int mat_N,
|
||||||
ConvParams<NDIM>& params) {
|
ConvParams<NDIM>& params) {
|
||||||
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
|
array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {});
|
||||||
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
|
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
|
||||||
encoder.add_temporary(unfolded);
|
encoder.add_temporary(unfolded);
|
||||||
|
|
||||||
int filter_size = params.C;
|
int filter_size = params.C;
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ array grouped_unfold_transpose_inputs_nd(
|
|||||||
int mat_N,
|
int mat_N,
|
||||||
ConvParams<NDIM>& params) {
|
ConvParams<NDIM>& params) {
|
||||||
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
|
array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {});
|
||||||
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder.stream()));
|
unfolded.set_data(cu::malloc_async(unfolded.nbytes(), encoder));
|
||||||
encoder.add_temporary(unfolded);
|
encoder.add_temporary(unfolded);
|
||||||
|
|
||||||
int filter_size = params.C;
|
int filter_size = params.C;
|
||||||
|
|||||||
@@ -7,9 +7,8 @@ namespace mlx::core {
|
|||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
bool donated = set_copy_output_data(in, out, ctype, [&](auto n) {
|
bool donated = set_copy_output_data(
|
||||||
return cu::malloc_async(n, encoder.stream());
|
in, out, ctype, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||||
});
|
|
||||||
if (donated && in.dtype() == out.dtype()) {
|
if (donated && in.dtype() == out.dtype()) {
|
||||||
// If the output has the same type as the input then there is nothing to
|
// If the output has the same type as the input then there is nothing to
|
||||||
// copy, just use the buffer.
|
// copy, just use the buffer.
|
||||||
@@ -104,7 +103,7 @@ void fill_gpu(const array& in, array& out, const Stream& s) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
||||||
@@ -114,7 +113,7 @@ void reshape_gpu(const array& in, array& out, Stream s) {
|
|||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
if (copy_necessary) {
|
if (copy_necessary) {
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
copy_gpu_inplace(
|
copy_gpu_inplace(
|
||||||
in,
|
in,
|
||||||
out,
|
out,
|
||||||
|
|||||||
@@ -95,11 +95,14 @@ void copy_general_input(
|
|||||||
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
const InType* in_ptr = gpu_ptr<InType>(in) + offset_in;
|
||||||
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
OutType* out_ptr = gpu_ptr<OutType>(out) + offset_out;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
int work_per_thread = 1;
|
|
||||||
|
int work_per_thread = 8;
|
||||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
auto dim0 = ndim > 0 ? shape.back() : 1;
|
||||||
auto rest = out.size() / dim0;
|
auto rest = out.size() / dim0;
|
||||||
if (dim0 >= 4) {
|
if (dim0 >= 4 && dim0 < 8) {
|
||||||
work_per_thread = 4;
|
work_per_thread = 4;
|
||||||
|
} else if (dim0 < 4) {
|
||||||
|
work_per_thread = 1;
|
||||||
}
|
}
|
||||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
||||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
auto block_dims = get_block_dims(dim0, rest, 1);
|
||||||
@@ -110,7 +113,10 @@ void copy_general_input(
|
|||||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel =
|
auto kernel =
|
||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 1>;
|
||||||
if (work_per_thread == 4) {
|
if (work_per_thread == 8) {
|
||||||
|
kernel =
|
||||||
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 8>;
|
||||||
|
} else if (work_per_thread == 4) {
|
||||||
kernel =
|
kernel =
|
||||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant(), 4>;
|
||||||
}
|
}
|
||||||
@@ -127,7 +133,9 @@ void copy_general_input(
|
|||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
auto kernel = cu::copy_g<InType, OutType, IdxT, 1>;
|
||||||
if (work_per_thread == 4) {
|
if (work_per_thread == 8) {
|
||||||
|
kernel = cu::copy_g<InType, OutType, IdxT, 8>;
|
||||||
|
} else if (work_per_thread == 4) {
|
||||||
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
kernel = cu::copy_g<InType, OutType, IdxT, 4>;
|
||||||
}
|
}
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#include <cudnn.h>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -12,10 +13,12 @@ namespace mlx::core {
|
|||||||
void check_cublas_error(const char* name, cublasStatus_t err);
|
void check_cublas_error(const char* name, cublasStatus_t err);
|
||||||
void check_cuda_error(const char* name, cudaError_t err);
|
void check_cuda_error(const char* name, cudaError_t err);
|
||||||
void check_cuda_error(const char* name, CUresult err);
|
void check_cuda_error(const char* name, CUresult err);
|
||||||
|
void check_cudnn_error(const char* name, cudnnStatus_t err);
|
||||||
|
|
||||||
// The macro version that prints the command that failed.
|
// The macro version that prints the command that failed.
|
||||||
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
|
||||||
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
|
||||||
|
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||||
|
|
||||||
// Base class for RAII managed CUDA resources.
|
// Base class for RAII managed CUDA resources.
|
||||||
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
template <typename Handle, cudaError_t (*Destroy)(Handle)>
|
||||||
@@ -29,6 +32,10 @@ class CudaHandle {
|
|||||||
}
|
}
|
||||||
|
|
||||||
~CudaHandle() {
|
~CudaHandle() {
|
||||||
|
// Skip if there was an error to avoid throwing in the destructors
|
||||||
|
if (cudaPeekAtLastError() != cudaSuccess) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
reset();
|
reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,32 +7,26 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Create a cudnn tensor descriptor.
|
#define RETURN_IF_ERROR(cmd) \
|
||||||
template <typename Vec>
|
if (auto ret = cmd; ret.is_bad()) { \
|
||||||
inline cudnn_frontend::Tensor build_cudnn_tensor(
|
return ret; \
|
||||||
int64_t id,
|
}
|
||||||
const array& x,
|
|
||||||
const Vec& shape,
|
|
||||||
const Vec& strides) {
|
|
||||||
return cudnn_frontend::TensorBuilder()
|
|
||||||
.setDim(shape.size(), shape.data())
|
|
||||||
.setStrides(strides.size(), strides.data())
|
|
||||||
.setId(id)
|
|
||||||
.setAlignment(get_alignment(x))
|
|
||||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
|
// In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
|
||||||
// whether a tensor is contiguous is determined with:
|
// whether a tensor is contiguous is determined with:
|
||||||
// shape[dim] == shape[dim + 1] * strides[dim + 1]
|
// shape[dim] == shape[dim + 1] * strides[dim + 1]
|
||||||
// So a contiguous array with singleton dims in MLX may be mistakenly treated
|
// So a contiguous array with singleton dims in MLX may be mistakenly treated
|
||||||
// as strided in cuDNN, and we work around it by normalizing the strides.
|
// as strided in cuDNN, and we work around it by normalizing the strides.
|
||||||
Strides normalized_strides(const array& x) {
|
std::vector<int64_t> normalized_strides(const array& x) {
|
||||||
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
|
||||||
return x.strides();
|
if (std::all_of(
|
||||||
|
strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
|
||||||
|
strides.back() = 1;
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
if (!x.flags().row_contiguous || x.ndim() < 2) {
|
||||||
|
return strides;
|
||||||
}
|
}
|
||||||
Strides strides = x.strides();
|
|
||||||
for (int i = x.ndim() - 2; i >= 0; --i) {
|
for (int i = x.ndim() - 2; i >= 0; --i) {
|
||||||
if (x.shape(i) == 1) {
|
if (x.shape(i) == 1) {
|
||||||
strides[i] = x.shape(i + 1) * strides[i + 1];
|
strides[i] = x.shape(i + 1) * strides[i + 1];
|
||||||
@@ -42,7 +36,9 @@ Strides normalized_strides(const array& x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the shape and strides after transposing from NHWC to NCHW.
|
// Return the shape and strides after transposing from NHWC to NCHW.
|
||||||
auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
inline auto nhwc_to_nchw(const array& x) {
|
||||||
|
auto shape = convert_vector<int64_t>(x.shape());
|
||||||
|
auto strides = normalized_strides(x);
|
||||||
assert(shape.size() >= 3);
|
assert(shape.size() >= 3);
|
||||||
shape.insert(shape.begin() + 1, shape.back());
|
shape.insert(shape.begin() + 1, shape.back());
|
||||||
shape.erase(shape.end() - 1);
|
shape.erase(shape.end() - 1);
|
||||||
@@ -51,228 +47,95 @@ auto nhwc_to_nchw(SmallVector<int64_t> shape, SmallVector<int64_t> strides) {
|
|||||||
return std::make_tuple(std::move(shape), std::move(strides));
|
return std::make_tuple(std::move(shape), std::move(strides));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline auto nhwc_to_nchw(const array& x) {
|
|
||||||
return nhwc_to_nchw(
|
|
||||||
convert_vector<int64_t>(x.shape()), normalized_strides(x));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return available engines for a |op_graph|.
|
|
||||||
cudnn_frontend::EngineConfigList get_cudnn_engine_configs(
|
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
|
||||||
Dtype dtype,
|
|
||||||
cudnn_frontend::OperationGraph& op_graph,
|
|
||||||
bool use_fallback = true) {
|
|
||||||
SmallVector<cudnn_frontend::GeneratorSource, 2> sources;
|
|
||||||
sources.push_back([](auto& op_graph) {
|
|
||||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
|
||||||
.setOperationGraph(op_graph)
|
|
||||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
|
||||||
.build();
|
|
||||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
|
||||||
});
|
|
||||||
if (use_fallback) {
|
|
||||||
sources.push_back([&backend_type](auto& op_graph) {
|
|
||||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
|
||||||
.setOperationGraph(op_graph)
|
|
||||||
.setOperation(backend_type)
|
|
||||||
.build();
|
|
||||||
return fallback.getFallbackList();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
auto configs =
|
|
||||||
cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data())
|
|
||||||
.generate_engine_config(op_graph);
|
|
||||||
|
|
||||||
cudnn_frontend::EngineConfigList filtered_configs;
|
|
||||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
|
||||||
if (cudnn_frontend::hasNumericalNote<
|
|
||||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
|
||||||
dtype == float32 && !env::enable_tf32()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
return filtered_configs;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Take |engine_configs| and |op_graph| and find a working execution plans
|
|
||||||
// from them.
|
|
||||||
std::optional<cudnn_frontend::ExecutionPlan>
|
|
||||||
find_cudnn_plan_from_engine_configs(
|
|
||||||
cudnnHandle_t handle,
|
|
||||||
const cudnn_frontend::EngineConfigList& engine_configs,
|
|
||||||
const cudnn_frontend::OperationGraph& op_graph) {
|
|
||||||
auto op_graph_tag = op_graph.getTag();
|
|
||||||
for (const auto& config : engine_configs) {
|
|
||||||
try {
|
|
||||||
return cudnn_frontend::ExecutionPlanBuilder()
|
|
||||||
.setHandle(handle)
|
|
||||||
.setEngineConfig(config, op_graph_tag)
|
|
||||||
.build();
|
|
||||||
} catch (cudnn_frontend::cudnnException& error) {
|
|
||||||
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare workspace and args to execute plan.
|
|
||||||
template <typename F>
|
|
||||||
bool prepare_cudnn_plan(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
int num_args,
|
|
||||||
const int64_t* uids,
|
|
||||||
void** data_ptrs,
|
|
||||||
F&& execute) {
|
|
||||||
int workspace_size = plan.getWorkspaceSize();
|
|
||||||
void* workspace_ptr = nullptr;
|
|
||||||
if (workspace_size > 0) {
|
|
||||||
array workspace(
|
|
||||||
cu::malloc_async(workspace_size, encoder.stream()),
|
|
||||||
{workspace_size},
|
|
||||||
uint8);
|
|
||||||
encoder.add_temporary(workspace);
|
|
||||||
workspace_ptr = gpu_ptr<void>(workspace);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto args = cudnn_frontend::VariantPackBuilder()
|
|
||||||
.setWorkspacePointer(workspace_ptr)
|
|
||||||
.setDataPointers(num_args, data_ptrs)
|
|
||||||
.setUids(num_args, uids)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
auto handle = encoder.device().cudnn_handle();
|
|
||||||
cudnnSetStream(handle, encoder.stream());
|
|
||||||
|
|
||||||
if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) {
|
fe::error_t DnnGraph::prepare() {
|
||||||
auto shape = convert_vector<int64_t>(x.shape());
|
RETURN_IF_ERROR(validate());
|
||||||
return build_cudnn_tensor(id, x, shape, normalized_strides(x));
|
try {
|
||||||
|
RETURN_IF_ERROR(build_operation_graph(handle_));
|
||||||
|
} catch (cudnn_frontend::cudnnException& error) {
|
||||||
|
// cuDNN bug: they did not catch all exceptions in the API.
|
||||||
|
return {fe::error_code_t::CUDNN_BACKEND_API_FAILED, error.what()};
|
||||||
|
}
|
||||||
|
RETURN_IF_ERROR(create_execution_plans({fe::HeurMode_t::A}));
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) {
|
fe::error_t DnnGraph::build() {
|
||||||
|
RETURN_IF_ERROR(check_support(handle_));
|
||||||
|
RETURN_IF_ERROR(build_plans(handle_));
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
fe::error_t DnnGraph::encode_graph(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
std::unordered_map<int64_t, void*> variant_pack) {
|
||||||
|
cudnnSetStream(handle_, encoder.stream());
|
||||||
|
CudaGraph cuda_graph(encoder.device());
|
||||||
|
RETURN_IF_ERROR(populate_cuda_graph(
|
||||||
|
handle_, variant_pack, prepare_workspace(encoder), cuda_graph));
|
||||||
|
encoder.add_graph_node(cuda_graph);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
fe::error_t DnnGraph::encode_capturing(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
std::unordered_map<int64_t, void*> variant_pack) {
|
||||||
|
auto* workspace_ptr = prepare_workspace(encoder);
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
cudnnSetStream(handle_, encoder.stream());
|
||||||
|
auto ret = execute(handle_, variant_pack, workspace_ptr);
|
||||||
|
if (ret.is_bad()) {
|
||||||
|
capture.discard = true;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* DnnGraph::prepare_workspace(cu::CommandEncoder& encoder) {
|
||||||
|
int64_t workspace_size = 0;
|
||||||
|
CHECK_CUDNN_FE_ERROR(get_workspace_size(workspace_size));
|
||||||
|
if (workspace_size > 0) {
|
||||||
|
array workspace(
|
||||||
|
cu::malloc_async(workspace_size, encoder),
|
||||||
|
{static_cast<int>(workspace_size)},
|
||||||
|
uint8);
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
return gpu_ptr<void>(workspace);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DnnGraph::set_tensor_attrs(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x,
|
||||||
|
const std::vector<int64_t>& shape,
|
||||||
|
const std::vector<int64_t>& strides) {
|
||||||
|
tensor->set_uid(uid)
|
||||||
|
.set_alignment(get_alignment(x))
|
||||||
|
.set_data_type(dtype_to_cudnn_type(x.dtype()))
|
||||||
|
.set_dim(shape)
|
||||||
|
.set_stride(strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DnnGraph::set_tensor_attrs(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x) {
|
||||||
|
set_tensor_attrs(
|
||||||
|
tensor,
|
||||||
|
uid,
|
||||||
|
x,
|
||||||
|
convert_vector<int64_t>(x.shape()),
|
||||||
|
normalized_strides(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
void DnnGraph::set_tensor_attrs_nchw(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x) {
|
||||||
auto [shape, strides] = nhwc_to_nchw(x);
|
auto [shape, strides] = nhwc_to_nchw(x);
|
||||||
return build_cudnn_tensor(id, x, shape, strides);
|
set_tensor_attrs(tensor, uid, x, shape, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) {
|
|
||||||
if (x.ndim() == 0) {
|
|
||||||
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
|
||||||
return build_cudnn_tensor(id, x, scalar_dims, scalar_dims);
|
|
||||||
}
|
|
||||||
if (x.ndim() == 1) {
|
|
||||||
int64_t s = x.shape(0);
|
|
||||||
SmallVector<int64_t, 4> shape = {1, x.shape(0), 1, 1};
|
|
||||||
SmallVector<int64_t, 4> strides = {s, 1, s, s};
|
|
||||||
return build_cudnn_tensor(id, x, shape, strides);
|
|
||||||
}
|
|
||||||
if (x.ndim() == 2) {
|
|
||||||
int64_t s =
|
|
||||||
x.flags().row_contiguous ? x.shape(1) * x.strides(1) : x.strides(0);
|
|
||||||
SmallVector<int64_t, 4> shape = {x.shape(0), x.shape(1), 1, 1};
|
|
||||||
SmallVector<int64_t, 4> strides = {s, x.strides(1), s, s};
|
|
||||||
return build_cudnn_tensor(id, x, shape, strides);
|
|
||||||
}
|
|
||||||
if (x.ndim() == 3 || x.ndim() == 4) {
|
|
||||||
return build_cudnn_tensor_nchw(id, x);
|
|
||||||
}
|
|
||||||
throw std::runtime_error(
|
|
||||||
fmt::format("Unsupported array with {} dims.", x.ndim()));
|
|
||||||
}
|
|
||||||
|
|
||||||
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) {
|
|
||||||
SmallVector<int64_t, 4> scalar_dims = {1, 1, 1, 1};
|
|
||||||
return cudnn_frontend::TensorBuilder()
|
|
||||||
.setDim(scalar_dims.size(), scalar_dims.data())
|
|
||||||
.setStrides(scalar_dims.size(), scalar_dims.data())
|
|
||||||
.setId(id)
|
|
||||||
.setAlignment(16)
|
|
||||||
.setDataType(dtype_to_cudnn_type(dtype))
|
|
||||||
.setByValue(true)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
|
||||||
cudnnHandle_t handle,
|
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
|
||||||
Dtype dtype,
|
|
||||||
cudnn_frontend::OperationGraph& op_graph) {
|
|
||||||
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
|
||||||
if (engine_configs.empty()) {
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool encode_cudnn_plan_with_capturing(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
int num_args,
|
|
||||||
const int64_t* uids,
|
|
||||||
void** data_ptrs) {
|
|
||||||
return prepare_cudnn_plan(
|
|
||||||
encoder,
|
|
||||||
plan,
|
|
||||||
num_args,
|
|
||||||
uids,
|
|
||||||
data_ptrs,
|
|
||||||
[&](auto handle, auto plan, auto args) {
|
|
||||||
auto capture = encoder.capture_context();
|
|
||||||
if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) {
|
|
||||||
// Discard the captured graph when failed.
|
|
||||||
capture.discard = true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#if CUDNN_VERSION >= 90500
|
|
||||||
bool encode_cudnn_plan_with_graph_api(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
CudaGraph& graph,
|
|
||||||
int num_args,
|
|
||||||
const int64_t* uids,
|
|
||||||
void** data_ptrs) {
|
|
||||||
return prepare_cudnn_plan(
|
|
||||||
encoder,
|
|
||||||
plan,
|
|
||||||
num_args,
|
|
||||||
uids,
|
|
||||||
data_ptrs,
|
|
||||||
[&](auto handle, auto plan, auto args) {
|
|
||||||
if (!graph) {
|
|
||||||
graph = CudaGraph(encoder.device());
|
|
||||||
if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) !=
|
|
||||||
CUDNN_STATUS_SUCCESS) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) !=
|
|
||||||
CUDNN_STATUS_SUCCESS) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
encoder.add_graph_node(graph);
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,25 +2,30 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
|
||||||
#include "mlx/backend/cuda/device/config.h"
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
#include <cudnn_frontend.h>
|
#include <cudnn_frontend.h>
|
||||||
#include <cudnn_frontend_find_plan.h>
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <array>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
class CommandEncoder;
|
class CommandEncoder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace fe = cudnn_frontend;
|
||||||
|
|
||||||
|
#define CHECK_CUDNN_FE_ERROR(cmd) \
|
||||||
|
do { \
|
||||||
|
auto error = cmd; \
|
||||||
|
if (!error.is_good()) { \
|
||||||
|
throw std::runtime_error( \
|
||||||
|
fmt::format("{} failed: {}.", #cmd, error.get_message())); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
// Return pointer alignment of |x|'s data.
|
// Return pointer alignment of |x|'s data.
|
||||||
inline uint8_t get_alignment(const array& x) {
|
inline uint8_t get_alignment(const array& x) {
|
||||||
uint8_t alignment = 1;
|
uint8_t alignment = 1;
|
||||||
@@ -35,8 +40,31 @@ inline uint8_t get_alignment(const array& x) {
|
|||||||
|
|
||||||
// Convert the type of elements in |vec| to |T|.
|
// Convert the type of elements in |vec| to |T|.
|
||||||
template <typename T, typename Vec>
|
template <typename T, typename Vec>
|
||||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
inline std::vector<T> convert_vector(const Vec& vec) {
|
||||||
return SmallVector<T>(vec.begin(), vec.end());
|
return std::vector<T>(vec.begin(), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map dtype to cudnn data type.
|
||||||
|
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case int8:
|
||||||
|
return fe::DataType_t::INT8;
|
||||||
|
case int32:
|
||||||
|
return fe::DataType_t::INT32;
|
||||||
|
case uint8:
|
||||||
|
return fe::DataType_t::UINT8;
|
||||||
|
case float16:
|
||||||
|
return fe::DataType_t::HALF;
|
||||||
|
case bfloat16:
|
||||||
|
return fe::DataType_t::BFLOAT16;
|
||||||
|
case float32:
|
||||||
|
return fe::DataType_t::FLOAT;
|
||||||
|
case float64:
|
||||||
|
return fe::DataType_t::DOUBLE;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
|
// Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
|
||||||
@@ -44,122 +72,100 @@ inline SmallVector<T> convert_vector(const Vec& vec) {
|
|||||||
// There are 2 differences from the const_param util from kernel_utils.cuh:
|
// There are 2 differences from the const_param util from kernel_utils.cuh:
|
||||||
// 1. The rest of array is filled with 0.
|
// 1. The rest of array is filled with 0.
|
||||||
// 2. This util can be used in .cpp files.
|
// 2. This util can be used in .cpp files.
|
||||||
template <typename T, template <typename U> class Vec>
|
template <int NDIM = MAX_NDIM, typename T, template <typename U> class Vec>
|
||||||
inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
|
inline std::array<T, NDIM> vector_key(const Vec<T>& vec) {
|
||||||
if (vec.size() > MAX_NDIM) {
|
if (vec.size() > NDIM) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||||
}
|
}
|
||||||
std::array<T, MAX_NDIM> result = {};
|
std::array<T, NDIM> result = {};
|
||||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helpers used by get_data_ptrs to get pointers.
|
// Extends cuDNN graph with helpers.
|
||||||
inline void* get_data_ptr(const array& arr) {
|
class DnnGraph : public fe::graph::Graph {
|
||||||
return const_cast<void*>(gpu_ptr<void>(arr));
|
public:
|
||||||
}
|
DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32)
|
||||||
|
: handle_(handle) {
|
||||||
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>
|
set_io_data_type(dtype_to_cudnn_type(io_dtype));
|
||||||
inline void* get_data_ptr(T& scalar) {
|
set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype));
|
||||||
return &scalar;
|
set_compute_data_type(dtype_to_cudnn_type(compute_dtype));
|
||||||
}
|
|
||||||
|
|
||||||
// Return an array filled with data pointers of args.
|
|
||||||
template <typename... Args>
|
|
||||||
inline std::array<void*, sizeof...(Args)> get_data_ptrs(Args&... args) {
|
|
||||||
return {get_data_ptr(args)...};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Map dtype to cudnn data type.
|
|
||||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
|
||||||
switch (dtype) {
|
|
||||||
case int8:
|
|
||||||
return CUDNN_DATA_INT8;
|
|
||||||
case int32:
|
|
||||||
return CUDNN_DATA_INT32;
|
|
||||||
case uint8:
|
|
||||||
return CUDNN_DATA_UINT8;
|
|
||||||
case float16:
|
|
||||||
return CUDNN_DATA_HALF;
|
|
||||||
case bfloat16:
|
|
||||||
return CUDNN_DATA_BFLOAT16;
|
|
||||||
case float32:
|
|
||||||
return CUDNN_DATA_FLOAT;
|
|
||||||
case float64:
|
|
||||||
return CUDNN_DATA_DOUBLE;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Create a tensor descriptor from |x|.
|
// Create a cuDNN tensor description from MLX array |x|.
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x);
|
auto& tensor(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x) {
|
||||||
|
set_tensor_attrs(attrs, uid, x);
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
auto tensor(const char* name, int64_t uid, const array& x) {
|
||||||
|
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
|
||||||
|
tensor(attrs, uid, x);
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
|
||||||
// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW.
|
// Create a cuDNN tensor description from MLX array |x|, and transpose it from
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x);
|
// NHWC layout to NCHW.
|
||||||
|
auto& tensor_nchw(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x) {
|
||||||
|
set_tensor_attrs_nchw(attrs, uid, x);
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
auto tensor_nchw(const char* name, int64_t uid, const array& x) {
|
||||||
|
auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
|
||||||
|
tensor_nchw(attrs, uid, x);
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
|
||||||
// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it
|
// Create a cuDNN tensor for scalar.
|
||||||
// from NHWC to NCHW.
|
auto scalar(const char* name, int64_t uid, Dtype dtype) {
|
||||||
cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x);
|
return Graph::tensor(fe::graph::Tensor_attributes()
|
||||||
|
.set_name(name)
|
||||||
|
.set_uid(uid)
|
||||||
|
.set_dim({1, 1, 1, 1})
|
||||||
|
.set_stride({1, 1, 1, 1})
|
||||||
|
.set_is_pass_by_value(true)
|
||||||
|
.set_data_type(dtype_to_cudnn_type(dtype)));
|
||||||
|
}
|
||||||
|
|
||||||
// Create a 4D scalar tensor descriptor, which is passed by value.
|
// Call this before setting notes.
|
||||||
cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype);
|
fe::error_t prepare();
|
||||||
|
// Call this after setting notes.
|
||||||
|
fe::error_t build();
|
||||||
|
|
||||||
// Find a working plan for |op_graph|.
|
// Add cuDNN graph to CUDA graph, using native CUDA graph API.
|
||||||
std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
fe::error_t encode_graph(
|
||||||
cudnnHandle_t handle,
|
cu::CommandEncoder& encoder,
|
||||||
cudnnBackendDescriptorType_t backend_type,
|
std::unordered_map<int64_t, void*> variant_pack);
|
||||||
Dtype dtype,
|
// Add cuDNN graph to CUDA graph, using stream capture.
|
||||||
cudnn_frontend::OperationGraph& op_graph);
|
fe::error_t encode_capturing(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
std::unordered_map<int64_t, void*> variant_pack);
|
||||||
|
|
||||||
// Encode the plan to command buffer by capturing.
|
private:
|
||||||
bool encode_cudnn_plan_with_capturing(
|
void* prepare_workspace(cu::CommandEncoder& encoder);
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
int num_args,
|
|
||||||
const int64_t* uids,
|
|
||||||
void** data_ptrs);
|
|
||||||
|
|
||||||
#if CUDNN_VERSION >= 90500
|
void set_tensor_attrs(
|
||||||
// Encode the plan to command buffer by using native graph api of cudnn. If the
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
// |graph| is empty it will be populated, otherwise it will be updated.
|
int64_t uid,
|
||||||
bool encode_cudnn_plan_with_graph_api(
|
const array& x,
|
||||||
cu::CommandEncoder& encoder,
|
const std::vector<int64_t>& shape,
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
const std::vector<int64_t>& strides);
|
||||||
CudaGraph& graph,
|
void set_tensor_attrs(
|
||||||
int num_args,
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
const int64_t* uids,
|
int64_t uid,
|
||||||
void** data_ptrs);
|
const array& x);
|
||||||
#endif
|
void set_tensor_attrs_nchw(
|
||||||
|
std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
|
||||||
|
int64_t uid,
|
||||||
|
const array& x);
|
||||||
|
|
||||||
// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z).
|
cudnnHandle_t handle_;
|
||||||
template <typename... Args>
|
};
|
||||||
bool encode_cudnn_plan(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
std::initializer_list<int64_t> uids,
|
|
||||||
Args&... args) {
|
|
||||||
assert(uids.size() == sizeof...(args));
|
|
||||||
auto data_ptrs = get_data_ptrs(args...);
|
|
||||||
return encode_cudnn_plan_with_capturing(
|
|
||||||
encoder, plan, uids.size(), uids.begin(), data_ptrs.data());
|
|
||||||
}
|
|
||||||
|
|
||||||
#if CUDNN_VERSION >= 90500
|
|
||||||
template <typename... Args>
|
|
||||||
bool encode_cudnn_plan(
|
|
||||||
cu::CommandEncoder& encoder,
|
|
||||||
cudnn_frontend::ExecutionPlan& plan,
|
|
||||||
CudaGraph& graph,
|
|
||||||
std::initializer_list<int64_t> uids,
|
|
||||||
Args&... args) {
|
|
||||||
assert(uids.size() == sizeof...(args));
|
|
||||||
auto data_ptrs = get_data_ptrs(args...);
|
|
||||||
return encode_cudnn_plan_with_graph_api(
|
|
||||||
encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data());
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ std::string build_kernel(
|
|||||||
const std::vector<std::string>& output_names,
|
const std::vector<std::string>& output_names,
|
||||||
const std::vector<Dtype>& output_dtypes,
|
const std::vector<Dtype>& output_dtypes,
|
||||||
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
const std::vector<std::pair<std::string, TemplateArg>>& template_args,
|
||||||
const std::vector<CustomKernelShapeInfo>& shape_infos) {
|
const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
|
||||||
std::string kernel_source;
|
std::string kernel_source;
|
||||||
kernel_source.reserve(header.size() + source.size() + 8192);
|
kernel_source.reserve(header.size() + source.size() + 8192);
|
||||||
kernel_source += default_header;
|
kernel_source += default_header;
|
||||||
@@ -81,17 +81,17 @@ std::string build_kernel(
|
|||||||
kernel_source += ",\n";
|
kernel_source += ",\n";
|
||||||
// Add input shape, strides and ndim if present in the source
|
// Add input shape, strides and ndim if present in the source
|
||||||
if (arr.ndim() > 0) {
|
if (arr.ndim() > 0) {
|
||||||
if (shape_infos[i].shape) {
|
if (std::get<0>(shape_infos[i])) {
|
||||||
kernel_source += " const __grid_constant__ Shape ";
|
kernel_source += " const __grid_constant__ Shape ";
|
||||||
kernel_source += name;
|
kernel_source += name;
|
||||||
kernel_source += "_shape,\n";
|
kernel_source += "_shape,\n";
|
||||||
}
|
}
|
||||||
if (shape_infos[i].strides) {
|
if (std::get<1>(shape_infos[i])) {
|
||||||
kernel_source += " const __grid_constant__ Strides ";
|
kernel_source += " const __grid_constant__ Strides ";
|
||||||
kernel_source += name;
|
kernel_source += name;
|
||||||
kernel_source += "_strides,\n";
|
kernel_source += "_strides,\n";
|
||||||
}
|
}
|
||||||
if (shape_infos[i].ndim) {
|
if (std::get<2>(shape_infos[i])) {
|
||||||
kernel_source += " const __grid_constant__ int ";
|
kernel_source += " const __grid_constant__ int ";
|
||||||
kernel_source += name;
|
kernel_source += name;
|
||||||
kernel_source += "_ndim,\n";
|
kernel_source += "_ndim,\n";
|
||||||
@@ -154,12 +154,12 @@ CustomKernelFunction cuda_kernel(
|
|||||||
"[custom_kernel] Must specify at least one output.");
|
"[custom_kernel] Must specify at least one output.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
std::vector<std::tuple<bool, bool, bool>> shape_infos;
|
||||||
for (auto& n : input_names) {
|
for (auto& n : input_names) {
|
||||||
CustomKernelShapeInfo shape_info;
|
std::tuple<bool, bool, bool> shape_info;
|
||||||
shape_info.shape = source.find(n + "_shape") != std::string::npos;
|
std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
|
||||||
shape_info.strides = source.find(n + "_strides") != std::string::npos;
|
std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
|
||||||
shape_info.ndim = source.find(n + "_ndim") != std::string::npos;
|
std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
|
||||||
shape_infos.push_back(shape_info);
|
shape_infos.push_back(shape_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -254,8 +254,8 @@ std::vector<array> precompiled_cuda_kernel(
|
|||||||
std::optional<float> init_value,
|
std::optional<float> init_value,
|
||||||
bool ensure_row_contiguous,
|
bool ensure_row_contiguous,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
std::vector<CustomKernelShapeInfo> shape_infos(
|
std::vector<std::tuple<bool, bool, bool>> shape_infos(
|
||||||
inputs.size(), CustomKernelShapeInfo{false, false, false});
|
inputs.size(), {false, false, false});
|
||||||
return array::make_arrays(
|
return array::make_arrays(
|
||||||
output_shapes,
|
output_shapes,
|
||||||
output_dtypes,
|
output_dtypes,
|
||||||
@@ -289,7 +289,7 @@ void CustomKernel::eval_gpu(
|
|||||||
copies.emplace_back(init_value_.value(), out.dtype());
|
copies.emplace_back(init_value_.value(), out.dtype());
|
||||||
fill_gpu(copies.back(), out, s);
|
fill_gpu(copies.back(), out, s);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -327,13 +327,13 @@ void CustomKernel::eval_gpu(
|
|||||||
const array& in = checked_inputs[i];
|
const array& in = checked_inputs[i];
|
||||||
auto& shape_info = shape_infos_[i];
|
auto& shape_info = shape_infos_[i];
|
||||||
args.append(in);
|
args.append(in);
|
||||||
if (shape_info.shape) {
|
if (std::get<0>(shape_info)) {
|
||||||
args.append_ndim(in.shape());
|
args.append_ndim(in.shape());
|
||||||
}
|
}
|
||||||
if (shape_info.strides) {
|
if (std::get<1>(shape_info)) {
|
||||||
args.append_ndim(in.strides());
|
args.append_ndim(in.strides());
|
||||||
}
|
}
|
||||||
if (shape_info.ndim) {
|
if (std::get<2>(shape_info)) {
|
||||||
args.append<int32_t>(in.ndim());
|
args.append<int32_t>(in.ndim());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,20 +14,20 @@ namespace mlx::core::cu {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
bool use_cuda_graphs() {
|
||||||
|
static bool use_graphs = env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
||||||
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
return use_graphs;
|
||||||
if (err != CUDNN_STATUS_SUCCESS) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool use_cuda_graphs() {
|
const char* save_cuda_graphs_dot_file() {
|
||||||
static bool use_graphs = []() {
|
static const char* filename = []() -> const char* {
|
||||||
return env::get_var("MLX_USE_CUDA_GRAPHS", true);
|
const char* env = std::getenv("MLX_SAVE_CUDA_GRAPHS_DOT_FILE");
|
||||||
|
if (env && std::strlen(env) == 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return env;
|
||||||
}();
|
}();
|
||||||
return use_graphs;
|
return filename;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -46,6 +46,7 @@ Device::Device(int device) : device_(device) {
|
|||||||
"Device {} does not support synchronization in managed memory.",
|
"Device {} does not support synchronization in managed memory.",
|
||||||
device_));
|
device_));
|
||||||
}
|
}
|
||||||
|
|
||||||
// The cublasLt handle is used by matmul.
|
// The cublasLt handle is used by matmul.
|
||||||
make_current();
|
make_current();
|
||||||
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
||||||
@@ -86,7 +87,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeThreadLocal));
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||||
@@ -114,18 +115,17 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Use an empty graph node for synchronization
|
// Use an empty graph node for synchronization
|
||||||
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
CommandEncoder::GraphNode empty{NULL, "E", std::to_string(enc.node_count_++)};
|
||||||
enc.empty_node_count_++;
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
|
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
|
||||||
|
|
||||||
// Insert the concurrent -> empty node dependencies
|
// Insert the concurrent -> empty node dependencies
|
||||||
for (auto& from : enc.concurrent_nodes_) {
|
for (auto& from : enc.concurrent_nodes_) {
|
||||||
enc.from_nodes_.push_back(from.node);
|
enc.from_nodes_.push_back(from.node);
|
||||||
enc.to_nodes_.push_back(empty.node);
|
enc.to_nodes_.push_back(empty.node);
|
||||||
enc.graph_key_ += from.id;
|
enc.graph_deps_key_ += from.id;
|
||||||
enc.graph_key_ += from.node_type;
|
enc.graph_deps_key_ += "-";
|
||||||
enc.graph_key_ += empty.id;
|
enc.graph_deps_key_ += empty.id;
|
||||||
enc.graph_key_ += empty.node_type;
|
enc.graph_deps_key_ += "-";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert the input -> concurrent node dependencies without updating output
|
// Insert the input -> concurrent node dependencies without updating output
|
||||||
@@ -140,9 +140,6 @@ CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
||||||
if (node.node_type == 'G') {
|
|
||||||
graph_node_count_++;
|
|
||||||
}
|
|
||||||
node.id = std::to_string(node_count_++);
|
node.id = std::to_string(node_count_++);
|
||||||
if (in_concurrent_) {
|
if (in_concurrent_) {
|
||||||
concurrent_nodes_.push_back(std::move(node));
|
concurrent_nodes_.push_back(std::move(node));
|
||||||
@@ -154,6 +151,10 @@ void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||||
|
for (auto& node : nodes) {
|
||||||
|
graph_nodes_key_ += node.node_type;
|
||||||
|
graph_nodes_key_ += "-";
|
||||||
|
}
|
||||||
std::vector<GraphNode> deps;
|
std::vector<GraphNode> deps;
|
||||||
{
|
{
|
||||||
// Dependencies must be added in the same order to produce a consistent
|
// Dependencies must be added in the same order to produce a consistent
|
||||||
@@ -181,20 +182,49 @@ void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
|||||||
for (auto& to : nodes) {
|
for (auto& to : nodes) {
|
||||||
from_nodes_.push_back(from.node);
|
from_nodes_.push_back(from.node);
|
||||||
to_nodes_.push_back(to.node);
|
to_nodes_.push_back(to.node);
|
||||||
graph_key_ += from.id;
|
graph_deps_key_ += from.id;
|
||||||
graph_key_ += from.node_type;
|
graph_deps_key_ += "-";
|
||||||
graph_key_ += to.id;
|
graph_deps_key_ += to.id;
|
||||||
graph_key_ += to.node_type;
|
graph_deps_key_ += "-";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Can be tuned with MLX_MAX_OPS_PER_BUFFER, MLX_MAX_MB_PER_BUFFER
|
||||||
|
std::pair<int, int> get_graph_limits(Device& d) {
|
||||||
|
auto cc =
|
||||||
|
d.compute_capability_major() * 100 + d.compute_capability_minor() * 10;
|
||||||
|
int ops = 20;
|
||||||
|
int mb = 100;
|
||||||
|
switch (cc) {
|
||||||
|
case 800: // A100
|
||||||
|
ops = 20;
|
||||||
|
mb = 400;
|
||||||
|
break;
|
||||||
|
case 900: // H100
|
||||||
|
ops = 30;
|
||||||
|
mb = 400;
|
||||||
|
break;
|
||||||
|
case 1000: // B200
|
||||||
|
ops = 50;
|
||||||
|
mb = 500;
|
||||||
|
break;
|
||||||
|
case 1210: // DGX Spark
|
||||||
|
ops = 20;
|
||||||
|
mb = 25;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return {env::max_ops_per_buffer(ops), env::max_mb_per_buffer(mb)};
|
||||||
|
}
|
||||||
|
|
||||||
CommandEncoder::CommandEncoder(Device& d)
|
CommandEncoder::CommandEncoder(Device& d)
|
||||||
: device_(d),
|
: device_(d),
|
||||||
stream_(d),
|
stream_(d),
|
||||||
graph_(d),
|
graph_(d),
|
||||||
worker_(d),
|
worker_(d),
|
||||||
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
|
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {
|
||||||
|
std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(d);
|
||||||
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||||
worker_.add_task(std::move(task));
|
worker_.add_task(std::move(task));
|
||||||
@@ -204,6 +234,7 @@ void CommandEncoder::set_input_array(const array& arr) {
|
|||||||
if (!use_cuda_graphs()) {
|
if (!use_cuda_graphs()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
bytes_in_graph_ += arr.data_size();
|
||||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||||
active_deps_.push_back(id);
|
active_deps_.push_back(id);
|
||||||
}
|
}
|
||||||
@@ -278,13 +309,76 @@ void CommandEncoder::add_kernel_node(
|
|||||||
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
insert_graph_dependencies(GraphNode{node, "K"});
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||||
CUgraphNode node;
|
CUgraphNode node;
|
||||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
insert_graph_dependencies(GraphNode{node, "K"});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
||||||
|
// Constructs a key representing the nodes of a sub-graph.
|
||||||
|
// Also checks if the sub-graph is updatable as CUDA graphs do not get
|
||||||
|
// updated correctly if a kernel node getting updated has a different cluster
|
||||||
|
// shape than the node it's being updated with.
|
||||||
|
std::string key = "(";
|
||||||
|
size_t num_nodes = 0;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nullptr, &num_nodes));
|
||||||
|
if (num_nodes == 0) {
|
||||||
|
return {key + ")", true};
|
||||||
|
}
|
||||||
|
bool is_updatable = true;
|
||||||
|
std::vector<cudaGraphNode_t> nodes(num_nodes);
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphGetNodes(graph, nodes.data(), &num_nodes));
|
||||||
|
for (const auto& node : nodes) {
|
||||||
|
if (!is_updatable) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
cudaGraphNodeType type;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||||
|
switch (type) {
|
||||||
|
case cudaGraphNodeTypeGraph: {
|
||||||
|
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||||
|
cudaGraph_t child;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||||
|
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
||||||
|
is_updatable &= sub_is_updatable;
|
||||||
|
key += subkey;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case cudaGraphNodeTypeHost:
|
||||||
|
key += "H";
|
||||||
|
break;
|
||||||
|
case cudaGraphNodeTypeMemset:
|
||||||
|
key += "M";
|
||||||
|
break;
|
||||||
|
case cudaGraphNodeTypeKernel: {
|
||||||
|
cudaLaunchAttributeValue cluster_dim;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||||
|
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||||
|
// Only allow dim.x to be greater than 1
|
||||||
|
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||||
|
is_updatable = false;
|
||||||
|
} else {
|
||||||
|
key += "K";
|
||||||
|
key += std::to_string(cluster_dim.clusterDim.x);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case cudaGraphNodeTypeWaitEvent:
|
||||||
|
key += "W";
|
||||||
|
break;
|
||||||
|
case cudaGraphNodeTypeEventRecord:
|
||||||
|
key += "R";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
is_updatable = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
key += ")";
|
||||||
|
return {key, is_updatable};
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||||
@@ -297,12 +391,15 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
|
auto [sub_graph_key, is_updatable] = subgraph_to_key(child);
|
||||||
|
is_graph_updatable_ &= is_updatable;
|
||||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
insert_graph_dependencies(GraphNode{node, sub_graph_key});
|
||||||
}
|
}
|
||||||
|
|
||||||
int CommandEncoder::get_num_ops() {
|
bool CommandEncoder::needs_commit() {
|
||||||
return node_count_;
|
return (node_count_ > max_ops_per_graph_) ||
|
||||||
|
((bytes_in_graph_ >> 20) > max_mb_per_graph_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::commit() {
|
void CommandEncoder::commit() {
|
||||||
@@ -322,53 +419,63 @@ void CommandEncoder::commit() {
|
|||||||
from_nodes_.size()));
|
from_nodes_.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_key_ += ".";
|
|
||||||
graph_key_ += std::to_string(node_count_);
|
|
||||||
graph_key_ += ".";
|
|
||||||
graph_key_ += std::to_string(graph_node_count_);
|
|
||||||
graph_key_ += ".";
|
|
||||||
graph_key_ += std::to_string(empty_node_count_);
|
|
||||||
|
|
||||||
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
|
|
||||||
|
|
||||||
if (graph_exec != nullptr) {
|
|
||||||
cudaGraphExecUpdateResult update_result;
|
|
||||||
#if CUDART_VERSION >= 12000
|
|
||||||
cudaGraphExecUpdateResultInfo info;
|
|
||||||
cudaGraphExecUpdate(graph_exec, graph_, &info);
|
|
||||||
update_result = info.result;
|
|
||||||
#else
|
|
||||||
cudaGraphNode_t error_node;
|
|
||||||
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
|
|
||||||
#endif // CUDART_VERSION >= 12000
|
|
||||||
if (update_result != cudaGraphExecUpdateSuccess) {
|
|
||||||
cudaGetLastError(); // reset error
|
|
||||||
graph_exec.reset();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (graph_exec == nullptr) {
|
|
||||||
graph_exec.instantiate(graph_);
|
|
||||||
}
|
|
||||||
device_.make_current();
|
device_.make_current();
|
||||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
|
||||||
|
if (!is_graph_updatable_) {
|
||||||
|
CudaGraphExec graph_exec;
|
||||||
|
graph_exec.instantiate(graph_);
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
|
} else {
|
||||||
|
auto graph_key = graph_nodes_key_ + ":" + graph_deps_key_;
|
||||||
|
auto& graph_exec = graph_cache_[graph_key];
|
||||||
|
|
||||||
|
if (graph_exec != nullptr) {
|
||||||
|
cudaGraphExecUpdateResult update_result;
|
||||||
|
#if CUDART_VERSION >= 12000
|
||||||
|
cudaGraphExecUpdateResultInfo info;
|
||||||
|
cudaGraphExecUpdate(graph_exec, graph_, &info);
|
||||||
|
update_result = info.result;
|
||||||
|
#else
|
||||||
|
cudaGraphNode_t error_node;
|
||||||
|
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
|
||||||
|
#endif // CUDART_VERSION >= 12000
|
||||||
|
if (update_result != cudaGraphExecUpdateSuccess) {
|
||||||
|
cudaGetLastError(); // reset error
|
||||||
|
graph_exec.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (graph_exec == nullptr) {
|
||||||
|
graph_exec.instantiate(graph_);
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save cuda graph to dot file
|
||||||
|
if (const char* filename = save_cuda_graphs_dot_file(); filename) {
|
||||||
|
static int count = 0;
|
||||||
|
auto path = fmt::format("{}_{}.dot", filename, ++count);
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphDebugDotPrint(graph_, path.c_str(), 0));
|
||||||
|
}
|
||||||
|
|
||||||
// Reset state
|
// Reset state
|
||||||
graph_node_count_ = 0;
|
|
||||||
empty_node_count_ = 0;
|
|
||||||
from_nodes_.clear();
|
from_nodes_.clear();
|
||||||
to_nodes_.clear();
|
to_nodes_.clear();
|
||||||
graph_key_.clear();
|
graph_deps_key_.clear();
|
||||||
|
graph_nodes_key_.clear();
|
||||||
node_map_.clear();
|
node_map_.clear();
|
||||||
graph_ = CudaGraph(device_);
|
graph_ = CudaGraph(device_);
|
||||||
|
is_graph_updatable_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put completion handlers in a batch.
|
// Put completion handlers in a batch.
|
||||||
worker_.commit(stream_);
|
worker_.commit(stream_);
|
||||||
node_count_ = 0;
|
node_count_ = 0;
|
||||||
|
bytes_in_graph_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::synchronize() {
|
void CommandEncoder::synchronize() {
|
||||||
cudaStreamSynchronize(stream_);
|
CHECK_CUDA_ERROR(cudaStreamSynchronize(stream_));
|
||||||
auto p = std::make_shared<std::promise<void>>();
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
std::future<void> f = p->get_future();
|
std::future<void> f = p->get_future();
|
||||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class CommandEncoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void add_completed_handler(std::function<void()> task);
|
void add_completed_handler(std::function<void()> task);
|
||||||
int get_num_ops();
|
bool needs_commit();
|
||||||
void commit();
|
void commit();
|
||||||
|
|
||||||
Device& device() {
|
Device& device() {
|
||||||
@@ -106,8 +106,9 @@ class CommandEncoder {
|
|||||||
cudaGraphNode_t node;
|
cudaGraphNode_t node;
|
||||||
// K = kernel
|
// K = kernel
|
||||||
// E = empty
|
// E = empty
|
||||||
// G = subgraph
|
// () = subgraph (with metadata)
|
||||||
char node_type;
|
// Symbols ':', '-' are reserved as separators
|
||||||
|
std::string node_type;
|
||||||
std::string id;
|
std::string id;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -119,18 +120,21 @@ class CommandEncoder {
|
|||||||
CudaGraph graph_;
|
CudaGraph graph_;
|
||||||
Worker worker_;
|
Worker worker_;
|
||||||
char node_count_{0};
|
char node_count_{0};
|
||||||
char graph_node_count_{0};
|
|
||||||
char empty_node_count_{0};
|
|
||||||
bool in_concurrent_{false};
|
bool in_concurrent_{false};
|
||||||
std::vector<cudaGraphNode_t> from_nodes_;
|
std::vector<cudaGraphNode_t> from_nodes_;
|
||||||
std::vector<cudaGraphNode_t> to_nodes_;
|
std::vector<cudaGraphNode_t> to_nodes_;
|
||||||
std::string graph_key_;
|
std::string graph_nodes_key_;
|
||||||
|
std::string graph_deps_key_;
|
||||||
std::vector<GraphNode> concurrent_nodes_;
|
std::vector<GraphNode> concurrent_nodes_;
|
||||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||||
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
||||||
std::vector<std::uintptr_t> active_deps_;
|
std::vector<std::uintptr_t> active_deps_;
|
||||||
std::vector<std::uintptr_t> active_outputs_;
|
std::vector<std::uintptr_t> active_outputs_;
|
||||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||||
|
size_t bytes_in_graph_{0};
|
||||||
|
bool is_graph_updatable_{true};
|
||||||
|
int max_ops_per_graph_;
|
||||||
|
int max_mb_per_graph_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Device {
|
class Device {
|
||||||
@@ -166,6 +170,7 @@ class Device {
|
|||||||
int device_;
|
int device_;
|
||||||
int compute_capability_major_;
|
int compute_capability_major_;
|
||||||
int compute_capability_minor_;
|
int compute_capability_minor_;
|
||||||
|
std::string device_name_;
|
||||||
cublasLtHandle_t lt_;
|
cublasLtHandle_t lt_;
|
||||||
cudnnHandle_t cudnn_;
|
cudnnHandle_t cudnn_;
|
||||||
std::unordered_map<int, CommandEncoder> encoders_;
|
std::unordered_map<int, CommandEncoder> encoders_;
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ void AllReduce::eval_gpu(
|
|||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
return {in, out};
|
return {in, out};
|
||||||
} else {
|
} else {
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
return {in, out};
|
return {in, out};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -74,7 +74,7 @@ void AllGather::eval_gpu(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto input = ensure_contiguous(inputs[0]);
|
auto input = ensure_contiguous(inputs[0]);
|
||||||
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
|
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
|
||||||
|
|
||||||
encoder.set_input_array(input);
|
encoder.set_input_array(input);
|
||||||
encoder.set_output_array(outputs[0]);
|
encoder.set_output_array(outputs[0]);
|
||||||
@@ -103,7 +103,7 @@ void ReduceScatter::eval_gpu(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto input = ensure_contiguous(inputs[0]);
|
auto input = ensure_contiguous(inputs[0]);
|
||||||
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder.stream()));
|
outputs[0].set_data(cu::malloc_async(outputs[0].nbytes(), encoder));
|
||||||
|
|
||||||
encoder.set_input_array(input);
|
encoder.set_input_array(input);
|
||||||
encoder.set_output_array(outputs[0]);
|
encoder.set_output_array(outputs[0]);
|
||||||
|
|||||||
@@ -11,9 +11,6 @@
|
|||||||
|
|
||||||
namespace mlx::core::gpu {
|
namespace mlx::core::gpu {
|
||||||
|
|
||||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
|
||||||
constexpr int default_max_nodes_per_graph = 20;
|
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -53,8 +50,7 @@ void eval(array& arr) {
|
|||||||
encoder.add_temporary(s);
|
encoder.add_temporary(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (encoder.get_num_ops() >=
|
if (encoder.needs_commit()) {
|
||||||
env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
|
||||||
scheduler::notify_new_task(stream);
|
scheduler::notify_new_task(stream);
|
||||||
encoder.add_completed_handler(
|
encoder.add_completed_handler(
|
||||||
[stream]() { scheduler::notify_task_completion(stream); });
|
[stream]() { scheduler::notify_task_completion(stream); });
|
||||||
|
|||||||
@@ -305,6 +305,7 @@ void Event::wait() {
|
|||||||
} else {
|
} else {
|
||||||
event->atomic->wait(value());
|
event->atomic->wait(value());
|
||||||
}
|
}
|
||||||
|
CHECK_CUDA_ERROR(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Event::wait(Stream s) {
|
void Event::wait(Stream s) {
|
||||||
|
|||||||
@@ -370,7 +370,7 @@ void CublasGemm::execute(
|
|||||||
// Ensure workspace is 256-byte aligned
|
// Ensure workspace is 256-byte aligned
|
||||||
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
|
int nbytes = cuda::ceil_div(heuristic_.workspaceSize, 256) * 256;
|
||||||
array workspace(
|
array workspace(
|
||||||
cu::malloc_async(nbytes, encoder.stream()),
|
cu::malloc_async(nbytes, encoder),
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
int8);
|
int8);
|
||||||
encoder.add_temporary(workspace);
|
encoder.add_temporary(workspace);
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ void CublasGemm::run_batched(
|
|||||||
|
|
||||||
// Launch kernel to set device offsets
|
// Launch kernel to set device offsets
|
||||||
auto pointers = array(
|
auto pointers = array(
|
||||||
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder.stream()),
|
cu::malloc_async(batch_count * sizeof(void*) * 3, encoder),
|
||||||
{batch_count * 3},
|
{batch_count * 3},
|
||||||
uint64);
|
uint64);
|
||||||
|
|
||||||
@@ -251,7 +251,7 @@ void CublasGemm::run_batched(
|
|||||||
|
|
||||||
// Launch kernel to set device offsets
|
// Launch kernel to set device offsets
|
||||||
auto pointers = array(
|
auto pointers = array(
|
||||||
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder.stream()),
|
cu::malloc_async(batch_count * sizeof(uint64_t) * 4, encoder),
|
||||||
{batch_count * 4},
|
{batch_count * 4},
|
||||||
uint64);
|
uint64);
|
||||||
|
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -241,7 +241,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -279,11 +279,14 @@ void compile(
|
|||||||
// Compile program.
|
// Compile program.
|
||||||
std::vector<const char*> args;
|
std::vector<const char*> args;
|
||||||
bool use_sass = compiler_supports_device_sass(device);
|
bool use_sass = compiler_supports_device_sass(device);
|
||||||
|
auto cc = device.compute_capability_major();
|
||||||
|
std::string arch_tag = (cc == 90 || cc == 100 || cc == 121) ? "a" : "";
|
||||||
std::string compute = fmt::format(
|
std::string compute = fmt::format(
|
||||||
"--gpu-architecture={}_{}{}",
|
"--gpu-architecture={}_{}{}{}",
|
||||||
use_sass ? "sm" : "compute",
|
use_sass ? "sm" : "compute",
|
||||||
device.compute_capability_major(),
|
cc,
|
||||||
device.compute_capability_minor());
|
device.compute_capability_minor(),
|
||||||
|
arch_tag);
|
||||||
args.push_back(compute.c_str());
|
args.push_back(compute.c_str());
|
||||||
std::string cccl_include = cccl_dir();
|
std::string cccl_include = cccl_dir();
|
||||||
if (!cccl_include.empty()) {
|
if (!cccl_include.empty()) {
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ void LayerNorm::eval_gpu(
|
|||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
@@ -335,7 +335,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
gx.copy_shared_buffer(g);
|
gx.copy_shared_buffer(g);
|
||||||
g_in_gx = true;
|
g_in_gx = true;
|
||||||
} else {
|
} else {
|
||||||
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
|
gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
|
||||||
}
|
}
|
||||||
if (g_copied && !g_in_gx) {
|
if (g_copied && !g_in_gx) {
|
||||||
encoder.add_temporary(g);
|
encoder.add_temporary(g);
|
||||||
@@ -355,7 +355,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
g_in_gw = true;
|
g_in_gw = true;
|
||||||
gw_temp.copy_shared_buffer(g);
|
gw_temp.copy_shared_buffer(g);
|
||||||
} else {
|
} else {
|
||||||
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
|
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
|
||||||
encoder.add_temporary(gw_temp);
|
encoder.add_temporary(gw_temp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& encoder = cu::get_command_encoder(stream());
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
auto size = out.size();
|
auto size = out.size();
|
||||||
auto nbytes = size * out.itemsize();
|
auto nbytes = size * out.itemsize();
|
||||||
out.set_data(cu::malloc_async(nbytes, encoder.stream()));
|
out.set_data(cu::malloc_async(nbytes, encoder));
|
||||||
auto out_ptr = malloc(nbytes);
|
auto out_ptr = malloc(nbytes);
|
||||||
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
|
reader_->read(static_cast<char*>(out_ptr), nbytes, offset_);
|
||||||
if (swap_endianness_) {
|
if (swap_endianness_) {
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
auto in = ensure_contiguous(inputs[0]);
|
auto in = ensure_contiguous(inputs[0]);
|
||||||
if (in.flags().row_contiguous) {
|
if (in.flags().row_contiguous) {
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
} else {
|
} else {
|
||||||
auto n = in.shape(-1);
|
auto n = in.shape(-1);
|
||||||
auto flags = in.flags();
|
auto flags = in.flags();
|
||||||
@@ -130,7 +130,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
flags.col_contiguous = col_contig;
|
flags.col_contiguous = col_contig;
|
||||||
out.set_data(
|
out.set_data(
|
||||||
cu::malloc_async(in.nbytes() / n, encoder.stream()),
|
cu::malloc_async(in.nbytes() / n, encoder),
|
||||||
in.data_size() / n,
|
in.data_size() / n,
|
||||||
std::move(strides),
|
std::move(strides),
|
||||||
flags);
|
flags);
|
||||||
|
|||||||
@@ -135,12 +135,19 @@ class LRUCache {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Turn a POD struct into a container key by doing bytes compare.
|
// Turn a POD struct into a container key by doing bytes compare.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
// BytesKey<MyKey> key;
|
||||||
|
// key.pod = { ... };
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct BytesKey {
|
struct BytesKey {
|
||||||
T pod;
|
T pod;
|
||||||
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
static_assert(std::is_standard_layout_v<T>, "T is not POD");
|
||||||
|
|
||||||
BytesKey(T pod) : pod(std::move(pod)) {}
|
BytesKey() {
|
||||||
|
// Make sure the paddings between members are filled with 0.
|
||||||
|
memset(&pod, 0, sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
BytesKey(const BytesKey& other) {
|
BytesKey(const BytesKey& other) {
|
||||||
memcpy(&pod, &other.pod, sizeof(T));
|
memcpy(&pod, &other.pod, sizeof(T));
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
|
|
||||||
int M = a_pre.shape(-2);
|
int M = a_pre.shape(-2);
|
||||||
int N = b_pre.shape(-1);
|
int N = b_pre.shape(-1);
|
||||||
@@ -163,7 +163,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
|
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
|
||||||
c.data_size() == out.shape(-1)) {
|
c.data_size() == out.shape(-1)) {
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
gemm_and_bias(
|
gemm_and_bias(
|
||||||
encoder,
|
encoder,
|
||||||
M,
|
M,
|
||||||
@@ -187,10 +187,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto sty = c.strides()[c.ndim() - 1];
|
auto sty = c.strides()[c.ndim() - 1];
|
||||||
if (sty == 1 && stx == c.shape(-1)) {
|
if (sty == 1 && stx == c.shape(-1)) {
|
||||||
ldc = stx;
|
ldc = stx;
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
} else if (sty == 1 && stx == 0) {
|
} else if (sty == 1 && stx == 0) {
|
||||||
ldc = 0;
|
ldc = 0;
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
} else {
|
} else {
|
||||||
// Copy C into out and set C to out
|
// Copy C into out and set C to out
|
||||||
ldc = c.shape(-1);
|
ldc = c.shape(-1);
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ NO_GPU(Inverse)
|
|||||||
NO_GPU(Cholesky)
|
NO_GPU(Cholesky)
|
||||||
NO_GPU_MULTI(Eig)
|
NO_GPU_MULTI(Eig)
|
||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
NO_GPU(MaskedScatter)
|
||||||
|
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
NO_GPU_MULTI(Send)
|
NO_GPU_MULTI(Send)
|
||||||
|
|||||||
@@ -2,7 +2,11 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh"
|
||||||
|
#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh"
|
||||||
#include "mlx/backend/cuda/quantized/quantized.h"
|
#include "mlx/backend/cuda/quantized/quantized.h"
|
||||||
|
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/vector_types.cuh"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
@@ -13,17 +17,6 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
template <int bits>
|
|
||||||
struct Quantize {
|
|
||||||
__device__ uint8_t operator()(float x) {
|
|
||||||
if constexpr (bits == 8) {
|
|
||||||
return __nv_fp8_e4m3(x).__x;
|
|
||||||
} else {
|
|
||||||
return __nv_fp4_e2m1(x).__x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <int bits>
|
template <int bits>
|
||||||
struct Dequantize {
|
struct Dequantize {
|
||||||
__device__ float operator()(uint8_t x) {
|
__device__ float operator()(uint8_t x) {
|
||||||
@@ -37,29 +30,40 @@ struct Dequantize {
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
|
||||||
__global__ void
|
__global__ void fp_quantize(T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
||||||
fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
using Tx2 = Vector2_t<T>;
|
||||||
|
using Tx4 = Vector4_t<T>;
|
||||||
|
uint32_t rbits = 0; // reserved bits for future use
|
||||||
auto block_size = cg::this_thread_block().dim_threads();
|
auto block_size = cg::this_thread_block().dim_threads();
|
||||||
auto block_idx = cg::this_thread_block().group_index();
|
auto block_idx = cg::this_thread_block().group_index();
|
||||||
auto idx_in_block = cg::this_thread_block().thread_index();
|
auto idx_in_block = cg::this_thread_block().thread_index();
|
||||||
|
|
||||||
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
auto tidx = block_idx.x * block_size.x + idx_in_block.x;
|
||||||
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
auto tidy = block_idx.y * block_size.y + idx_in_block.y;
|
||||||
|
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
|
||||||
|
|
||||||
auto grid_dim_x =
|
size_t thread_idx = tidx + grid_dim_x * size_t(tidy);
|
||||||
cg::this_grid().dim_blocks().x * cg::this_grid().block_index().x;
|
size_t base_idx = thread_idx * group_size;
|
||||||
size_t index = tidx + grid_dim_x * size_t(tidy);
|
|
||||||
if (index >= size) {
|
if (base_idx >= size) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float w_thread = w[index];
|
auto w_tile = load_vector<group_size, T>(w, thread_idx);
|
||||||
|
float scale = 0.0f;
|
||||||
|
|
||||||
cg::greater<float> max_op;
|
Tx2 amax_2x = Tx2{0.0f, 0.0f};
|
||||||
auto warp = cg::tiled_partition<group_size>(cg::this_thread_block());
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < group_size; i += 2) {
|
||||||
|
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
|
||||||
|
abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
|
||||||
|
}
|
||||||
|
|
||||||
|
scale = static_cast<float>(
|
||||||
|
max(fabsf(static_cast<float>(amax_2x.x)),
|
||||||
|
fabsf(static_cast<float>(amax_2x.y))));
|
||||||
|
|
||||||
float scale = cg::reduce(warp, abs(w_thread), max_op);
|
|
||||||
scale /= bits == 4 ? 6.0f : 448.0f;
|
scale /= bits == 4 ? 6.0f : 448.0f;
|
||||||
// Convert to mx scale or nv scale
|
// Convert to mx scale or nv scale
|
||||||
using ScaleType =
|
using ScaleType =
|
||||||
@@ -68,21 +72,24 @@ fp_quantize(const T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
|||||||
uint8_t q_scale = s.__x;
|
uint8_t q_scale = s.__x;
|
||||||
scale = float(s);
|
scale = float(s);
|
||||||
|
|
||||||
// Write out the scales
|
scales[thread_idx] = q_scale;
|
||||||
size_t gindex = index / group_size;
|
constexpr int elem_per_byte = bits == 8 ? 1 : 2;
|
||||||
if (index % group_size == 0) {
|
AlignedVector<uint8_t, group_size / elem_per_byte> quantized;
|
||||||
scales[gindex] = q_scale;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t output = Quantize<bits>{}(scale == 0 ? 0.0f : w_thread / scale);
|
#pragma unroll
|
||||||
if (bits == 4) {
|
for (int i = 0; i < group_size / 4; i++) {
|
||||||
uint8_t sval = warp.shfl_down(output, 1);
|
Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
|
||||||
output |= sval << bits;
|
if constexpr (bits == 8) {
|
||||||
}
|
uint32_t quantized_val =
|
||||||
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
|
||||||
if (index % pack_factor == 0) {
|
*reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
|
||||||
out[index / pack_factor] = output;
|
} else {
|
||||||
|
uint16_t quantized_val =
|
||||||
|
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
|
||||||
|
*reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
store_vector<group_size / elem_per_byte>(out, thread_idx, quantized);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits, bool use_mx_scale>
|
template <typename T, int group_size, int bits, bool use_mx_scale>
|
||||||
@@ -142,15 +149,16 @@ void fp_quantize(
|
|||||||
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
|
dispatch_float_types(w.dtype(), "fp_quantize", [&](auto type_tag) {
|
||||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
if constexpr (!std::is_same_v<T, double>) {
|
if constexpr (!std::is_same_v<T, double>) {
|
||||||
auto kernel = cu::fp_quantize<T, 32, 4, true>;
|
auto kernel = cu::fp_quantize<T, 32, 4, true, false>;
|
||||||
if (bits == 8) {
|
if (bits == 8) {
|
||||||
kernel = cu::fp_quantize<T, 32, 8, true>;
|
kernel = cu::fp_quantize<T, 32, 8, true, false>;
|
||||||
} else if (group_size == 16) {
|
} else if (group_size == 16) {
|
||||||
kernel = cu::fp_quantize<T, 16, 4, false>;
|
kernel = cu::fp_quantize<T, 16, 4, false, false>;
|
||||||
}
|
}
|
||||||
bool large = w.size() > UINT_MAX;
|
bool large = w.size() > UINT_MAX;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(w.size(), w.shape(), w.strides(), large);
|
get_launch_args(w.size(), w.shape(), w.strides(), large, group_size);
|
||||||
|
|
||||||
enc.add_kernel_node(
|
enc.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
num_blocks,
|
num_blocks,
|
||||||
|
|||||||
32
mlx/backend/cuda/quantized/mxfp8_quantize.cuh
Normal file
32
mlx/backend/cuda/quantized/mxfp8_quantize.cuh
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include "mlx/backend/cuda/vector_types.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
// TODO implement fast path
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ uint32_t
|
||||||
|
scale_cvt_Tx4_to_fp8x4_fallback(const Vector4_t<T> input, const float scale) {
|
||||||
|
uint32_t out_fp8x4 = 0;
|
||||||
|
float4 scaled;
|
||||||
|
scaled.x = static_cast<float>(input.x) * scale;
|
||||||
|
scaled.y = static_cast<float>(input.y) * scale;
|
||||||
|
scaled.z = static_cast<float>(input.z) * scale;
|
||||||
|
scaled.w = static_cast<float>(input.w) * scale;
|
||||||
|
out_fp8x4 = __nv_fp8x4_e4m3(scaled).__x;
|
||||||
|
return out_fp8x4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Place holder for future fast path implementation
|
||||||
|
template <typename T, bool USE_SR>
|
||||||
|
__device__ __forceinline__ uint32_t scale_cvt_Tx4_to_fp8x4(
|
||||||
|
const Vector4_t<T> input,
|
||||||
|
const float scale,
|
||||||
|
uint32_t rbits) {
|
||||||
|
return scale_cvt_Tx4_to_fp8x4_fallback(input, scale);
|
||||||
|
}
|
||||||
|
} // namespace mlx::core::cu
|
||||||
334
mlx/backend/cuda/quantized/nvfp4_quantize.cuh
Normal file
334
mlx/backend/cuda/quantized/nvfp4_quantize.cuh
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_fp4.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include "mlx/backend/cuda/vector_types.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
using bf16x4 = Vector4_t<__nv_bfloat16>;
|
||||||
|
using fp16x4 = Vector4_t<__half>;
|
||||||
|
using f32x4 = Vector4_t<float>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ uint16_t
|
||||||
|
scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
|
||||||
|
// Fallback implementation for architectures that do not support cvt
|
||||||
|
// instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
|
||||||
|
uint16_t out_fp4x4 = 0;
|
||||||
|
fp32x4 scaled;
|
||||||
|
scaled.x = static_cast<float>(input.x) * scale;
|
||||||
|
scaled.y = static_cast<float>(input.y) * scale;
|
||||||
|
scaled.z = static_cast<float>(input.z) * scale;
|
||||||
|
scaled.w = static_cast<float>(input.w) * scale;
|
||||||
|
uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
|
||||||
|
uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
|
||||||
|
uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
|
||||||
|
uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
|
||||||
|
out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
|
||||||
|
(static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
|
||||||
|
static_cast<uint16_t>(q0);
|
||||||
|
return out_fp4x4;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
|
||||||
|
defined(__CUDA_ARCH_SPECIFIC__)
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint16_t
|
||||||
|
scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
|
||||||
|
uint16_t out_fp4x4 = 0;
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
".reg.b16 x0_bf16; \n\t" // first bf16
|
||||||
|
".reg.b16 x1_bf16; \n\t" // second bf16
|
||||||
|
".reg.b16 x2_bf16; \n\t" // third bf16
|
||||||
|
".reg.b16 x3_bf16; \n\t" // fourth bf16
|
||||||
|
".reg.b32 x0; \n\t" // to hold scaled first
|
||||||
|
".reg.b32 x1; \n\t" // to hold scaled second
|
||||||
|
".reg.b32 x2; \n\t" // to hold scaled third
|
||||||
|
".reg.b32 x3; \n\t" // to hold scaled fourth
|
||||||
|
".reg.b64 x01; \n\t" // to hold vector mul
|
||||||
|
".reg.b64 x23; \n\t"
|
||||||
|
".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
|
||||||
|
".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
|
||||||
|
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
|
||||||
|
"cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
|
||||||
|
"cvt.f32.bf16 x1, x1_bf16; \n\t"
|
||||||
|
"cvt.f32.bf16 x2, x2_bf16; \n\t"
|
||||||
|
"cvt.f32.bf16 x3, x3_bf16; \n\t"
|
||||||
|
"mov.b64 x01, {x0, x1}; \n\t"
|
||||||
|
"mul.f32x2 x01, x01, %2; \n\t" // scale first pair
|
||||||
|
"mov.b64 x23, {x2, x3}; \n\t"
|
||||||
|
"mul.f32x2 x23, x23, %2; \n\t" // scale second pair
|
||||||
|
"mov.b64 {x0, x1}, x01; \n\t"
|
||||||
|
"mov.b64 {x2, x3}, x23; \n\t"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
|
||||||
|
// pair
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
|
||||||
|
// pair
|
||||||
|
"mov.b16 %0, {q0, q1}; \n\t" // pack to output
|
||||||
|
"}"
|
||||||
|
: "=h"(out_fp4x4)
|
||||||
|
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
|
||||||
|
"l"(reinterpret_cast<const uint64_t&>(
|
||||||
|
scale))); // here cast is needed becuase an asm operand must have
|
||||||
|
// scalar type
|
||||||
|
return out_fp4x4;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
|
||||||
|
const bf16x4 input_bf16x4,
|
||||||
|
const float2 scale,
|
||||||
|
uint32_t rbits) {
|
||||||
|
uint16_t out_fp4x4 = 0;
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
".reg.b16 x0_bf16; \n\t"
|
||||||
|
".reg.b16 x1_bf16; \n\t"
|
||||||
|
".reg.b16 x2_bf16; \n\t"
|
||||||
|
".reg.b16 x3_bf16; \n\t"
|
||||||
|
".reg.b32 x0; \n\t"
|
||||||
|
".reg.b32 x1; \n\t"
|
||||||
|
".reg.b32 x2; \n\t"
|
||||||
|
".reg.b32 x3; \n\t"
|
||||||
|
".reg.b64 x01; \n\t"
|
||||||
|
".reg.b64 x23; \n\t"
|
||||||
|
".reg.b16 q0; \n\t"
|
||||||
|
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
|
||||||
|
"cvt.f32.bf16 x0, x0_bf16; \n\t"
|
||||||
|
"cvt.f32.bf16 x1, x1_bf16; \n\t"
|
||||||
|
"cvt.f32.bf16 x2, x2_bf16; \n\t"
|
||||||
|
"cvt.f32.bf16 x3, x3_bf16; \n\t"
|
||||||
|
"mov.b64 x01, {x0, x1}; \n\t"
|
||||||
|
"mul.f32x2 x01, x01, %2; \n\t"
|
||||||
|
"mov.b64 x23, {x2, x3}; \n\t"
|
||||||
|
"mul.f32x2 x23, x23, %2; \n\t"
|
||||||
|
"mov.b64 {x0, x1}, x01; \n\t"
|
||||||
|
"mov.b64 {x2, x3}, x23; \n\t"
|
||||||
|
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
|
||||||
|
"}"
|
||||||
|
: "=h"(out_fp4x4)
|
||||||
|
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
|
||||||
|
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
||||||
|
"r"(rbits));
|
||||||
|
return out_fp4x4;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
|
||||||
|
const float2 input_fp32x2_0,
|
||||||
|
const float2 input_fp32x2_1,
|
||||||
|
const float2 scale) {
|
||||||
|
uint16_t out_fp4x4 = 0;
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
".reg.b32 x0; \n\t"
|
||||||
|
".reg.b32 x1; \n\t"
|
||||||
|
".reg.b32 x2; \n\t"
|
||||||
|
".reg.b32 x3; \n\t"
|
||||||
|
".reg.b64 x01; \n\t"
|
||||||
|
".reg.b64 x23; \n\t"
|
||||||
|
".reg.b8 q0; \n\t"
|
||||||
|
".reg.b8 q1; \n\t"
|
||||||
|
"mov.b64 x01, {%1, %2}; \n\t"
|
||||||
|
"mul.f32x2 x01, x01, %5; \n\t"
|
||||||
|
"mov.b64 x23, {%3, %4}; \n\t"
|
||||||
|
"mul.f32x2 x23, x23, %5; \n\t"
|
||||||
|
"mov.b64 {x0, x1}, x01; \n\t"
|
||||||
|
"mov.b64 {x2, x3}, x23; \n\t"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
|
||||||
|
"mov.b16 %0, {q0, q1}; \n\t"
|
||||||
|
"}"
|
||||||
|
: "=h"(out_fp4x4)
|
||||||
|
: "f"(input_fp32x2_0.x),
|
||||||
|
"f"(input_fp32x2_0.y),
|
||||||
|
"f"(input_fp32x2_1.x),
|
||||||
|
"f"(input_fp32x2_1.y),
|
||||||
|
"l"(reinterpret_cast<const uint64_t&>(scale)));
|
||||||
|
return out_fp4x4;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
|
||||||
|
const float2 input_fp32x2_0,
|
||||||
|
const float2 input_fp32x2_1,
|
||||||
|
const float2 scale,
|
||||||
|
uint32_t rbits) {
|
||||||
|
uint16_t out_fp4x4 = 0;
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
".reg.b32 x0; \n\t"
|
||||||
|
".reg.b32 x1; \n\t"
|
||||||
|
".reg.b32 x2; \n\t"
|
||||||
|
".reg.b32 x3; \n\t"
|
||||||
|
".reg.b64 x01; \n\t"
|
||||||
|
".reg.b64 x23; \n\t"
|
||||||
|
".reg.b16 q0; \n\t"
|
||||||
|
"mov.b64 x01, {%1, %2}; \n\t"
|
||||||
|
"mul.f32x2 x01, x01, %5; \n\t"
|
||||||
|
"mov.b64 x23, {%3, %4}; \n\t"
|
||||||
|
"mul.f32x2 x23, x23, %5; \n\t"
|
||||||
|
"mov.b64 {x0, x1}, x01; \n\t"
|
||||||
|
"mov.b64 {x2, x3}, x23; \n\t"
|
||||||
|
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
|
||||||
|
"}"
|
||||||
|
: "=h"(out_fp4x4)
|
||||||
|
: "f"(input_fp32x2_0.x),
|
||||||
|
"f"(input_fp32x2_0.y),
|
||||||
|
"f"(input_fp32x2_1.x),
|
||||||
|
"f"(input_fp32x2_1.y),
|
||||||
|
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
||||||
|
"r"(rbits));
|
||||||
|
return out_fp4x4;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint16_t
|
||||||
|
scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
|
||||||
|
uint16_t out_fp4x4 = 0;
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
".reg.b16 x0_fp16; \n\t"
|
||||||
|
".reg.b16 x1_fp16; \n\t"
|
||||||
|
".reg.b16 x2_fp16; \n\t"
|
||||||
|
".reg.b16 x3_fp16; \n\t"
|
||||||
|
".reg.b32 x0; \n\t"
|
||||||
|
".reg.b32 x1; \n\t"
|
||||||
|
".reg.b32 x2; \n\t"
|
||||||
|
".reg.b32 x3; \n\t"
|
||||||
|
".reg.b64 x01; \n\t"
|
||||||
|
".reg.b64 x23; \n\t"
|
||||||
|
".reg.b8 q0; \n\t"
|
||||||
|
".reg.b8 q1; \n\t"
|
||||||
|
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
|
||||||
|
"cvt.f32.f16 x0, x0_fp16; \n\t"
|
||||||
|
"cvt.f32.f16 x1, x1_fp16; \n\t"
|
||||||
|
"cvt.f32.f16 x2, x2_fp16; \n\t"
|
||||||
|
"cvt.f32.f16 x3, x3_fp16; \n\t"
|
||||||
|
"mov.b64 x01, {x0, x1}; \n\t"
|
||||||
|
"mul.f32x2 x01, x01, %2; \n\t"
|
||||||
|
"mov.b64 x23, {x2, x3}; \n\t"
|
||||||
|
"mul.f32x2 x23, x23, %2; \n\t"
|
||||||
|
"mov.b64 {x0, x1}, x01; \n\t"
|
||||||
|
"mov.b64 {x2, x3}, x23; \n\t"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
|
||||||
|
"mov.b16 %0, {q0, q1}; \n\t"
|
||||||
|
"}"
|
||||||
|
: "=h"(out_fp4x4)
|
||||||
|
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
|
||||||
|
"l"(reinterpret_cast<const uint64_t&>(scale)));
|
||||||
|
return out_fp4x4;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
|
||||||
|
const fp16x4 input_fp16x4,
|
||||||
|
const float2 scale,
|
||||||
|
uint32_t rbits) {
|
||||||
|
uint16_t out_fp4x4 = 0;
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
".reg.b16 x0_fp16; \n\t"
|
||||||
|
".reg.b16 x1_fp16; \n\t"
|
||||||
|
".reg.b16 x2_fp16; \n\t"
|
||||||
|
".reg.b16 x3_fp16; \n\t"
|
||||||
|
".reg.b32 x0; \n\t"
|
||||||
|
".reg.b32 x1; \n\t"
|
||||||
|
".reg.b32 x2; \n\t"
|
||||||
|
".reg.b32 x3; \n\t"
|
||||||
|
".reg.b64 x01; \n\t"
|
||||||
|
".reg.b64 x23; \n\t"
|
||||||
|
".reg.b16 q0; \n\t"
|
||||||
|
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
|
||||||
|
"cvt.f32.f16 x0, x0_fp16; \n\t"
|
||||||
|
"cvt.f32.f16 x1, x1_fp16; \n\t"
|
||||||
|
"cvt.f32.f16 x2, x2_fp16; \n\t"
|
||||||
|
"cvt.f32.f16 x3, x3_fp16; \n\t"
|
||||||
|
"mov.b64 x01, {x0, x1}; \n\t"
|
||||||
|
"mul.f32x2 x01, x01, %2; \n\t"
|
||||||
|
"mov.b64 x23, {x2, x3}; \n\t"
|
||||||
|
"mul.f32x2 x23, x23, %2; \n\t"
|
||||||
|
"mov.b64 {x0, x1}, x01; \n\t"
|
||||||
|
"mov.b64 {x2, x3}, x23; \n\t"
|
||||||
|
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
|
||||||
|
"}"
|
||||||
|
: "=h"(out_fp4x4)
|
||||||
|
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
|
||||||
|
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
||||||
|
"r"(rbits));
|
||||||
|
return out_fp4x4;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool USE_SR>
|
||||||
|
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
|
||||||
|
const bf16x4 input,
|
||||||
|
const float scale,
|
||||||
|
uint32_t rbits) {
|
||||||
|
float2 scale_fp32x2 = make_float2(scale, scale);
|
||||||
|
if constexpr (USE_SR) {
|
||||||
|
return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
|
||||||
|
} else {
|
||||||
|
return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool USE_SR>
|
||||||
|
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
|
||||||
|
const fp16x4 input,
|
||||||
|
const float scale,
|
||||||
|
uint32_t rbits) {
|
||||||
|
float2 scale_fp32x2 = make_float2(scale, scale);
|
||||||
|
if constexpr (USE_SR) {
|
||||||
|
return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
|
||||||
|
} else {
|
||||||
|
return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool USE_SR>
|
||||||
|
__device__ __forceinline__ uint16_t
|
||||||
|
scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
|
||||||
|
float2 scale_fp32x2 = make_float2(scale, scale);
|
||||||
|
float2 input_fp32x2_0 = make_float2(input.x, input.y);
|
||||||
|
float2 input_fp32x2_1 = make_float2(input.z, input.w);
|
||||||
|
|
||||||
|
if constexpr (USE_SR) {
|
||||||
|
return scale_cvt_fp32x4_to_fp4x4_rs(
|
||||||
|
input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
|
||||||
|
} else {
|
||||||
|
return scale_cvt_fp32x4_to_fp4x4_rn(
|
||||||
|
input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, bool USE_SR>
|
||||||
|
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
|
||||||
|
const Vector4_t<T> input,
|
||||||
|
const float scale,
|
||||||
|
uint32_t rbits) {
|
||||||
|
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||||
|
return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
||||||
|
} else if constexpr (std::is_same<T, __half>::value) {
|
||||||
|
return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
||||||
|
} else {
|
||||||
|
return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
|
||||||
|
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
|
||||||
|
|
||||||
|
template <typename T, bool USE_SR>
|
||||||
|
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
|
||||||
|
const Vector4_t<T> input,
|
||||||
|
const float scale,
|
||||||
|
uint32_t rbits) {
|
||||||
|
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
|
||||||
|
(__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
|
||||||
|
return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
|
||||||
|
#else
|
||||||
|
static_assert(
|
||||||
|
!USE_SR,
|
||||||
|
"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
|
||||||
|
return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -59,7 +59,7 @@ void fast::Quantize::eval_gpu(
|
|||||||
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
||||||
auto& w = outputs[0];
|
auto& w = outputs[0];
|
||||||
|
|
||||||
w.set_data(cu::malloc_async(w.nbytes(), enc.stream()));
|
w.set_data(cu::malloc_async(w.nbytes(), enc));
|
||||||
|
|
||||||
if (mode_ == QuantizationMode::Affine) {
|
if (mode_ == QuantizationMode::Affine) {
|
||||||
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
||||||
@@ -72,11 +72,11 @@ void fast::Quantize::eval_gpu(
|
|||||||
auto& wq = outputs[0];
|
auto& wq = outputs[0];
|
||||||
auto& scales = outputs[1];
|
auto& scales = outputs[1];
|
||||||
|
|
||||||
wq.set_data(cu::malloc_async(wq.nbytes(), enc.stream()));
|
wq.set_data(cu::malloc_async(wq.nbytes(), enc));
|
||||||
scales.set_data(cu::malloc_async(scales.nbytes(), enc.stream()));
|
scales.set_data(cu::malloc_async(scales.nbytes(), enc));
|
||||||
if (mode_ == QuantizationMode::Affine) {
|
if (mode_ == QuantizationMode::Affine) {
|
||||||
auto& biases = outputs[2];
|
auto& biases = outputs[2];
|
||||||
biases.set_data(cu::malloc_async(biases.nbytes(), enc.stream()));
|
biases.set_data(cu::malloc_async(biases.nbytes(), enc));
|
||||||
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
||||||
} else {
|
} else {
|
||||||
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
|
fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
|
||||||
|
|||||||
@@ -15,6 +15,22 @@ inline constexpr __device__ short get_bytes_per_pack() {
|
|||||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
|
||||||
|
if constexpr (
|
||||||
|
(std::is_same<T, __nv_bfloat162>::value) ||
|
||||||
|
(std::is_same<T, __half2>::value)) {
|
||||||
|
T a = x1;
|
||||||
|
T b = x2;
|
||||||
|
out = __hmax2(__habs2(a), __habs2(b));
|
||||||
|
} else if constexpr (std::is_same<T, float2>::value) {
|
||||||
|
float2 a = x1;
|
||||||
|
float2 b = x2;
|
||||||
|
out.x = fmaxf(fabsf(a.x), fabsf(b.x));
|
||||||
|
out.y = fmaxf(fabsf(a.y), fabsf(b.y));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|
||||||
template <typename F>
|
template <typename F>
|
||||||
|
|||||||
@@ -139,30 +139,36 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// keys has shape (N1, ..., NK, 2)
|
// keys has shape (N1, ..., NK, 2)
|
||||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||||
auto& keys = inputs[0];
|
auto& keys = inputs[0];
|
||||||
uint32_t num_keys = keys.size() / 2;
|
size_t num_keys = keys.size() / 2;
|
||||||
|
|
||||||
uint32_t elems_per_key = out.size() / num_keys;
|
size_t elems_per_key = out.size() / num_keys;
|
||||||
uint32_t bytes_per_key = out.itemsize() * elems_per_key;
|
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
|
||||||
uint32_t half_size = out_per_key / 2;
|
size_t half_size = out_per_key / 2;
|
||||||
|
|
||||||
bool odd = out_per_key % 2;
|
bool odd = out_per_key % 2;
|
||||||
|
if ((half_size + odd) >= UINT32_MAX || num_keys >= UINT32_MAX) {
|
||||||
|
throw std::runtime_error("[RandomBits::eval_gpu] Large size unsupported");
|
||||||
|
}
|
||||||
|
|
||||||
encoder.set_input_array(keys);
|
encoder.set_input_array(keys);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
dim3 grid_dims{num_keys, half_size + odd};
|
int64_t total = num_keys * (half_size + odd);
|
||||||
int64_t total = grid_dims.x * grid_dims.y;
|
uint32_t threads_y = 1;
|
||||||
int32_t threads_y = 1;
|
while ((total / threads_y) >= UINT_MAX) {
|
||||||
while ((total / threads_y) >= (1U << 31)) {
|
|
||||||
threads_y *= 2;
|
threads_y *= 2;
|
||||||
}
|
}
|
||||||
int32_t threads_x = cuda::ceil_div(total, threads_y);
|
uint32_t threads_x = cuda::ceil_div(total, threads_y);
|
||||||
|
|
||||||
|
dim3 grid_dims{
|
||||||
|
static_cast<uint32_t>(num_keys), static_cast<uint32_t>(half_size + odd)};
|
||||||
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
auto [grid, block] = get_grid_and_block(threads_x, threads_y, 1);
|
||||||
auto& stream = encoder.stream();
|
auto& stream = encoder.stream();
|
||||||
if (keys.flags().row_contiguous) {
|
if (keys.flags().row_contiguous) {
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ void all_reduce(
|
|||||||
Reduce::ReduceType reduce_type) {
|
Reduce::ReduceType reduce_type) {
|
||||||
constexpr int N_READS = 8;
|
constexpr int N_READS = 8;
|
||||||
|
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
|
|
||||||
auto get_args = [](size_t size, int N) {
|
auto get_args = [](size_t size, int N) {
|
||||||
int threads = std::min(512UL, (size + N - 1) / N);
|
int threads = std::min(512UL, (size + N - 1) / N);
|
||||||
@@ -107,8 +107,7 @@ void all_reduce(
|
|||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
if (blocks > 1) {
|
if (blocks > 1) {
|
||||||
array intermediate({blocks}, out.dtype(), nullptr, {});
|
array intermediate({blocks}, out.dtype(), nullptr, {});
|
||||||
intermediate.set_data(
|
intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
|
||||||
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
|
|
||||||
encoder.add_temporary(intermediate);
|
encoder.add_temporary(intermediate);
|
||||||
encoder.set_output_array(intermediate);
|
encoder.set_output_array(intermediate);
|
||||||
dispatch_all_types(dt, [&](auto type_tag) {
|
dispatch_all_types(dt, [&](auto type_tag) {
|
||||||
|
|||||||
@@ -89,9 +89,13 @@ template <
|
|||||||
int NDIM,
|
int NDIM,
|
||||||
int BM,
|
int BM,
|
||||||
int BN,
|
int BN,
|
||||||
int N_READS = 4>
|
int N_READS = 4,
|
||||||
__global__ void
|
int BLOCKS = 1>
|
||||||
col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
__global__ void col_reduce_looped(
|
||||||
|
T* in,
|
||||||
|
U* out,
|
||||||
|
const __grid_constant__ ColReduceArgs args,
|
||||||
|
int64_t out_size) {
|
||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||||
@@ -102,6 +106,8 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
size_t tile_idx = grid.block_rank();
|
size_t tile_idx = grid.block_rank();
|
||||||
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN);
|
||||||
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN);
|
||||||
|
size_t tile_out = tile_y / out_size;
|
||||||
|
tile_y = tile_y % out_size;
|
||||||
|
|
||||||
// Compute the indices for the thread within the tile
|
// Compute the indices for the thread within the tile
|
||||||
short thread_x = block.thread_rank() % threads_per_row;
|
short thread_x = block.thread_rank() % threads_per_row;
|
||||||
@@ -118,12 +124,23 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
totals[i] = ReduceInit<Op, T>::value();
|
totals[i] = ReduceInit<Op, T>::value();
|
||||||
}
|
}
|
||||||
|
|
||||||
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
|
||||||
loop.next(thread_y, args.reduce_shape.data(), args.reduce_strides.data());
|
|
||||||
size_t total = args.non_col_reductions * args.reduction_size;
|
size_t total = args.non_col_reductions * args.reduction_size;
|
||||||
|
size_t per_block, start, end;
|
||||||
|
if constexpr (BLOCKS > 1) {
|
||||||
|
per_block = (total + BLOCKS - 1) / BLOCKS;
|
||||||
|
start = tile_out * per_block + thread_y;
|
||||||
|
end = min((tile_out + 1) * per_block, total);
|
||||||
|
} else {
|
||||||
|
per_block = total;
|
||||||
|
start = thread_y;
|
||||||
|
end = total;
|
||||||
|
}
|
||||||
|
|
||||||
|
LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
|
||||||
|
loop.next(start, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
if (tile_x * BN + BN <= args.reduction_stride) {
|
if (tile_x * BN + BN <= args.reduction_stride) {
|
||||||
if (args.reduction_stride % N_READS == 0) {
|
if (args.reduction_stride % N_READS == 0) {
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
for (size_t r = start; r < end; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@@ -132,7 +149,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
for (size_t r = start; r < end; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
|
||||||
for (int i = 0; i < N_READS; i++) {
|
for (int i = 0; i < N_READS; i++) {
|
||||||
@@ -142,7 +159,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t r = thread_y; r < total; r += BM) {
|
for (size_t r = start; r < end; r += BM) {
|
||||||
T vals[N_READS];
|
T vals[N_READS];
|
||||||
cub::LoadDirectBlocked(
|
cub::LoadDirectBlocked(
|
||||||
thread_x,
|
thread_x,
|
||||||
@@ -173,6 +190,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
|
|||||||
|
|
||||||
// Write result.
|
// Write result.
|
||||||
if (warp.thread_rank() == 0) {
|
if (warp.thread_rank() == 0) {
|
||||||
|
if (BLOCKS > 1) {
|
||||||
|
out += tile_out * out_size * args.reduction_stride;
|
||||||
|
}
|
||||||
cub::StoreDirectBlocked(
|
cub::StoreDirectBlocked(
|
||||||
warp.meta_group_rank(),
|
warp.meta_group_rank(),
|
||||||
out + tile_y * args.reduction_stride + tile_x * BN,
|
out + tile_y * args.reduction_stride + tile_x * BN,
|
||||||
@@ -227,11 +247,12 @@ __global__ void col_reduce_small(
|
|||||||
inline auto output_grid_for_col_reduce(
|
inline auto output_grid_for_col_reduce(
|
||||||
const array& out,
|
const array& out,
|
||||||
const cu::ColReduceArgs& args,
|
const cu::ColReduceArgs& args,
|
||||||
int bn) {
|
int bn,
|
||||||
|
int outer = 1) {
|
||||||
int gx, gy = 1;
|
int gx, gy = 1;
|
||||||
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
size_t n_inner_blocks = cuda::ceil_div(args.reduction_stride, bn);
|
||||||
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
size_t n_outer_blocks = out.size() / args.reduction_stride;
|
||||||
size_t n_blocks = n_outer_blocks * n_inner_blocks;
|
size_t n_blocks = n_outer_blocks * n_inner_blocks * outer;
|
||||||
while (n_blocks / gy > INT32_MAX) {
|
while (n_blocks / gy > INT32_MAX) {
|
||||||
gy *= 2;
|
gy *= 2;
|
||||||
}
|
}
|
||||||
@@ -277,7 +298,8 @@ void col_reduce_looped(
|
|||||||
0,
|
0,
|
||||||
indata,
|
indata,
|
||||||
gpu_ptr<U>(out),
|
gpu_ptr<U>(out),
|
||||||
static_cast<cu::ColReduceArgs>(args));
|
static_cast<cu::ColReduceArgs>(args),
|
||||||
|
out.size() / args.reduction_stride);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -320,6 +342,117 @@ void col_reduce_small(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void col_reduce_two_pass(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
Reduce::ReduceType reduce_type,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const ReductionPlan& plan,
|
||||||
|
const cu::ColReduceArgs& args) {
|
||||||
|
// Allocate data for the output using in's layout to access them as
|
||||||
|
// contiguously as possible.
|
||||||
|
allocate_same_layout(out, in, axes, encoder);
|
||||||
|
|
||||||
|
// Allocate an intermediate array to hold the 1st pass result
|
||||||
|
constexpr int outer = 32;
|
||||||
|
|
||||||
|
Shape intermediate_shape;
|
||||||
|
intermediate_shape.push_back(outer);
|
||||||
|
intermediate_shape.insert(
|
||||||
|
intermediate_shape.end(), out.shape().begin(), out.shape().end());
|
||||||
|
|
||||||
|
Strides intermediate_strides;
|
||||||
|
intermediate_strides.push_back(out.size());
|
||||||
|
intermediate_strides.insert(
|
||||||
|
intermediate_strides.end(), out.strides().begin(), out.strides().end());
|
||||||
|
|
||||||
|
array intermediate(intermediate_shape, out.dtype(), nullptr, {});
|
||||||
|
auto [data_size, rc, cc] =
|
||||||
|
check_contiguity(intermediate_shape, intermediate_strides);
|
||||||
|
auto fl = out.flags();
|
||||||
|
fl.row_contiguous = rc;
|
||||||
|
fl.col_contiguous = cc;
|
||||||
|
fl.contiguous = true;
|
||||||
|
intermediate.set_data(
|
||||||
|
cu::malloc_async(intermediate.nbytes(), encoder),
|
||||||
|
data_size,
|
||||||
|
intermediate_strides,
|
||||||
|
fl,
|
||||||
|
allocator::free);
|
||||||
|
|
||||||
|
encoder.add_temporary(intermediate);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(intermediate);
|
||||||
|
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||||
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
|
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||||
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
|
T* indata = const_cast<T*>(gpu_ptr<T>(in));
|
||||||
|
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
dim3 grid = output_grid_for_col_reduce(out, args, BN, outer);
|
||||||
|
int blocks = BM * BN / N_READS;
|
||||||
|
auto kernel = cu::
|
||||||
|
col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS, outer>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid,
|
||||||
|
blocks,
|
||||||
|
0,
|
||||||
|
indata,
|
||||||
|
gpu_ptr<U>(intermediate),
|
||||||
|
static_cast<cu::ColReduceArgs>(args),
|
||||||
|
out.size() / args.reduction_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Prepare the reduction arguments for the 2nd pass
|
||||||
|
cu::ColReduceArgs second_args = args;
|
||||||
|
second_args.reduction_size = outer;
|
||||||
|
second_args.reduction_stride = out.size();
|
||||||
|
second_args.ndim = 0;
|
||||||
|
second_args.reduce_shape[0] = outer;
|
||||||
|
second_args.reduce_strides[0] = out.size();
|
||||||
|
second_args.reduce_ndim = 1;
|
||||||
|
second_args.non_col_reductions = 1;
|
||||||
|
|
||||||
|
encoder.set_input_array(intermediate);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
dispatch_all_types(intermediate.dtype(), [&](auto type_tag) {
|
||||||
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
|
dispatch_reduce_ndim(second_args.reduce_ndim, [&](auto reduce_ndim) {
|
||||||
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
|
constexpr int N_READS = 4;
|
||||||
|
constexpr int BM = 32;
|
||||||
|
constexpr int BN = 32;
|
||||||
|
dim3 grid = output_grid_for_col_reduce(out, second_args, BN);
|
||||||
|
int blocks = BM * BN / N_READS;
|
||||||
|
auto kernel =
|
||||||
|
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
grid,
|
||||||
|
blocks,
|
||||||
|
0,
|
||||||
|
gpu_ptr<T>(intermediate),
|
||||||
|
gpu_ptr<U>(out),
|
||||||
|
second_args,
|
||||||
|
second_args.reduction_stride);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void col_reduce(
|
void col_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
const array& in,
|
const array& in,
|
||||||
@@ -334,6 +467,18 @@ void col_reduce(
|
|||||||
// It is a general strided reduce. Each threadblock computes the output for
|
// It is a general strided reduce. Each threadblock computes the output for
|
||||||
// a subrow of the fast moving axis. For instance 32 elements.
|
// a subrow of the fast moving axis. For instance 32 elements.
|
||||||
//
|
//
|
||||||
|
// - col_reduce_small
|
||||||
|
//
|
||||||
|
// It is a column reduce for small columns. Each thread loops over the whole
|
||||||
|
// column without communicating with any other thread.
|
||||||
|
//
|
||||||
|
// - col_reduce_two_pass
|
||||||
|
//
|
||||||
|
// It is a reduce for long columns. To increase parallelism, we split the
|
||||||
|
// reduction in two passes. First we do a column reduce where many
|
||||||
|
// threadblocks operate on different parts of the reduced axis. Then we
|
||||||
|
// perform a final column reduce.
|
||||||
|
//
|
||||||
// Notes: As in row reduce we opt to read as much in order as possible and
|
// Notes: As in row reduce we opt to read as much in order as possible and
|
||||||
// leave transpositions as they are (contrary to our Metal backend).
|
// leave transpositions as they are (contrary to our Metal backend).
|
||||||
//
|
//
|
||||||
@@ -349,6 +494,14 @@ void col_reduce(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Long column with smallish row
|
||||||
|
size_t total_sums = args.non_col_reductions * args.reduction_size;
|
||||||
|
size_t approx_threads = out.size();
|
||||||
|
if (total_sums / approx_threads > 32) {
|
||||||
|
col_reduce_two_pass(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Fallback col reduce
|
// Fallback col reduce
|
||||||
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ void init_reduce(
|
|||||||
Reduce::ReduceType reduce_type) {
|
Reduce::ReduceType reduce_type) {
|
||||||
// Allocate if needed
|
// Allocate if needed
|
||||||
if (out.data_shared_ptr() == nullptr) {
|
if (out.data_shared_ptr() == nullptr) {
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ inline void allocate_same_layout(
|
|||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes,
|
||||||
cu::CommandEncoder& encoder) {
|
cu::CommandEncoder& encoder) {
|
||||||
if (in.flags().row_contiguous) {
|
if (in.flags().row_contiguous) {
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,7 +135,7 @@ inline void allocate_same_layout(
|
|||||||
fl.col_contiguous = cc;
|
fl.col_contiguous = cc;
|
||||||
fl.contiguous = true;
|
fl.contiguous = true;
|
||||||
out.set_data(
|
out.set_data(
|
||||||
cu::malloc_async(out.nbytes(), encoder.stream()),
|
cu::malloc_async(out.nbytes(), encoder),
|
||||||
data_size,
|
data_size,
|
||||||
final_strides,
|
final_strides,
|
||||||
fl,
|
fl,
|
||||||
|
|||||||
@@ -22,26 +22,28 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
|
||||||
template <typename T, int BLOCK_DIM>
|
template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>
|
||||||
struct BlockBroadcastReduce {
|
struct BlockBroadcastReduce {
|
||||||
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
|
using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];
|
||||||
static_assert(BLOCK_DIM % WARP_SIZE == 0);
|
|
||||||
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
|
|
||||||
|
|
||||||
cg::thread_block& block;
|
cg::thread_block& block;
|
||||||
TempStorage& temp;
|
TempStorage& temp;
|
||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
|
||||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
auto warp = cg::tiled_partition<GROUP_DIM>(block);
|
||||||
T x = cg::reduce(warp, input, op);
|
T x = cg::reduce(warp, input, op);
|
||||||
if (warp.thread_rank() == 0) {
|
if constexpr (BLOCK_DIM > GROUP_DIM) {
|
||||||
temp[warp.meta_group_rank()] = x;
|
if (warp.thread_rank() == 0) {
|
||||||
|
temp[warp.meta_group_rank()] = x;
|
||||||
|
}
|
||||||
|
block.sync();
|
||||||
|
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
||||||
|
: init_value;
|
||||||
|
return cg::reduce(warp, x, op);
|
||||||
|
} else {
|
||||||
|
return x;
|
||||||
}
|
}
|
||||||
block.sync();
|
|
||||||
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
|
|
||||||
: init_value;
|
|
||||||
return cg::reduce(warp, x, op);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ T Sum(const T& input) {
|
__device__ T Sum(const T& input) {
|
||||||
@@ -49,6 +51,52 @@ struct BlockBroadcastReduce {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>
|
||||||
|
__global__ void rms_norm_small(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
T* out,
|
||||||
|
float eps,
|
||||||
|
uint32_t axis_size,
|
||||||
|
uint32_t n_rows,
|
||||||
|
int64_t w_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;
|
||||||
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
auto row =
|
||||||
|
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
|
||||||
|
if (row >= n_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
x += row * axis_size;
|
||||||
|
out += row * axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float normalizer = 0;
|
||||||
|
auto index = block.thread_index().x;
|
||||||
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float t = static_cast<float>(xn[i]);
|
||||||
|
normalizer += t * t;
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
|
||||||
|
normalizer = rsqrt(normalizer / axis_size + eps);
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
float y = static_cast<float>(xn[i]) * normalizer;
|
||||||
|
xn[i] = wn[i] * static_cast<T>(y);
|
||||||
|
}
|
||||||
|
store_vector<N_READS>(out, index, xn, axis_size);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
template <typename T, int BLOCK_DIM, int N_READS = 4>
|
||||||
__global__ void rms_norm(
|
__global__ void rms_norm(
|
||||||
const T* x,
|
const T* x,
|
||||||
@@ -94,6 +142,74 @@ __global__ void rms_norm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
bool HAS_W,
|
||||||
|
int BLOCK_DIM,
|
||||||
|
int REDUCE_DIM,
|
||||||
|
int N_READS = 4>
|
||||||
|
__global__ void rms_norm_vjp_small(
|
||||||
|
const T* x,
|
||||||
|
const T* w,
|
||||||
|
const T* g,
|
||||||
|
T* gx,
|
||||||
|
T* gw,
|
||||||
|
float eps,
|
||||||
|
int32_t axis_size,
|
||||||
|
int32_t n_rows,
|
||||||
|
int64_t w_stride) {
|
||||||
|
auto grid = cg::this_grid();
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;
|
||||||
|
__shared__ typename BlockReduceF2::TempStorage temp;
|
||||||
|
|
||||||
|
auto row =
|
||||||
|
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
|
||||||
|
if (row >= n_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
x += row * axis_size;
|
||||||
|
g += row * axis_size;
|
||||||
|
gx += row * axis_size;
|
||||||
|
gw += row * axis_size;
|
||||||
|
|
||||||
|
// Normalizer.
|
||||||
|
float2 factors = {};
|
||||||
|
auto index = block.thread_index().x;
|
||||||
|
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
|
||||||
|
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
|
||||||
|
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float t = static_cast<float>(xn[i]);
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
float wg = wi * gi;
|
||||||
|
factors = plus_f2(factors, {wg * t, t * t});
|
||||||
|
}
|
||||||
|
|
||||||
|
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
|
||||||
|
float meangwx = factors.x / axis_size;
|
||||||
|
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||||
|
float normalizer3 = normalizer * normalizer * normalizer;
|
||||||
|
|
||||||
|
// Outputs.
|
||||||
|
for (int i = 0; i < N_READS; i++) {
|
||||||
|
float xi = xn[i];
|
||||||
|
float wi = wn[i];
|
||||||
|
float gi = gn[i];
|
||||||
|
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
wn[i] = static_cast<T>(gi * xi * normalizer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
store_vector<N_READS>(gx, index, xn, axis_size);
|
||||||
|
if constexpr (HAS_W) {
|
||||||
|
store_vector<N_READS>(gw, index, wn, axis_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
|
||||||
__global__ void rms_norm_vjp(
|
__global__ void rms_norm_vjp(
|
||||||
const T* x,
|
const T* x,
|
||||||
@@ -107,12 +223,8 @@ __global__ void rms_norm_vjp(
|
|||||||
auto grid = cg::this_grid();
|
auto grid = cg::this_grid();
|
||||||
auto block = cg::this_thread_block();
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
|
|
||||||
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
|
||||||
__shared__ union {
|
__shared__ typename BlockReduceF2::TempStorage temp;
|
||||||
typename BlockReduceF::TempStorage f;
|
|
||||||
typename BlockReduceF2::TempStorage f2;
|
|
||||||
} temp;
|
|
||||||
|
|
||||||
x += grid.block_rank() * axis_size;
|
x += grid.block_rank() * axis_size;
|
||||||
g += grid.block_rank() * axis_size;
|
g += grid.block_rank() * axis_size;
|
||||||
@@ -134,7 +246,7 @@ __global__ void rms_norm_vjp(
|
|||||||
factors = plus_f2(factors, {wg * t, t * t});
|
factors = plus_f2(factors, {wg * t, t * t});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
|
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
|
||||||
float meangwx = factors.x / axis_size;
|
float meangwx = factors.x / axis_size;
|
||||||
float normalizer = rsqrt(factors.y / axis_size + eps);
|
float normalizer = rsqrt(factors.y / axis_size + eps);
|
||||||
float normalizer3 = normalizer * normalizer * normalizer;
|
float normalizer3 = normalizer * normalizer * normalizer;
|
||||||
@@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) {
|
|||||||
return s.device == Device::cpu;
|
return s.device == Device::cpu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int n_per_thread, typename F>
|
||||||
|
void dispatch_group_dim(int axis_size, F&& f) {
|
||||||
|
if (axis_size <= n_per_thread * 8) {
|
||||||
|
f(std::integral_constant<int, 8>{},
|
||||||
|
std::integral_constant<int, 1>(),
|
||||||
|
std::integral_constant<int, 16>());
|
||||||
|
} else if (axis_size <= n_per_thread * 16) {
|
||||||
|
f(std::integral_constant<int, 16>{},
|
||||||
|
std::integral_constant<int, 1>(),
|
||||||
|
std::integral_constant<int, 8>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 1>(),
|
||||||
|
std::integral_constant<int, 4>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 2) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 2>(),
|
||||||
|
std::integral_constant<int, 2>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 4) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 4>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 8) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 8>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
} else if (axis_size <= n_per_thread * 32 * 16) {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 16>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
} else {
|
||||||
|
f(std::integral_constant<int, 32>{},
|
||||||
|
std::integral_constant<int, 32>(),
|
||||||
|
std::integral_constant<int, 1>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
// TODO: There are duplicate code with backend/metal/normalization.cpp
|
||||||
void RMSNorm::eval_gpu(
|
void RMSNorm::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
@@ -190,7 +339,7 @@ void RMSNorm::eval_gpu(
|
|||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
@@ -216,12 +365,33 @@ void RMSNorm::eval_gpu(
|
|||||||
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
constexpr int N_READS = 16 / sizeof(DataType);
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
if (axis_size <= N_READS * 1024) {
|
||||||
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
dispatch_group_dim<N_READS>(
|
||||||
|
axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
|
||||||
|
constexpr int block_dim = n_groups() * group_dim();
|
||||||
|
auto kernel =
|
||||||
|
cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;
|
||||||
|
auto n_blocks =
|
||||||
|
(n_rows + groups_per_block() - 1) / groups_per_block();
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
n_blocks,
|
||||||
|
{block_dim, groups_per_block()},
|
||||||
|
0,
|
||||||
|
gpu_ptr<DataType>(x),
|
||||||
|
gpu_ptr<DataType>(w),
|
||||||
|
gpu_ptr<DataType>(out),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
n_rows,
|
||||||
|
w_stride);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto kernel = cu::rms_norm<DataType, 1024, N_READS>;
|
||||||
encoder.add_kernel_node(
|
encoder.add_kernel_node(
|
||||||
kernel,
|
kernel,
|
||||||
n_rows,
|
n_rows,
|
||||||
block_dim(),
|
1024,
|
||||||
0,
|
0,
|
||||||
gpu_ptr<DataType>(x),
|
gpu_ptr<DataType>(x),
|
||||||
gpu_ptr<DataType>(w),
|
gpu_ptr<DataType>(w),
|
||||||
@@ -229,7 +399,7 @@ void RMSNorm::eval_gpu(
|
|||||||
eps_,
|
eps_,
|
||||||
axis_size,
|
axis_size,
|
||||||
w_stride);
|
w_stride);
|
||||||
});
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -274,7 +444,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
gx.copy_shared_buffer(g);
|
gx.copy_shared_buffer(g);
|
||||||
g_in_gx = true;
|
g_in_gx = true;
|
||||||
} else {
|
} else {
|
||||||
gx.set_data(cu::malloc_async(gx.nbytes(), encoder.stream()));
|
gx.set_data(cu::malloc_async(gx.nbytes(), encoder));
|
||||||
}
|
}
|
||||||
if (g_copied && !g_in_gx) {
|
if (g_copied && !g_in_gx) {
|
||||||
encoder.add_temporary(g);
|
encoder.add_temporary(g);
|
||||||
@@ -292,7 +462,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
if (!g_in_gx && donate_g) {
|
if (!g_in_gx && donate_g) {
|
||||||
gw_temp.copy_shared_buffer(g);
|
gw_temp.copy_shared_buffer(g);
|
||||||
} else {
|
} else {
|
||||||
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder.stream()));
|
gw_temp.set_data(cu::malloc_async(gw_temp.nbytes(), encoder));
|
||||||
encoder.add_temporary(gw_temp);
|
encoder.add_temporary(gw_temp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu(
|
|||||||
dispatch_bool(has_w, [&](auto has_w_constant) {
|
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
constexpr int N_READS = 16 / sizeof(DataType);
|
constexpr int N_READS = 16 / sizeof(DataType);
|
||||||
dispatch_block_dim(
|
if (axis_size <= N_READS * 1024) {
|
||||||
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
dispatch_group_dim<N_READS>(
|
||||||
auto kernel = cu::rms_norm_vjp<
|
axis_size,
|
||||||
DataType,
|
[&](auto group_dim, auto n_groups, auto groups_per_block) {
|
||||||
has_w_constant.value,
|
constexpr int block_dim = group_dim() * n_groups();
|
||||||
block_dim(),
|
auto kernel = cu::rms_norm_vjp_small<
|
||||||
N_READS>;
|
DataType,
|
||||||
encoder.add_kernel_node(
|
has_w_constant.value,
|
||||||
kernel,
|
block_dim,
|
||||||
n_rows,
|
group_dim(),
|
||||||
block_dim(),
|
N_READS>;
|
||||||
0,
|
auto n_blocks =
|
||||||
gpu_ptr<DataType>(x),
|
(n_rows + groups_per_block() - 1) / groups_per_block();
|
||||||
gpu_ptr<DataType>(w),
|
encoder.add_kernel_node(
|
||||||
gpu_ptr<DataType>(g),
|
kernel,
|
||||||
gpu_ptr<DataType>(gx),
|
n_blocks,
|
||||||
gpu_ptr<DataType>(gw_temp),
|
{block_dim, groups_per_block()},
|
||||||
eps_,
|
0,
|
||||||
axis_size,
|
gpu_ptr<DataType>(x),
|
||||||
w_stride);
|
gpu_ptr<DataType>(w),
|
||||||
});
|
gpu_ptr<DataType>(g),
|
||||||
|
gpu_ptr<DataType>(gx),
|
||||||
|
gpu_ptr<DataType>(gw_temp),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
n_rows,
|
||||||
|
w_stride);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto kernel =
|
||||||
|
cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
n_rows,
|
||||||
|
1024,
|
||||||
|
0,
|
||||||
|
gpu_ptr<DataType>(x),
|
||||||
|
gpu_ptr<DataType>(w),
|
||||||
|
gpu_ptr<DataType>(g),
|
||||||
|
gpu_ptr<DataType>(gx),
|
||||||
|
gpu_ptr<DataType>(gw_temp),
|
||||||
|
eps_,
|
||||||
|
axis_size,
|
||||||
|
w_stride);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -292,14 +292,14 @@ void RoPE::eval_gpu(
|
|||||||
donated = true;
|
donated = true;
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
}
|
}
|
||||||
strides[0] = mat_size;
|
strides[0] = mat_size;
|
||||||
strides[1] = in.strides()[ndim - 2];
|
strides[1] = in.strides()[ndim - 2];
|
||||||
strides[2] = in.strides()[ndim - 1];
|
strides[2] = in.strides()[ndim - 1];
|
||||||
} else if (dispatch_ndim == 3) {
|
} else if (dispatch_ndim == 3) {
|
||||||
// Handle non-contiguous 3D inputs
|
// Handle non-contiguous 3D inputs
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
strides[0] = in.strides()[ndim - 3];
|
strides[0] = in.strides()[ndim - 3];
|
||||||
strides[1] = in.strides()[ndim - 2];
|
strides[1] = in.strides()[ndim - 2];
|
||||||
strides[2] = in.strides()[ndim - 1];
|
strides[2] = in.strides()[ndim - 1];
|
||||||
|
|||||||
506
mlx/backend/cuda/scaled_dot_product_attention.cpp
Normal file
506
mlx/backend/cuda/scaled_dot_product_attention.cpp
Normal file
@@ -0,0 +1,506 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/cudnn_utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
array prepare_sdpa_input(const array& x, Stream s) {
|
||||||
|
// SDPA kernel's requirements on inputs:
|
||||||
|
// 1. last dim's stride be 1;
|
||||||
|
// 2. pointer be aligned.
|
||||||
|
if (x.strides(-1) != 1 || get_alignment(x) < 16) {
|
||||||
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.add_temporary(x_copy);
|
||||||
|
return x_copy;
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
void malloc_with_same_layout(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
array& o,
|
||||||
|
const array& q) {
|
||||||
|
if (q.flags().row_contiguous) {
|
||||||
|
o.set_data(cu::malloc_async(o.nbytes(), encoder));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// fill_order = argsort(q.strides())
|
||||||
|
Shape fill_order(q.ndim());
|
||||||
|
std::iota(fill_order.begin(), fill_order.end(), 0);
|
||||||
|
std::stable_sort(
|
||||||
|
fill_order.begin(), fill_order.end(), [&q](int idx1, int idx2) {
|
||||||
|
auto s1 = q.strides(idx1) > 0 ? q.strides(idx1) : 1;
|
||||||
|
auto s2 = q.strides(idx2) > 0 ? q.strides(idx2) : 1;
|
||||||
|
return s1 < s2;
|
||||||
|
});
|
||||||
|
// Generate o_strides with fill_order
|
||||||
|
Strides o_strides(q.ndim());
|
||||||
|
int64_t stride = 1;
|
||||||
|
for (int i : fill_order) {
|
||||||
|
o_strides[i] = stride;
|
||||||
|
stride *= o.shape(i);
|
||||||
|
}
|
||||||
|
// o is a transposed contiguous array
|
||||||
|
o.set_data(
|
||||||
|
cu::malloc_async(o.nbytes(), encoder),
|
||||||
|
o.size(),
|
||||||
|
o_strides,
|
||||||
|
{true, false, false});
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int QKV_NDIM = 4;
|
||||||
|
|
||||||
|
struct SDPACacheKey {
|
||||||
|
int device_id;
|
||||||
|
fe::DataType_t cudnn_dtype;
|
||||||
|
std::array<int, QKV_NDIM> q_shape;
|
||||||
|
std::array<int, QKV_NDIM> k_shape;
|
||||||
|
std::array<int, QKV_NDIM> v_shape;
|
||||||
|
std::array<int64_t, QKV_NDIM> q_strides;
|
||||||
|
std::array<int64_t, QKV_NDIM> k_strides;
|
||||||
|
std::array<int64_t, QKV_NDIM> v_strides;
|
||||||
|
bool do_causal;
|
||||||
|
std::array<int, QKV_NDIM> mask_shape;
|
||||||
|
std::array<int64_t, QKV_NDIM> mask_strides;
|
||||||
|
bool output_logsumexp;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline BytesKey<SDPACacheKey> build_sdpa_cache_key(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
|
bool output_logsumexp = true) {
|
||||||
|
BytesKey<SDPACacheKey> cache_key;
|
||||||
|
cache_key.pod = {
|
||||||
|
encoder.device().cuda_device(),
|
||||||
|
dtype_to_cudnn_type(q.dtype()),
|
||||||
|
vector_key<QKV_NDIM>(q.shape()),
|
||||||
|
vector_key<QKV_NDIM>(k.shape()),
|
||||||
|
vector_key<QKV_NDIM>(v.shape()),
|
||||||
|
vector_key<QKV_NDIM>(q.strides()),
|
||||||
|
vector_key<QKV_NDIM>(k.strides()),
|
||||||
|
vector_key<QKV_NDIM>(v.strides()),
|
||||||
|
do_causal,
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
output_logsumexp,
|
||||||
|
};
|
||||||
|
if (mask_arr) {
|
||||||
|
cache_key.pod.mask_shape = vector_key<QKV_NDIM>(mask_arr->shape());
|
||||||
|
cache_key.pod.mask_strides = vector_key<QKV_NDIM>(mask_arr->strides());
|
||||||
|
}
|
||||||
|
return cache_key;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& sdpa_cache() {
|
||||||
|
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
|
||||||
|
"MLX_CUDA_SDPA_CACHE_SIZE", /* default_capacity */ 64);
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& sdpa_backward_cache() {
|
||||||
|
static LRUBytesKeyCache<SDPACacheKey, DnnGraph> cache(
|
||||||
|
"MLX_CUDA_SDPA_BACKWARD_CACHE_SIZE", /* default_capacity */ 64);
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum UIDS {
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
SCALE,
|
||||||
|
BIAS,
|
||||||
|
O,
|
||||||
|
STATS,
|
||||||
|
// Backward graph:
|
||||||
|
D_Q,
|
||||||
|
D_K,
|
||||||
|
D_V,
|
||||||
|
D_O,
|
||||||
|
};
|
||||||
|
|
||||||
|
DnnGraph build_sdpa_graph(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
|
bool output_logsumexp,
|
||||||
|
const array& o,
|
||||||
|
const array& stats) {
|
||||||
|
DnnGraph graph(handle, q.dtype());
|
||||||
|
|
||||||
|
auto q_ = graph.tensor("Q", Q, q);
|
||||||
|
auto k_ = graph.tensor("K", K, k);
|
||||||
|
auto v_ = graph.tensor("V", V, v);
|
||||||
|
|
||||||
|
auto options = fe::graph::SDPA_attributes()
|
||||||
|
.set_name("sdpa_cudnn")
|
||||||
|
.set_attn_scale(graph.scalar("Scale", SCALE, float32))
|
||||||
|
.set_generate_stats(output_logsumexp);
|
||||||
|
if (do_causal) {
|
||||||
|
if (q.shape(2) > k.shape(2)) {
|
||||||
|
options.set_causal_mask(do_causal);
|
||||||
|
} else {
|
||||||
|
options.set_causal_mask_bottom_right(do_causal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (mask_arr) {
|
||||||
|
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [o_, stats_] = graph.sdpa(q_, k_, v_, options);
|
||||||
|
graph.tensor(o_, O, o)->set_output(true);
|
||||||
|
if (output_logsumexp) {
|
||||||
|
graph.tensor(stats_, STATS, stats)->set_output(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_CUDNN_FE_ERROR(graph.prepare());
|
||||||
|
graph.select_behavior_notes(
|
||||||
|
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||||
|
CHECK_CUDNN_FE_ERROR(graph.build());
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
DnnGraph build_sdpa_backward_graph(
|
||||||
|
cudnnHandle_t handle,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
|
const array& o,
|
||||||
|
const array& d_o,
|
||||||
|
const array& stats,
|
||||||
|
array& d_q,
|
||||||
|
array& d_k,
|
||||||
|
array& d_v) {
|
||||||
|
DnnGraph graph(handle, q.dtype());
|
||||||
|
|
||||||
|
auto q_ = graph.tensor("Q", Q, q);
|
||||||
|
auto k_ = graph.tensor("K", K, k);
|
||||||
|
auto v_ = graph.tensor("V", V, v);
|
||||||
|
auto o_ = graph.tensor("O", O, o);
|
||||||
|
auto d_o_ = graph.tensor("D_O", D_O, d_o);
|
||||||
|
auto stats_ = graph.tensor("STATS", STATS, stats);
|
||||||
|
|
||||||
|
auto options = fe::graph::SDPA_backward_attributes()
|
||||||
|
.set_name("sdpa_backward_cudnn")
|
||||||
|
.set_attn_scale(graph.scalar("Scale", SCALE, float32));
|
||||||
|
if (do_causal) {
|
||||||
|
if (q.shape(2) > k.shape(2)) {
|
||||||
|
options.set_causal_mask(do_causal);
|
||||||
|
} else {
|
||||||
|
options.set_causal_mask_bottom_right(do_causal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (mask_arr) {
|
||||||
|
options.set_bias(graph.tensor("BIAS", BIAS, *mask_arr));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [d_q_, d_k_, d_v_] =
|
||||||
|
graph.sdpa_backward(q_, k_, v_, o_, d_o_, stats_, options);
|
||||||
|
graph.tensor(d_q_, D_Q, d_q)->set_output(true);
|
||||||
|
graph.tensor(d_k_, D_K, d_k)->set_output(true);
|
||||||
|
graph.tensor(d_v_, D_V, d_v)->set_output(true);
|
||||||
|
|
||||||
|
CHECK_CUDNN_FE_ERROR(graph.prepare());
|
||||||
|
graph.select_behavior_notes(
|
||||||
|
{fe::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API});
|
||||||
|
CHECK_CUDNN_FE_ERROR(graph.build());
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool supports_sdpa_cudnn(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool do_causal,
|
||||||
|
Stream s) {
|
||||||
|
static bool enabled = env::get_var("MLX_CUDA_USE_CUDNN_SPDA", 1);
|
||||||
|
if (!enabled) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// cuDNN SDPA requires Ampere and later.
|
||||||
|
if (cu::device(s.device).compute_capability_major() < 8) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only use cuDNN for prefilling (T_q > 1) and training (T_q == T_kv).
|
||||||
|
if ((q.shape(2) == 1) && (q.shape(2) != k.shape(2))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// D_qk and D_v must be a multiple of 8 with maximum value 128.
|
||||||
|
if ((q.shape(-1) % 8 != 0) || (q.shape(-1) > 128) || (v.shape(-1) % 8 != 0) ||
|
||||||
|
(v.shape(-1) > 128)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Dtype dtype = q.dtype();
|
||||||
|
return dtype == float16 || dtype == bfloat16;
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_cudnn(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
float scale,
|
||||||
|
array& o,
|
||||||
|
array& stats,
|
||||||
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
|
bool output_logsumexp,
|
||||||
|
Stream s) {
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
|
||||||
|
malloc_with_same_layout(encoder, o, q);
|
||||||
|
|
||||||
|
encoder.set_input_array(q);
|
||||||
|
encoder.set_input_array(k);
|
||||||
|
encoder.set_input_array(v);
|
||||||
|
encoder.set_output_array(o);
|
||||||
|
if (mask_arr) {
|
||||||
|
encoder.set_input_array(*mask_arr);
|
||||||
|
}
|
||||||
|
if (output_logsumexp) {
|
||||||
|
stats.set_data(cu::malloc_async(stats.nbytes(), encoder));
|
||||||
|
encoder.set_output_array(stats);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search cache.
|
||||||
|
auto cache_key = build_sdpa_cache_key(
|
||||||
|
encoder, q, k, v, do_causal, mask_arr, output_logsumexp);
|
||||||
|
auto it = sdpa_cache().find(cache_key);
|
||||||
|
if (it == sdpa_cache().end()) {
|
||||||
|
auto graph = build_sdpa_graph(
|
||||||
|
handle, q, k, v, do_causal, mask_arr, output_logsumexp, o, stats);
|
||||||
|
it = sdpa_cache().emplace(cache_key, std::move(graph)).first;
|
||||||
|
}
|
||||||
|
auto& graph = it->second;
|
||||||
|
|
||||||
|
std::unordered_map<int64_t, void*> variant_pack{
|
||||||
|
{Q, gpu_ptr<void>(q)},
|
||||||
|
{K, gpu_ptr<void>(k)},
|
||||||
|
{V, gpu_ptr<void>(v)},
|
||||||
|
{SCALE, &scale},
|
||||||
|
{O, gpu_ptr<void>(o)}};
|
||||||
|
if (mask_arr) {
|
||||||
|
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
|
||||||
|
}
|
||||||
|
if (output_logsumexp) {
|
||||||
|
variant_pack[STATS] = gpu_ptr<void>(stats);
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void sdpa_backward_cudnn(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
float scale,
|
||||||
|
const array& o,
|
||||||
|
const array& stats,
|
||||||
|
bool do_causal,
|
||||||
|
const std::optional<array>& mask_arr,
|
||||||
|
const array& d_o,
|
||||||
|
array& d_q,
|
||||||
|
array& d_k,
|
||||||
|
array& d_v,
|
||||||
|
Stream s) {
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
|
||||||
|
malloc_with_same_layout(encoder, d_q, q);
|
||||||
|
malloc_with_same_layout(encoder, d_k, k);
|
||||||
|
malloc_with_same_layout(encoder, d_v, v);
|
||||||
|
|
||||||
|
encoder.set_input_array(q);
|
||||||
|
encoder.set_input_array(k);
|
||||||
|
encoder.set_input_array(v);
|
||||||
|
encoder.set_input_array(o);
|
||||||
|
encoder.set_input_array(stats);
|
||||||
|
encoder.set_input_array(d_o);
|
||||||
|
encoder.set_output_array(d_q);
|
||||||
|
encoder.set_output_array(d_k);
|
||||||
|
encoder.set_output_array(d_v);
|
||||||
|
if (mask_arr) {
|
||||||
|
encoder.set_input_array(*mask_arr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search cache.
|
||||||
|
auto cache_key = build_sdpa_cache_key(encoder, q, k, v, do_causal, mask_arr);
|
||||||
|
auto it = sdpa_backward_cache().find(cache_key);
|
||||||
|
if (it == sdpa_backward_cache().end()) {
|
||||||
|
auto graph = build_sdpa_backward_graph(
|
||||||
|
handle, q, k, v, do_causal, mask_arr, o, d_o, stats, d_q, d_k, d_v);
|
||||||
|
it = sdpa_backward_cache().emplace(cache_key, std::move(graph)).first;
|
||||||
|
}
|
||||||
|
auto& graph = it->second;
|
||||||
|
|
||||||
|
std::unordered_map<int64_t, void*> variant_pack{
|
||||||
|
{Q, gpu_ptr<void>(q)},
|
||||||
|
{K, gpu_ptr<void>(k)},
|
||||||
|
{V, gpu_ptr<void>(v)},
|
||||||
|
{SCALE, &scale},
|
||||||
|
{O, gpu_ptr<void>(o)},
|
||||||
|
{STATS, gpu_ptr<void>(stats)},
|
||||||
|
{D_O, gpu_ptr<void>(d_o)},
|
||||||
|
{D_Q, gpu_ptr<void>(d_q)},
|
||||||
|
{D_K, gpu_ptr<void>(d_k)},
|
||||||
|
{D_V, gpu_ptr<void>(d_v)}};
|
||||||
|
if (mask_arr) {
|
||||||
|
variant_pack[BIAS] = gpu_ptr<void>(*mask_arr);
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_CUDNN_FE_ERROR(graph.encode_graph(encoder, std::move(variant_pack)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Defined in scaled_dot_product_attention.cu file.
|
||||||
|
bool supports_sdpa_vector(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool has_mask,
|
||||||
|
bool has_arr_mask,
|
||||||
|
bool do_causal,
|
||||||
|
bool output_logsumexp);
|
||||||
|
void sdpa_vector(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal,
|
||||||
|
const std::optional<array>& sinks,
|
||||||
|
Stream s);
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
|
||||||
|
bool ScaledDotProductAttention::use_fallback(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool has_mask,
|
||||||
|
bool has_arr_mask,
|
||||||
|
bool do_causal,
|
||||||
|
bool is_training,
|
||||||
|
bool output_logsumexp,
|
||||||
|
Stream s) {
|
||||||
|
if (s.device == Device::cpu) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return !supports_sdpa_vector(
|
||||||
|
q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) &&
|
||||||
|
!supports_sdpa_cudnn(q, k, v, do_causal, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ScaledDotProductAttention::supports_bool_mask() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
array q = prepare_sdpa_input(inputs[0], s);
|
||||||
|
array k = prepare_sdpa_input(inputs[1], s);
|
||||||
|
array v = prepare_sdpa_input(inputs[2], s);
|
||||||
|
auto& out = outputs[0];
|
||||||
|
auto& stats = outputs[1];
|
||||||
|
bool has_mask = inputs.size() - has_sinks_ > 3;
|
||||||
|
bool has_arr_mask = has_mask && !do_causal_;
|
||||||
|
|
||||||
|
std::optional<array> mask_arr;
|
||||||
|
if (has_arr_mask) {
|
||||||
|
mask_arr = prepare_sdpa_input(inputs[3], s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (supports_sdpa_vector(
|
||||||
|
q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_)) {
|
||||||
|
if (has_sinks_) {
|
||||||
|
sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s);
|
||||||
|
} else {
|
||||||
|
sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sdpa_cudnn(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
scale_,
|
||||||
|
out,
|
||||||
|
stats,
|
||||||
|
do_causal_,
|
||||||
|
mask_arr,
|
||||||
|
output_logsumexp_,
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
|
||||||
|
// The frontend adds a padding mask when sequence length is not a multiple of
|
||||||
|
// tile size.
|
||||||
|
if (q.shape(2) % 128 != 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return s.device == Device::cpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ScaledDotProductAttentionVJP::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("ScaledDotProductAttentionVJP::eval_gpu");
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
assert(inputs.size() >= 6);
|
||||||
|
int primals_size = inputs.size() - 3;
|
||||||
|
bool has_arr_mask = primals_size > 3 + has_sinks_;
|
||||||
|
|
||||||
|
array q = prepare_sdpa_input(inputs[0], s);
|
||||||
|
array k = prepare_sdpa_input(inputs[1], s);
|
||||||
|
array v = prepare_sdpa_input(inputs[2], s);
|
||||||
|
array o = prepare_sdpa_input(inputs[primals_size], s);
|
||||||
|
array stats = prepare_sdpa_input(inputs[primals_size + 1], s);
|
||||||
|
array d_o = prepare_sdpa_input(inputs[primals_size + 2], s);
|
||||||
|
|
||||||
|
std::optional<array> mask_arr;
|
||||||
|
if (has_arr_mask) {
|
||||||
|
mask_arr = prepare_sdpa_input(inputs[3], s);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(outputs.size() == 3);
|
||||||
|
auto& d_q = outputs[0];
|
||||||
|
auto& d_k = outputs[1];
|
||||||
|
auto& d_v = outputs[2];
|
||||||
|
|
||||||
|
sdpa_backward_cudnn(
|
||||||
|
q, k, v, scale_, o, stats, do_causal_, mask_arr, d_o, d_q, d_k, d_v, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -6,10 +6,6 @@
|
|||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
|
||||||
#include "mlx/transforms_impl.h"
|
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
@@ -565,10 +561,9 @@ void sdpa_vector_2pass_fallback(
|
|||||||
array sums(intermediate_shape, float32, nullptr, {});
|
array sums(intermediate_shape, float32, nullptr, {});
|
||||||
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||||
|
|
||||||
intermediate.set_data(
|
intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
|
||||||
cu::malloc_async(intermediate.nbytes(), encoder.stream()));
|
sums.set_data(cu::malloc_async(sums.nbytes(), encoder));
|
||||||
sums.set_data(cu::malloc_async(sums.nbytes(), encoder.stream()));
|
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder));
|
||||||
maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder.stream()));
|
|
||||||
|
|
||||||
encoder.add_temporary(intermediate);
|
encoder.add_temporary(intermediate);
|
||||||
encoder.add_temporary(sums);
|
encoder.add_temporary(sums);
|
||||||
@@ -663,21 +658,16 @@ void sdpa_vector_fallback(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace fast {
|
bool supports_sdpa_vector(
|
||||||
|
|
||||||
bool ScaledDotProductAttention::use_fallback(
|
|
||||||
const array& q,
|
const array& q,
|
||||||
const array& k,
|
const array& k,
|
||||||
const array& v,
|
const array& v,
|
||||||
bool has_mask,
|
bool has_mask,
|
||||||
bool has_arr_mask,
|
bool has_arr_mask,
|
||||||
bool do_causal,
|
bool do_causal,
|
||||||
Stream s) {
|
bool output_logsumexp) {
|
||||||
if (detail::in_grad_tracing()) {
|
if (output_logsumexp) {
|
||||||
return true;
|
return false;
|
||||||
}
|
|
||||||
if (s.device == Device::cpu) {
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const int value_head_dim = v.shape(-1);
|
const int value_head_dim = v.shape(-1);
|
||||||
@@ -691,29 +681,24 @@ bool ScaledDotProductAttention::use_fallback(
|
|||||||
const bool supported_vector_config =
|
const bool supported_vector_config =
|
||||||
sdpa_supported_head_dim && query_sequence_length < 4;
|
sdpa_supported_head_dim && query_sequence_length < 4;
|
||||||
|
|
||||||
const bool supported_config = supported_vector_config;
|
return supported_vector_config && !has_arr_mask;
|
||||||
|
|
||||||
return has_arr_mask || !supported_config;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ScaledDotProductAttention::eval_gpu(
|
void sdpa_vector(
|
||||||
const std::vector<array>& inputs,
|
const array& q_pre,
|
||||||
array& out) {
|
const array& k_pre,
|
||||||
nvtx3::scoped_range r("ScaledDotProductAttention::eval_gpu");
|
const array& v_pre,
|
||||||
|
float scale,
|
||||||
auto& s = stream();
|
array& o,
|
||||||
|
bool do_causal,
|
||||||
|
const std::optional<array>& sinks_pre,
|
||||||
|
Stream s) {
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
auto& q_pre = inputs[0];
|
|
||||||
auto& k_pre = inputs[1];
|
|
||||||
auto& v_pre = inputs[2];
|
|
||||||
auto& o = out;
|
|
||||||
|
|
||||||
std::vector<array> copies;
|
std::vector<array> copies;
|
||||||
|
|
||||||
// Define some copy functions to ensure the layout of the inputs is as
|
// Define some copy functions to ensure the layout of the inputs is as
|
||||||
// expected.
|
// expected.
|
||||||
copies.reserve(inputs.size());
|
copies.reserve(4);
|
||||||
auto copy_unless = [&copies, &s](
|
auto copy_unless = [&copies, &s](
|
||||||
auto predicate, const array& arr) -> const array& {
|
auto predicate, const array& arr) -> const array& {
|
||||||
if (!predicate(arr)) {
|
if (!predicate(arr)) {
|
||||||
@@ -731,8 +716,8 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
};
|
};
|
||||||
|
|
||||||
std::optional<array> sinks = std::nullopt;
|
std::optional<array> sinks = std::nullopt;
|
||||||
if (has_sinks_) {
|
if (sinks_pre) {
|
||||||
sinks = copy_unless(is_matrix_contiguous, inputs.back());
|
sinks = copy_unless(is_matrix_contiguous, sinks_pre.value());
|
||||||
}
|
}
|
||||||
|
|
||||||
// We are in vector mode ie single query
|
// We are in vector mode ie single query
|
||||||
@@ -788,7 +773,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
};
|
};
|
||||||
|
|
||||||
o.set_data(
|
o.set_data(
|
||||||
cu::malloc_async(o.nbytes(), encoder.stream()),
|
cu::malloc_async(o.nbytes(), encoder),
|
||||||
o.size(),
|
o.size(),
|
||||||
{str_oB, str_oH, str_oL, str_oD},
|
{str_oB, str_oH, str_oL, str_oD},
|
||||||
flags);
|
flags);
|
||||||
@@ -798,8 +783,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
encoder.add_temporary(cp);
|
encoder.add_temporary(cp);
|
||||||
}
|
}
|
||||||
|
|
||||||
return sdpa_vector_fallback(
|
sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks);
|
||||||
s, encoder, q, k, v, scale_, o, do_causal_, sinks);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Full attention mode should never reach here
|
// Full attention mode should never reach here
|
||||||
@@ -808,6 +792,4 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace fast
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -374,7 +374,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
|
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ void concatenate_gpu(
|
|||||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
out.set_data(cu::malloc_async(out.nbytes(), encoder.stream()));
|
out.set_data(cu::malloc_async(out.nbytes(), encoder));
|
||||||
|
|
||||||
auto strides = out.strides();
|
auto strides = out.strides();
|
||||||
auto flags = out.flags();
|
auto flags = out.flags();
|
||||||
@@ -89,7 +89,7 @@ array compute_dynamic_offset(
|
|||||||
if (donate) {
|
if (donate) {
|
||||||
offset.copy_shared_buffer(indices);
|
offset.copy_shared_buffer(indices);
|
||||||
} else {
|
} else {
|
||||||
offset.set_data(cu::malloc_async(offset.itemsize(), encoder.stream()));
|
offset.set_data(cu::malloc_async(offset.itemsize(), encoder));
|
||||||
}
|
}
|
||||||
|
|
||||||
encoder.add_temporary(offset);
|
encoder.add_temporary(offset);
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
cu::malloc_async(x.data_size() * x.itemsize(), encoder.stream()),
|
cu::malloc_async(x.data_size() * x.itemsize(), encoder),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
|
|||||||
@@ -49,14 +49,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
array trans = swapaxes_in_eval(in, axis, last_dim);
|
array trans = swapaxes_in_eval(in, axis, last_dim);
|
||||||
in = contiguous_copy_gpu(trans, s);
|
in = contiguous_copy_gpu(trans, s);
|
||||||
encoder.add_temporary(in);
|
encoder.add_temporary(in);
|
||||||
out = array(
|
out =
|
||||||
cu::malloc_async(out.nbytes(), encoder.stream()),
|
array(cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
|
||||||
in.shape(),
|
|
||||||
out.dtype());
|
|
||||||
encoder.add_temporary(out);
|
encoder.add_temporary(out);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
cu::malloc_async(in.data_size() * out.itemsize(), encoder.stream()),
|
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
@@ -74,17 +72,13 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
if (argsort) {
|
if (argsort) {
|
||||||
// Indices in the sorted dimension.
|
// Indices in the sorted dimension.
|
||||||
array indices(
|
array indices(
|
||||||
cu::malloc_async(out.nbytes(), encoder.stream()),
|
cu::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype());
|
||||||
in.shape(),
|
|
||||||
out.dtype());
|
|
||||||
encoder.add_temporary(indices);
|
encoder.add_temporary(indices);
|
||||||
|
|
||||||
// In argsort though we don't need the result of sorted values, the
|
// In argsort though we don't need the result of sorted values, the
|
||||||
// API requires us to provide an array to store it.
|
// API requires us to provide an array to store it.
|
||||||
array discard(
|
array discard(
|
||||||
cu::malloc_async(in.nbytes(), encoder.stream()),
|
cu::malloc_async(in.nbytes(), encoder), in.shape(), in.dtype());
|
||||||
in.shape(),
|
|
||||||
in.dtype());
|
|
||||||
encoder.add_temporary(discard);
|
encoder.add_temporary(discard);
|
||||||
|
|
||||||
size_t size;
|
size_t size;
|
||||||
@@ -104,9 +98,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
stream));
|
stream));
|
||||||
|
|
||||||
array temp(
|
array temp(
|
||||||
cu::malloc_async(size, encoder.stream()),
|
cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
|
||||||
{static_cast<int>(size)},
|
|
||||||
uint8);
|
|
||||||
encoder.add_temporary(temp);
|
encoder.add_temporary(temp);
|
||||||
|
|
||||||
// Start capturing after allocations
|
// Start capturing after allocations
|
||||||
@@ -148,9 +140,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
stream));
|
stream));
|
||||||
|
|
||||||
array temp(
|
array temp(
|
||||||
cu::malloc_async(size, encoder.stream()),
|
cu::malloc_async(size, encoder), {static_cast<int>(size)}, uint8);
|
||||||
{static_cast<int>(size)},
|
|
||||||
uint8);
|
|
||||||
encoder.add_temporary(temp);
|
encoder.add_temporary(temp);
|
||||||
|
|
||||||
// Start capturing after allocations
|
// Start capturing after allocations
|
||||||
|
|||||||
@@ -3,31 +3,10 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/steel/utils.cuh"
|
#include "mlx/backend/cuda/steel/utils.cuh"
|
||||||
|
#include "mlx/backend/cuda/vector_types.cuh"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
// Map types to their vector of 2 type float -> float2, double -> double2 etc
|
|
||||||
template <typename T>
|
|
||||||
struct Vector2;
|
|
||||||
template <>
|
|
||||||
struct Vector2<double> {
|
|
||||||
using type = double2;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct Vector2<float> {
|
|
||||||
using type = float2;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct Vector2<__half> {
|
|
||||||
using type = __half2;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct Vector2<__nv_bfloat16> {
|
|
||||||
using type = __nv_bfloat162;
|
|
||||||
};
|
|
||||||
template <typename T>
|
|
||||||
using Vector2_t = typename Vector2<T>::type;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The basic building block for Ampere mmas. A 16x16 tile distributed across
|
* The basic building block for Ampere mmas. A 16x16 tile distributed across
|
||||||
* the warp.
|
* the warp.
|
||||||
|
|||||||
@@ -257,9 +257,8 @@ void ternary_op_gpu(
|
|||||||
auto& c = inputs[2];
|
auto& c = inputs[2];
|
||||||
auto topt = get_ternary_op_type(a, b, c);
|
auto topt = get_ternary_op_type(a, b, c);
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
set_ternary_op_output_data(a, b, c, out, topt, [&](auto n) {
|
set_ternary_op_output_data(
|
||||||
return cu::malloc_async(n, encoder.stream());
|
a, b, c, out, topt, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||||
});
|
|
||||||
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
ternary_op_gpu_inplace<Op>(inputs, out, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -208,9 +208,8 @@ void unary_op_gpu(
|
|||||||
const char* op,
|
const char* op,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
set_unary_output_data(inputs[0], out, [&](auto n) {
|
set_unary_output_data(
|
||||||
return cu::malloc_async(n, encoder.stream());
|
inputs[0], out, [&](auto n) { return cu::malloc_async(n, encoder); });
|
||||||
});
|
|
||||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -31,6 +32,13 @@ void check_cuda_error(const char* name, CUresult err) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||||
|
if (err != CUDNN_STATUS_SUCCESS) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
case bool_:
|
case bool_:
|
||||||
@@ -60,7 +68,7 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
|
|||||||
case float64:
|
case float64:
|
||||||
return "double";
|
return "double";
|
||||||
case complex64:
|
case complex64:
|
||||||
return "complex64_t";
|
return "mlx::core::cu::complex64_t";
|
||||||
default:
|
default:
|
||||||
return "unknown";
|
return "unknown";
|
||||||
}
|
}
|
||||||
@@ -72,7 +80,6 @@ CudaGraph::CudaGraph(cu::Device& device) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CudaGraph::end_capture(cudaStream_t stream) {
|
void CudaGraph::end_capture(cudaStream_t stream) {
|
||||||
assert(handle_ == nullptr);
|
|
||||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ inline T* gpu_ptr(array& arr) {
|
|||||||
arr.offset());
|
arr.offset());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For const array, keep constness in pointer unless it is untyped.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline const T* gpu_ptr(const array& arr) {
|
inline std::conditional_t<std::is_same_v<T, void>, void*, const T*> gpu_ptr(
|
||||||
|
const array& arr) {
|
||||||
return gpu_ptr<T>(const_cast<array&>(arr));
|
return gpu_ptr<T>(const_cast<array&>(arr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
48
mlx/backend/cuda/vector_types.cuh
Normal file
48
mlx/backend/cuda/vector_types.cuh
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Vector2;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Vector2<double> {
|
||||||
|
using type = double2;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Vector2<float> {
|
||||||
|
using type = float2;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Vector2<__half> {
|
||||||
|
using type = __half2;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Vector2<__nv_bfloat16> {
|
||||||
|
using type = __nv_bfloat162;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using Vector2_t = typename Vector2<T>::type;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Vector4 {
|
||||||
|
T x, y, z, w;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using Vector4_t = Vector4<T>;
|
||||||
|
|
||||||
|
using bf16x4 = Vector4_t<__nv_bfloat16>;
|
||||||
|
using fp16x4 = Vector4_t<__half>;
|
||||||
|
using fp32x4 = Vector4_t<float>;
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
@@ -44,7 +44,7 @@ void Worker::commit(cudaStream_t stream) {
|
|||||||
}
|
}
|
||||||
signal_event_.record(stream);
|
signal_event_.record(stream);
|
||||||
signal_event_.wait(signal_stream_);
|
signal_event_.wait(signal_stream_);
|
||||||
cudaLaunchHostFunc(signal_stream_, signal, this);
|
CHECK_CUDA_ERROR(cudaLaunchHostFunc(signal_stream_, signal, this));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::thread_fn() {
|
void Worker::thread_fn() {
|
||||||
|
|||||||
@@ -7,8 +7,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s);
|
|
||||||
|
|
||||||
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
void copy_gpu(const array& in, array& out, CopyType ctype) {
|
||||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ void slice_gpu(
|
|||||||
array& out,
|
array& out,
|
||||||
const Shape& start_indices,
|
const Shape& start_indices,
|
||||||
const Shape& strides,
|
const Shape& strides,
|
||||||
const Stream& s) {
|
const Stream&) {
|
||||||
slice(in, out, start_indices, strides);
|
slice(in, out, start_indices, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ make_jit_source(binary_ops)
|
|||||||
make_jit_source(ternary_ops)
|
make_jit_source(ternary_ops)
|
||||||
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
|
||||||
make_jit_source(indexing/scatter kernels/indexing/indexing.h)
|
make_jit_source(indexing/scatter kernels/indexing/indexing.h)
|
||||||
|
make_jit_source(indexing/masked_scatter)
|
||||||
make_jit_source(indexing/gather kernels/indexing/indexing.h)
|
make_jit_source(indexing/gather kernels/indexing/indexing.h)
|
||||||
make_jit_source(indexing/gather_front kernels/indexing/indexing.h)
|
make_jit_source(indexing/gather_front kernels/indexing/indexing.h)
|
||||||
make_jit_source(indexing/gather_axis)
|
make_jit_source(indexing/gather_axis)
|
||||||
|
|||||||
@@ -149,7 +149,9 @@ Buffer MetalAllocator::malloc(size_t size) {
|
|||||||
buf = device_->newBuffer(size, resource_options);
|
buf = device_->newBuffer(size, resource_options);
|
||||||
}
|
}
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
return Buffer{nullptr};
|
std::ostringstream msg;
|
||||||
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
lk.lock();
|
lk.lock();
|
||||||
num_resources_++;
|
num_resources_++;
|
||||||
@@ -201,6 +203,32 @@ size_t MetalAllocator::size(Buffer buffer) const {
|
|||||||
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
|
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Buffer MetalAllocator::make_buffer(void* ptr, size_t size) {
|
||||||
|
auto buf = device_->newBuffer(ptr, size, resource_options, nullptr);
|
||||||
|
if (!buf) {
|
||||||
|
return Buffer{nullptr};
|
||||||
|
}
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
residency_set_.insert(buf);
|
||||||
|
active_memory_ += buf->length();
|
||||||
|
peak_memory_ = std::max(peak_memory_, active_memory_);
|
||||||
|
num_resources_++;
|
||||||
|
return Buffer{static_cast<void*>(buf)};
|
||||||
|
}
|
||||||
|
|
||||||
|
void MetalAllocator::release(Buffer buffer) {
|
||||||
|
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||||
|
if (buf == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
active_memory_ -= buf->length();
|
||||||
|
num_resources_--;
|
||||||
|
lk.unlock();
|
||||||
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
|
buf->release();
|
||||||
|
}
|
||||||
|
|
||||||
MetalAllocator& allocator() {
|
MetalAllocator& allocator() {
|
||||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator
|
// By creating the |allocator_| on heap, the destructor of MetalAllocator
|
||||||
// will not be called on exit and buffers in the cache will be leaked. This
|
// will not be called on exit and buffers in the cache will be leaked. This
|
||||||
|
|||||||
@@ -21,6 +21,9 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
virtual Buffer malloc(size_t size) override;
|
virtual Buffer malloc(size_t size) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
virtual size_t size(Buffer buffer) const override;
|
virtual size_t size(Buffer buffer) const override;
|
||||||
|
virtual Buffer make_buffer(void* ptr, size_t size) override;
|
||||||
|
virtual void release(Buffer buffer) override;
|
||||||
|
|
||||||
size_t get_active_memory() {
|
size_t get_active_memory() {
|
||||||
return active_memory_;
|
return active_memory_;
|
||||||
};
|
};
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user