mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-09 04:34:11 +08:00
Compare commits
198 Commits
split_logs
...
sdpav-back
Author | SHA1 | Date | |
---|---|---|---|
![]() |
a22d0bf273 | ||
![]() |
99d8de8445 | ||
![]() |
c66b76a8c8 | ||
![]() |
f81edd184f | ||
![]() |
7f8ba2a003 | ||
![]() |
c28249b81a | ||
![]() |
e74bcdc5e3 | ||
![]() |
d8ed6c1aa3 | ||
![]() |
db5c7efcf6 | ||
![]() |
7bb96e4249 | ||
![]() |
fa89f0b150 | ||
![]() |
ca973d1e83 | ||
![]() |
828c5f1137 | ||
![]() |
7d86a5c108 | ||
![]() |
0b807893a7 | ||
![]() |
6ad0889c8a | ||
![]() |
737dd6d1ac | ||
![]() |
aaf78f4c6b | ||
![]() |
8831064493 | ||
![]() |
be9bc96da4 | ||
![]() |
86258f292f | ||
![]() |
b26d88591c | ||
![]() |
86c6a15571 | ||
![]() |
8b25ce62d5 | ||
![]() |
da5912e4f2 | ||
![]() |
daafee676f | ||
![]() |
d32519c8ee | ||
![]() |
b405591249 | ||
![]() |
3bf81ed1bd | ||
![]() |
2204182bba | ||
![]() |
3628e5d497 | ||
![]() |
a0ae49d397 | ||
![]() |
254476718b | ||
![]() |
3adba92ebe | ||
![]() |
ef631d63af | ||
![]() |
970dbe8e25 | ||
![]() |
641be9463b | ||
![]() |
ab0e608862 | ||
![]() |
1588659062 | ||
![]() |
b9e88fb976 | ||
![]() |
4ad53414dd | ||
![]() |
d1165b215e | ||
![]() |
dcb8319f3d | ||
![]() |
5597fa089c | ||
![]() |
9acec364c2 | ||
![]() |
7d9d6ef456 | ||
![]() |
6f5874a2f2 | ||
![]() |
70dc336785 | ||
![]() |
4e504039f5 | ||
![]() |
d1f4d291e8 | ||
![]() |
e1840853ce | ||
![]() |
0f5ce173da | ||
![]() |
588854195f | ||
![]() |
28d068bce6 | ||
![]() |
d107d8d495 | ||
![]() |
1e496ddb82 | ||
![]() |
74eccbf3fa | ||
![]() |
08638223ca | ||
![]() |
56cc858af9 | ||
![]() |
f55c4ed1d6 | ||
![]() |
93d70419e7 | ||
![]() |
63f663d9c6 | ||
![]() |
84b4d96efa | ||
![]() |
aec67f2fa6 | ||
![]() |
deee214a95 | ||
![]() |
45adec102c | ||
![]() |
31fc530c76 | ||
![]() |
fbb3f65a1a | ||
![]() |
6b1b8ea91b | ||
![]() |
b2273733ea | ||
![]() |
f409b229a4 | ||
![]() |
30571e2326 | ||
![]() |
d7734edd9f | ||
![]() |
2ba69bc8fa | ||
![]() |
cb349a291c | ||
![]() |
f0a0b077a0 | ||
![]() |
49114f28ab | ||
![]() |
e7d2ebadd2 | ||
![]() |
e569803d7c | ||
![]() |
d34f887abc | ||
![]() |
5201df5030 | ||
![]() |
2d3c26c565 | ||
![]() |
6325f60d52 | ||
![]() |
42cc9cfbc7 | ||
![]() |
8347575ba1 | ||
![]() |
b6eec20260 | ||
![]() |
0eb035b4b1 | ||
![]() |
afb9817599 | ||
![]() |
8fb3e7a26c | ||
![]() |
8c7bc30ce4 | ||
![]() |
85873cb162 | ||
![]() |
e14ee12491 | ||
![]() |
8b9a3f3cea | ||
![]() |
fb4e8b896b | ||
![]() |
2ca533b279 | ||
![]() |
4a9b29a875 | ||
![]() |
a4fcc893cd | ||
![]() |
9d10239af7 | ||
![]() |
19facd4b20 | ||
![]() |
f5299f72cd | ||
![]() |
0e0d9ac522 | ||
![]() |
8917022deb | ||
![]() |
ec0d5db67b | ||
![]() |
e76e9b87f0 | ||
![]() |
cfb6a244ea | ||
![]() |
58f3860306 | ||
![]() |
dd4f53db63 | ||
![]() |
3d5e17e507 | ||
![]() |
33bf1a244b | ||
![]() |
772f471ff2 | ||
![]() |
2c11d10f8d | ||
![]() |
656ed7f780 | ||
![]() |
81bb9a2a9e | ||
![]() |
5adf185f86 | ||
![]() |
c9a9180584 | ||
![]() |
76831ed83d | ||
![]() |
b3d7b85376 | ||
![]() |
cad5c0241c | ||
![]() |
b8022c578a | ||
![]() |
bc53f8293f | ||
![]() |
c552ff2451 | ||
![]() |
4fda5fbdf9 | ||
![]() |
580776559b | ||
![]() |
a14aaa7c9d | ||
![]() |
a6d780154f | ||
![]() |
6871e2eeb7 | ||
![]() |
8402a2acf4 | ||
![]() |
fddb6933e1 | ||
![]() |
c8b4787e4e | ||
![]() |
2188199ff8 | ||
![]() |
aa07429bad | ||
![]() |
918761a25a | ||
![]() |
a4fc671d3e | ||
![]() |
f5f65ef48c | ||
![]() |
c2dd81a8aa | ||
![]() |
d7e680ffe4 | ||
![]() |
c371baf53a | ||
![]() |
ccf78f566c | ||
![]() |
c9fa68664a | ||
![]() |
c35f4d089a | ||
![]() |
8590c0941e | ||
![]() |
095163b8d1 | ||
![]() |
99c33d011d | ||
![]() |
62fecf3e13 | ||
![]() |
7c4eb5d03e | ||
![]() |
bae9a6b404 | ||
![]() |
004c1d8ef2 | ||
![]() |
7ebb2e0193 | ||
![]() |
9ce77798b1 | ||
![]() |
f8bad60609 | ||
![]() |
5866b3857b | ||
![]() |
1ca616844b | ||
![]() |
2e8cf0b450 | ||
![]() |
24f89173d1 | ||
![]() |
c6a20b427a | ||
![]() |
a5ac9244c4 | ||
![]() |
c763fe1be0 | ||
![]() |
52dc8c8cd5 | ||
![]() |
aede70e81d | ||
![]() |
85a8beb5e4 | ||
![]() |
0bb89e9e5f | ||
![]() |
5685ceb3c7 | ||
![]() |
0408ba0a76 | ||
![]() |
cbad6c3093 | ||
![]() |
1b021f6984 | ||
![]() |
95b7551d65 | ||
![]() |
db5a7c6192 | ||
![]() |
6ef2f67e7f | ||
![]() |
f76ee1ffd2 | ||
![]() |
54a71f270a | ||
![]() |
55b4062dd8 | ||
![]() |
79071bfba4 | ||
![]() |
7774b87cbd | ||
![]() |
35c87741cf | ||
![]() |
4cbe605214 | ||
![]() |
ab8883dd55 | ||
![]() |
eebe73001a | ||
![]() |
0359bf02c9 | ||
![]() |
237f9e58a8 | ||
![]() |
8576e6fe36 | ||
![]() |
0654543dcc | ||
![]() |
48ef3e74e2 | ||
![]() |
7d4b378952 | ||
![]() |
7ff5c41e06 | ||
![]() |
602f43e3d1 | ||
![]() |
a2cadb8218 | ||
![]() |
c1eb9d05d9 | ||
![]() |
cf6c939e86 | ||
![]() |
130df35e1b | ||
![]() |
0751263dec | ||
![]() |
eca2f3eb97 | ||
![]() |
3aa9cf3f9e | ||
![]() |
8f3d208dce | ||
![]() |
caaa3f1f8c | ||
![]() |
659a51919f | ||
![]() |
6661387066 | ||
![]() |
a7fae8a176 | ||
![]() |
0cae0bdac8 |
@@ -7,15 +7,9 @@ parameters:
|
||||
nightly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
weekly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
@@ -38,7 +32,7 @@ jobs:
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install -r docs/requirements.txt
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
||||
pip install . -v
|
||||
- when:
|
||||
condition:
|
||||
not: << parameters.upload-docs >>
|
||||
@@ -70,9 +64,9 @@ jobs:
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -84,37 +78,36 @@ jobs:
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py develop
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
source .venv/bin/activate
|
||||
python -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
mkdir -p build && cd build
|
||||
source .venv/bin/activate
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j `nproc`
|
||||
- run:
|
||||
@@ -139,51 +132,49 @@ jobs:
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.9
|
||||
brew install openmpi
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
pip install unittest-xml-reporting
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
||||
brew install openmpi uv
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
pip install -e . -v
|
||||
uv venv --python 3.9
|
||||
uv pip install \
|
||||
nanobind==2.4.0 \
|
||||
cmake \
|
||||
numpy \
|
||||
torch \
|
||||
tensorflow \
|
||||
unittest-xml-reporting
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
uv pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
cd examples/extensions
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext -j8
|
||||
uv pip install -r requirements.txt
|
||||
uv run --no-project setup.py build_ext --inplace
|
||||
uv run --no-project python test.py
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
@@ -192,7 +183,7 @@ jobs:
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source env/bin/activate
|
||||
source .venv/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
@@ -204,13 +195,60 @@ jobs:
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
pip install -e . -v
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
uv pip install -e .
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
uv run --no-project python -m xmlrunner discover \
|
||||
-v python/tests \
|
||||
-o test-results/gpu_jit
|
||||
|
||||
cuda_build_and_test:
|
||||
parameters:
|
||||
image_date:
|
||||
type: string
|
||||
default: "2023.11.1"
|
||||
machine:
|
||||
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- cuda-<< parameters.image_date >>-{{ arch }}-
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||
rm -rf ccache-4.11.3-linux-x86_64
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
- run:
|
||||
name: CCache report
|
||||
command: |
|
||||
ccache --show-stats
|
||||
ccache --zero-stats
|
||||
ccache --max-size 400MB
|
||||
ccache --cleanup
|
||||
- save_cache:
|
||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||
paths:
|
||||
- /home/circleci/.cache/ccache
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
@@ -252,21 +290,29 @@ jobs:
|
||||
command: |
|
||||
source env/bin/activate
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
<< parameters.build_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
python -m build -w
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.9", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
@@ -283,52 +329,100 @@ jobs:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
extra_env:
|
||||
build_env:
|
||||
type: string
|
||||
default: "DEV_RELEASE=1"
|
||||
docker:
|
||||
- image: ubuntu:20.04
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
PYTHON=python<< parameters.python_version >>
|
||||
apt-get update
|
||||
apt-get upgrade -y
|
||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
||||
apt-get install -y apt-utils
|
||||
apt-get install -y software-properties-common
|
||||
add-apt-repository -y ppa:deadsnakes/ppa
|
||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
apt-get install -y build-essential git
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
TZ=Etc/UTC sudo apt-get -y install tzdata
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
$PYTHON -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
<< parameters.build_env >> pip install ".[dev]" -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
python setup.py generate_stubs
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||
bash python/scripts/repair_linux.sh
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.9", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload packages
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
build_cuda_release:
|
||||
parameters:
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Upload package
|
||||
name: Build wheel
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install zip
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
python -m build -w
|
||||
bash python/scripts/repair_cuda.sh
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
@@ -340,7 +434,6 @@ workflows:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- mac_build_and_test:
|
||||
@@ -348,13 +441,16 @@ workflows:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
and:
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
@@ -437,6 +533,25 @@ workflows:
|
||||
branches:
|
||||
ignore: /.*/
|
||||
upload-docs: true
|
||||
- build_linux_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
|
||||
prb:
|
||||
when:
|
||||
@@ -455,6 +570,11 @@ workflows:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -513,11 +633,17 @@ workflows:
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
weekly_build:
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
- build_cuda_release
|
||||
|
||||
build_dev_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.weekly_build >>
|
||||
- << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
matrix:
|
||||
@@ -587,14 +713,12 @@ workflows:
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.linux_release >>
|
||||
jobs:
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
|
@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@@ -34,13 +34,16 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
message(
|
||||
@@ -63,10 +66,17 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
if(MLX_USE_CCACHE)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
@@ -83,6 +93,10 @@ if(MLX_BUILD_METAL)
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
@@ -226,12 +240,19 @@ target_include_directories(
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
# Do not add mlx_EXPORTS define for shared library.
|
||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||
|
||||
if(USE_SYSTEM_FMT)
|
||||
find_package(fmt REQUIRED)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
endif()
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
|
21
README.md
21
README.md
@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||
more complex models.
|
||||
|
||||
@@ -68,18 +68,23 @@ in the documentation.
|
||||
|
||||
## Installation
|
||||
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||
macOS, run:
|
||||
|
||||
**With `pip`**:
|
||||
|
||||
```
|
||||
```bash
|
||||
pip install mlx
|
||||
```
|
||||
|
||||
**With `conda`**:
|
||||
To install the CUDA backend on Linux, run:
|
||||
|
||||
```bash
|
||||
pip install mlx[cuda]
|
||||
```
|
||||
conda install -c conda-forge mlx
|
||||
|
||||
To install a CPU-only Linux package, run:
|
||||
|
||||
```bash
|
||||
pip install mlx[cpu]
|
||||
```
|
||||
|
||||
Checkout the
|
||||
|
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
|
@@ -192,6 +192,22 @@ void time_reductions() {
|
||||
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
|
||||
auto indices = mx::array({1});
|
||||
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||
std::vector<int> axes{0};
|
||||
auto b = scatter(a, {indices}, updates, axes);
|
||||
mx::eval(b);
|
||||
|
||||
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||
TIME(max_along_0);
|
||||
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||
TIME(max_along_1);
|
||||
|
||||
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||
TIME(min_along_0);
|
||||
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||
TIME(min_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
|
@@ -5,6 +5,7 @@ import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.mps
|
||||
|
||||
|
||||
@@ -44,8 +45,10 @@ def bench(f, *args):
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
if x.device != torch.device("cpu"):
|
||||
if x.device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
elif x.device == torch.device("cuda"):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
@@ -340,7 +351,11 @@ if __name__ == "__main__":
|
||||
args.axis.pop(0)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = "cpu" if args.cpu else "mps"
|
||||
device = "mps"
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
if args.cpu:
|
||||
device = "cpu"
|
||||
|
||||
types = args.dtype
|
||||
if not types:
|
||||
@@ -460,5 +475,8 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
|
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype = "float32"
|
||||
shapes = (
|
||||
(4, 32, 32, 21, 3, 3, 128),
|
||||
(4, 32, 32, 21, 3, 3, 37),
|
||||
(4, 32, 32, 370, 3, 3, 370),
|
||||
(4, 32, 32, 370, 7, 7, 128),
|
||||
(2, 320, 640, 21, 7, 7, 21),
|
||||
)
|
||||
for N, H, W, C, kh, kw, O in shapes:
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@@ -1,5 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
@@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
|
||||
return y
|
||||
|
||||
|
||||
def time_layer_norm():
|
||||
def time_layer_norm(N, dt):
|
||||
L = 1024
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x, w, b):
|
||||
def layer_norm_loop(f, x, w, b):
|
||||
for _ in range(32):
|
||||
x = f(x, w, b)
|
||||
return x
|
||||
|
||||
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||
|
||||
def layer_norm_grad_loop(g, x, w, b):
|
||||
gx, gw, gb = x, w, b
|
||||
for _ in range(32):
|
||||
gx, gw, gb = g(gx, gw, gb, y)
|
||||
return gx, gw, gb
|
||||
|
||||
time_fn(layer_norm_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x):
|
||||
def layer_norm_grad_x_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(layer_norm_loop, g1, x)
|
||||
time_fn(layer_norm_loop, g2, x)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_layer_norm()
|
||||
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||
print(dt, n)
|
||||
time_layer_norm(n, dt)
|
||||
|
@@ -51,6 +51,20 @@ def time_maximum():
|
||||
time_fn(mx.maximum, a, b)
|
||||
|
||||
|
||||
def time_max():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.max, a, 0)
|
||||
|
||||
|
||||
def time_min():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.min, a, 0)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
@@ -108,6 +122,8 @@ if __name__ == "__main__":
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_min()
|
||||
time_max()
|
||||
time_maximum()
|
||||
time_exp()
|
||||
time_negative()
|
||||
|
@@ -11,13 +11,14 @@ include(CMakeParseArguments)
|
||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||
# files (like headers)
|
||||
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
||||
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
||||
#
|
||||
# clang format on
|
||||
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
@@ -26,6 +27,10 @@ macro(mlx_build_metallib)
|
||||
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
||||
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
||||
-frecord-sources)
|
||||
endif()
|
||||
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
|
@@ -10,7 +10,7 @@ import mlx.core as mx
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "MLX"
|
||||
copyright = "2023, MLX Contributors"
|
||||
copyright = "2023, Apple"
|
||||
author = "MLX Contributors"
|
||||
version = ".".join(mx.__version__.split(".")[:3])
|
||||
release = version
|
||||
|
@@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Every time you make a kernel, a new Metal library is created and possibly
|
||||
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||
:func:`fast.metal_kernel` and then use it many times.
|
||||
|
||||
.. note::
|
||||
We are only required to pass the body of the Metal kernel in ``source``.
|
||||
Only pass the body of the Metal kernel in ``source``. The function
|
||||
signature is generated automatically.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
@@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||
dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||
generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||
when indexing.
|
||||
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||
is ``True`` by default. This will copy the array inputs if needed
|
||||
before the kernel is launched to ensure that the memory layout is row
|
||||
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||
have to worry about gaps or the ordering of the dims when indexing.
|
||||
|
||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||
input array ``a`` if any are present in ``source``.
|
||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||
the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||
relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
return output
|
||||
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
@@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||
define its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
requires a few extra :func:`fast.metal_kernel` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
@@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
|
@@ -138,13 +138,13 @@ more concrete:
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
/** The name of primitive. */
|
||||
const char* name() const override {
|
||||
return "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
@@ -394,14 +394,14 @@ below.
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
std::stream kname;
|
||||
kname = "axpby_general_" + type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname, lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@@ -13,7 +13,7 @@ silicon computer is
|
||||
|
||||
pip install mlx
|
||||
|
||||
To install from PyPI you must meet the following requirements:
|
||||
To install from PyPI your system must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.9
|
||||
@@ -23,12 +23,39 @@ To install from PyPI you must meet the following requirements:
|
||||
MLX is only available on devices running macOS >= 13.5
|
||||
It is highly recommended to use macOS 14 (Sonoma)
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
MLX is also available on conda-forge. To install MLX with conda do:
|
||||
MLX has a CUDA backend which you can install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
conda install conda-forge::mlx
|
||||
pip install mlx[cuda]
|
||||
|
||||
To install the CUDA package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Nvidia architecture >= SM 7.0 (Volta)
|
||||
- Nvidia driver >= 550.54.14
|
||||
- CUDA toolkit >= 12.0
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.9
|
||||
|
||||
|
||||
CPU-only (Linux)
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
For a CPU-only version of MLX that runs on Linux use:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx[cpu]
|
||||
|
||||
To install the CPU-only package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.9
|
||||
|
||||
|
||||
Troubleshooting
|
||||
@@ -65,6 +92,8 @@ Build Requirements
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
|
||||
.. _python install:
|
||||
|
||||
To build and install the MLX python library from source, first, clone MLX from
|
||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
@@ -76,20 +105,20 @@ Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
||||
pip install .
|
||||
|
||||
For developing, install the package with development dependencies, and use an
|
||||
editable install:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
||||
pip install -e ".[dev]"
|
||||
|
||||
Once the development dependencies are installed, you can build faster with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
Run the tests with:
|
||||
|
||||
@@ -107,6 +136,8 @@ IDE:
|
||||
C++ API
|
||||
^^^^^^^
|
||||
|
||||
.. _cpp install:
|
||||
|
||||
Currently, MLX must be built and installed from source.
|
||||
|
||||
Similarly to the python library, to build and install the MLX C++ library start
|
||||
@@ -185,6 +216,7 @@ should point to the path to the built metal library.
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -213,6 +245,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Linux
|
||||
^^^^^
|
||||
|
||||
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||
For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
apt-get update -y
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
From here follow the instructions to install either the :ref:`Python <python
|
||||
install>` or :ref:`C++ <cpp install>` APIs.
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
apt-get update -y
|
||||
apt-get -y install cuda-toolkit-12-9
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
|
||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||
|
||||
To build the C++ package run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
@@ -19,6 +19,8 @@ Array
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
array.real
|
||||
array.imag
|
||||
array.abs
|
||||
array.all
|
||||
array.any
|
||||
|
@@ -16,6 +16,8 @@ Linear Algebra
|
||||
cross
|
||||
qr
|
||||
svd
|
||||
eigvals
|
||||
eig
|
||||
eigvalsh
|
||||
eigh
|
||||
lu
|
||||
|
@@ -19,3 +19,4 @@ Common Optimizers
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
Muon
|
||||
|
@@ -107,6 +107,16 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
|
||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> a[[0, 0]] = mx.array([4, 5])
|
||||
|
||||
The first element of ``a`` could be ``4`` or ``5``.
|
||||
|
||||
Transformations of functions which use in-place updates are allowed and work as
|
||||
expected. For example:
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
@@ -16,6 +17,19 @@
|
||||
|
||||
namespace my_ext {
|
||||
|
||||
// A helper function to find the location of the current binary on disk.
|
||||
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
|
||||
std::string current_binary_dir() {
|
||||
static std::string binary_dir = []() {
|
||||
Dl_info info;
|
||||
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||
throw std::runtime_error("Unable to get current binary dir.");
|
||||
}
|
||||
return std::filesystem::path(info.dli_fname).parent_path().string();
|
||||
}();
|
||||
return binary_dir;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -167,16 +181,15 @@ void Axpby::eval_gpu(
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_";
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
std::string kname = "axpby_";
|
||||
kname += (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname += type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname, lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
/** The name of primitive. */
|
||||
const char* name() const override {
|
||||
return "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.25
|
||||
mlx>=0.21.0
|
||||
nanobind==2.2.0
|
||||
nanobind==2.4.0
|
||||
|
@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
|
||||
|
||||
a = mx.ones((3, 4))
|
||||
b = mx.ones((3, 4))
|
||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c shape: {c_cpu.shape}")
|
||||
print(f"c dtype: {c_cpu.dtype}")
|
||||
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
|
||||
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
|
||||
|
@@ -21,7 +21,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
# Define MLX_VERSION only in the version.cpp file.
|
||||
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||
|
||||
@@ -47,10 +47,21 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||
endif()
|
||||
|
@@ -10,6 +10,7 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/small_vector.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -18,8 +19,8 @@ class Primitive;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using ShapeElem = int32_t;
|
||||
using Shape = std::vector<ShapeElem>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
using Shape = SmallVector<ShapeElem>;
|
||||
using Strides = SmallVector<int64_t>;
|
||||
|
||||
class array {
|
||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||
@@ -224,6 +225,10 @@ class array {
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||
o.buffer = allocator::Buffer(nullptr);
|
||||
o.d = [](allocator::Buffer) {};
|
||||
}
|
||||
~Data() {
|
||||
d(buffer);
|
||||
}
|
||||
|
157
mlx/backend/common/buffer_cache.h
Normal file
157
mlx/backend/common/buffer_cache.h
Normal file
@@ -0,0 +1,157 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(
|
||||
size_t page_size,
|
||||
std::function<size_t(T*)> get_size,
|
||||
std::function<void(T*)> free)
|
||||
: page_size_(page_size),
|
||||
get_size_(std::move(get_size)),
|
||||
free_(std::move(free)) {}
|
||||
|
||||
~BufferCache() {
|
||||
clear();
|
||||
}
|
||||
|
||||
BufferCache(const BufferCache&) = delete;
|
||||
BufferCache& operator=(const BufferCache&) = delete;
|
||||
|
||||
T* reuse_from_cache(size_t size) {
|
||||
// Find the closest buffer in pool.
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
if (it == buffer_pool_.end() ||
|
||||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Collect from the cache.
|
||||
T* buf = it->second->buf;
|
||||
pool_size_ -= it->first;
|
||||
|
||||
// Remove from record.
|
||||
remove_from_list(it->second);
|
||||
buffer_pool_.erase(it);
|
||||
return buf;
|
||||
}
|
||||
|
||||
void recycle_to_cache(T* buf) {
|
||||
assert(buf);
|
||||
// Add to cache.
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
size_t size = get_size_(buf);
|
||||
pool_size_ += size;
|
||||
buffer_pool_.emplace(size, bh);
|
||||
}
|
||||
|
||||
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
// Release buffer.
|
||||
size_t size = get_size_(tail_->buf);
|
||||
total_bytes_freed += size;
|
||||
free_(tail_->buf);
|
||||
n_release++;
|
||||
|
||||
// Remove from record.
|
||||
auto its = buffer_pool_.equal_range(size);
|
||||
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||
return el.second == tail_;
|
||||
});
|
||||
assert(it != buffer_pool_.end());
|
||||
buffer_pool_.erase(it);
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
int clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
free_(holder->buf);
|
||||
n_release++;
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
size_t cache_size() const {
|
||||
return pool_size_;
|
||||
}
|
||||
|
||||
size_t page_size() const {
|
||||
return page_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||
|
||||
BufferHolder* prev{nullptr};
|
||||
BufferHolder* next{nullptr};
|
||||
T* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add) {
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void remove_from_list(BufferHolder* to_remove) {
|
||||
if (to_remove->prev && to_remove->next) { // if middle
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
delete to_remove;
|
||||
}
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_{nullptr};
|
||||
BufferHolder* tail_{nullptr};
|
||||
size_t pool_size_{0};
|
||||
|
||||
const size_t page_size_;
|
||||
std::function<size_t(T*)> get_size_;
|
||||
std::function<void(T*)> free_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,8 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -15,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) {
|
||||
return print_float_constant<float16_t>(os, x);
|
||||
case bfloat16:
|
||||
return print_float_constant<bfloat16_t>(os, x);
|
||||
case float64:
|
||||
return print_float_constant<double>(os, x);
|
||||
case complex64:
|
||||
return print_complex_constant<complex64_t>(os, x);
|
||||
case int8:
|
||||
@@ -51,6 +52,8 @@ std::string get_type_string(Dtype d) {
|
||||
return "float16_t";
|
||||
case bfloat16:
|
||||
return "bfloat16_t";
|
||||
case float64:
|
||||
return "double";
|
||||
case complex64:
|
||||
return "complex64_t";
|
||||
case bool_:
|
||||
@@ -79,55 +82,6 @@ std::string get_type_string(Dtype d) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
NodeNamer namer;
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// Fill the input names. This is not really necessary, I just like having A,
|
||||
// B, C, ... as the inputs.
|
||||
for (auto& x : inputs) {
|
||||
namer.get_name(x);
|
||||
}
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << (is_scalar(x) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const Shape& shape) {
|
||||
@@ -159,8 +113,7 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
@@ -175,8 +128,7 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
in.is_donatable() && is_constant(i)) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
@@ -204,7 +156,7 @@ void compiled_allocate_outputs(
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
is_constant(i)) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
@@ -216,4 +168,74 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant) {
|
||||
const Shape& shape = out.shape();
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
if (contiguous) {
|
||||
return {true, shape, {}};
|
||||
}
|
||||
|
||||
std::vector<Strides> strides_vec{out.strides()};
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants.
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip scalar inputs.
|
||||
const auto& x = inputs[i];
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
size_t j = 0;
|
||||
for (; j < shape.size() - x.ndim(); ++j) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides_vec.push_back(std::move(xstrides));
|
||||
}
|
||||
|
||||
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||
}
|
||||
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (const auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (const auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,9 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -14,19 +13,17 @@ inline bool is_static_cast(const Primitive& p) {
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids);
|
||||
|
||||
std::string get_type_string(Dtype d);
|
||||
|
||||
template <typename T>
|
||||
void print_float_constant(std::ostream& os, const array& x) {
|
||||
auto old_precision = os.precision();
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||
<< x.item<T>() << std::setprecision(old_precision);
|
||||
if constexpr (std::is_same_v<T, double>) {
|
||||
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
||||
} else {
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
||||
}
|
||||
os << x.item<T>() << std::setprecision(old_precision);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -60,8 +57,19 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous);
|
||||
|
||||
// Collapse contiguous dims ignoring scalars and constants.
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant);
|
||||
|
||||
// Return whether the kernel should use large index.
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
return true;
|
||||
} else {
|
||||
|
67
mlx/backend/common/matmul.h
Normal file
67
mlx/backend/common/matmul.h
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||
const array& a,
|
||||
const array& b) {
|
||||
if (a.ndim() == 2) {
|
||||
return {{1}, {0}, {0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||
|
||||
auto a_batch_strides = batch_strides[0];
|
||||
auto b_batch_strides = batch_strides[1];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
a_batch_strides.push_back(0);
|
||||
b_batch_strides.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
||||
}
|
||||
|
||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||
collapse_batches(const array& a, const array& b, const array& c) {
|
||||
if (a.ndim() == 2) {
|
||||
return {{1}, {0}, {0}, {0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
auto C_batch_stride = batch_strides[2];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
A_batch_stride.push_back(0);
|
||||
B_batch_stride.push_back(0);
|
||||
C_batch_stride.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -5,11 +5,9 @@
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
Shape shape,
|
||||
Strides strides,
|
||||
const std::vector<int>& axes) {
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
@@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
return shapes_without_reduction_axes(
|
||||
std::move(shape), std::move(strides), axes);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
|
@@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes);
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
Shape shape,
|
||||
Strides strides,
|
||||
const std::vector<int>& axes);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
26
mlx/backend/common/unary.h
Normal file
26
mlx/backend/common/unary.h
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,9 +1,22 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::filesystem::path current_binary_dir() {
|
||||
static std::filesystem::path binary_dir = []() {
|
||||
Dl_info info;
|
||||
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||
throw std::runtime_error("Unable to get current binary dir.");
|
||||
}
|
||||
return std::filesystem::path(info.dli_fname).parent_path();
|
||||
}();
|
||||
return binary_dir;
|
||||
}
|
||||
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
@@ -101,4 +114,145 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||
}
|
||||
|
||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
while (true) {
|
||||
int presum = sum;
|
||||
// Check all the pows
|
||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||
pows[0]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||
pows[1]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||
pows[2]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == presum || sum == pow2) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
||||
}
|
||||
|
||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
||||
// Dims with strides of 0 are ignored as they
|
||||
// correspond to broadcasted dimensions
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
return std::make_tuple(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
Dims get_2d_grid_dims_common(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor) {
|
||||
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||
// divided by divisor.
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// No need to add this shape we can just remove it from the divisor.
|
||||
if (divisor % shape[i] == 0) {
|
||||
divisor /= shape[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
|
||||
if (divisor > 1) {
|
||||
if (grid_x % divisor == 0) {
|
||||
grid_x /= divisor;
|
||||
divisor = 1;
|
||||
} else if (grid_y % divisor == 0) {
|
||||
grid_y /= divisor;
|
||||
divisor = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
if (divisor > 1) {
|
||||
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
|
||||
}
|
||||
return std::make_tuple(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
||||
auto gx = (dim0 + bx - 1) / bx;
|
||||
auto gy = (dim1 + by - 1) / by;
|
||||
auto gz = (dim2 + bz - 1) / bz;
|
||||
|
||||
return std::make_pair(
|
||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||
}
|
||||
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||
int ndim = x.ndim();
|
||||
if (axis1 < 0) {
|
||||
axis1 += ndim;
|
||||
}
|
||||
if (axis2 < 0) {
|
||||
axis2 += ndim;
|
||||
}
|
||||
|
||||
auto shape = x.shape();
|
||||
std::swap(shape[axis1], shape[axis2]);
|
||||
auto strides = x.strides();
|
||||
std::swap(strides[axis1], strides[axis2]);
|
||||
|
||||
auto [data_size, row_contiguous, col_contiguous] =
|
||||
check_contiguity(shape, strides);
|
||||
bool contiguous = data_size == x.data_size();
|
||||
|
||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(
|
||||
x,
|
||||
std::move(strides),
|
||||
{contiguous, row_contiguous, col_contiguous},
|
||||
x.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,12 +2,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <filesystem>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Return the directory that contains current shared library.
|
||||
std::filesystem::path current_binary_dir();
|
||||
|
||||
inline int64_t
|
||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
int64_t loc = 0;
|
||||
@@ -70,6 +75,31 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const array& a,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
// Compute the thread block dimensions which fit the given
|
||||
// input dimensions.
|
||||
// - The thread block dimensions will be powers of two
|
||||
// - The thread block size will be less than 2^pow2
|
||||
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
|
||||
// Computes a 2D grid where each element is < UINT_MAX
|
||||
// Assumes:
|
||||
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||
// - shape and strides correspond to a contiguous (no holes) but
|
||||
// possibly broadcasted array
|
||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
||||
|
||||
// Same as above but we do an implicit division with divisor.
|
||||
// Basically, equivalent to factorizing
|
||||
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||
Dims get_2d_grid_dims_common(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
|
||||
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
||||
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int dims = shape_.size();
|
||||
@@ -165,4 +195,14 @@ void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
|
||||
// Like the swapaxes op but safe to call in eval_gpu.
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||
|
||||
template <typename T>
|
||||
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||
vec.erase(std::next(vec.begin(), index));
|
||||
return vec;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -46,6 +46,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
|
@@ -14,10 +14,8 @@ template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
Strides strides = in.strides();
|
||||
Shape shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
Strides strides = remove_index(in.strides(), axis);
|
||||
Shape shape = remove_index(in.shape(), axis);
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
|
||||
|
@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy(
|
||||
copy_cpu(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
|
@@ -146,18 +146,9 @@ inline void build_kernel(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous,
|
||||
int ndim) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
NodeNamer namer;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
@@ -170,14 +161,15 @@ inline void build_kernel(
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
auto tstr = get_type_string(x.dtype());
|
||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
@@ -211,10 +203,11 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
@@ -238,7 +231,7 @@ inline void build_kernel(
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||
} else {
|
||||
x.primitive().print(os);
|
||||
os << x.primitive().name();
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
@@ -264,8 +257,9 @@ inline void build_kernel(
|
||||
} else {
|
||||
for (int d = ndim - 1; d >= 0; --d) {
|
||||
// Update pointers
|
||||
for (auto& x : inputs) {
|
||||
if (is_constant(x) || is_scalar(x)) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
if (is_constant(i) || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
auto& xname = namer.get_name(x);
|
||||
@@ -287,65 +281,45 @@ inline void build_kernel(
|
||||
void Compiled::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Force allocating shape/strides on heap so we can take their data() first
|
||||
// and then std::move them.
|
||||
// TODO: Refactor code to avoid heap allocation.
|
||||
shape.grow();
|
||||
for (auto& s : strides) {
|
||||
s.grow();
|
||||
}
|
||||
|
||||
// Collect function input arguments.
|
||||
std::vector<void*> args;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
const auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
|
||||
if (contiguous || is_scalar(x)) {
|
||||
continue;
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
args.push_back(strides[strides_index++].data());
|
||||
}
|
||||
|
||||
// Broadcast the input to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < shape.size() - x.ndim(); j++) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides.push_back(std::move(xstrides));
|
||||
args.push_back(strides.back().data());
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
int ndim = shape.size();
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
if (!contiguous) {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
kernel_name += std::to_string(ndim);
|
||||
}
|
||||
|
||||
// Get the function
|
||||
auto fn_ptr = compile(kernel_name, [&]() {
|
||||
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||
std::ostringstream kernel;
|
||||
kernel << get_kernel_preamble() << std::endl;
|
||||
kernel << "extern \"C\" {" << std::endl;
|
||||
@@ -355,7 +329,7 @@ void Compiled::eval_cpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
contiguous,
|
||||
ndim);
|
||||
// Close extern "C"
|
||||
@@ -363,26 +337,22 @@ void Compiled::eval_cpu(
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
Shape out_shape;
|
||||
if (!contiguous) {
|
||||
out_shape = outputs[0].shape();
|
||||
args.push_back((void*)out_shape.data());
|
||||
args.push_back((void*)shape.data());
|
||||
} else {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
encoder.dispatch(
|
||||
[fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
||||
encoder.dispatch([fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -295,7 +295,11 @@ inline void copy_inplace_dispatch(
|
||||
|
||||
} // namespace
|
||||
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
void copy_cpu_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
@@ -305,7 +309,7 @@ void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
||||
}
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
bool donated = set_copy_output_data(src, dst, ctype);
|
||||
if (donated && src.dtype() == dst.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
@@ -315,10 +319,10 @@ void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_inplace(src, dst, ctype, stream);
|
||||
copy_cpu_inplace(src, dst, ctype, stream);
|
||||
}
|
||||
|
||||
void copy_inplace(
|
||||
void copy_cpu_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
@@ -373,4 +377,10 @@ void copy_inplace(
|
||||
});
|
||||
}
|
||||
|
||||
array contiguous_copy_cpu(const array& arr, Stream stream) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||
return arr_copy;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -10,10 +10,14 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
void copy_cpu_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream);
|
||||
|
||||
void copy_inplace(
|
||||
void copy_cpu_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
@@ -26,4 +30,7 @@ void copy_inplace(
|
||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||
|
||||
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||
array contiguous_copy_cpu(const array& arr, Stream stream);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -13,9 +13,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return {arr, false};
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General, stream);
|
||||
return {arr_copy, true};
|
||||
return {contiguous_copy_cpu(arr, stream), true};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -34,8 +32,7 @@ void AllReduce::eval_cpu(
|
||||
}
|
||||
return in;
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy(in, arr_copy, CopyType::General, s);
|
||||
array arr_copy = contiguous_copy_cpu(in, s);
|
||||
out.copy_shared_buffer(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
|
174
mlx/backend/cpu/eig.cpp
Normal file
174
mlx/backend/cpu/eig.cpp
Normal file
@@ -0,0 +1,174 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void eig_impl(
|
||||
array& a,
|
||||
array& vectors,
|
||||
array& values,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using OT = std::complex<T>;
|
||||
auto a_ptr = a.data<T>();
|
||||
auto eig_ptr = values.data<OT>();
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(values);
|
||||
OT* vec_ptr = nullptr;
|
||||
if (compute_eigenvectors) {
|
||||
encoder.set_output_array(vectors);
|
||||
vec_ptr = vectors.data<OT>();
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
vec_ptr,
|
||||
eig_ptr,
|
||||
compute_eigenvectors,
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
char jobr = 'N';
|
||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||
int n_vecs_r = 1;
|
||||
int n_vecs_l = compute_eigenvectors ? N : 1;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
&work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
}
|
||||
|
||||
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
||||
auto vec_tmp_data =
|
||||
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
a_ptr,
|
||||
&N,
|
||||
eig_tmp,
|
||||
eig_tmp + N,
|
||||
vec_tmp,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
||||
}
|
||||
if (vec_ptr) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
if (eig_ptr[i].imag() != 0) {
|
||||
// This vector and the next are a pair
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {
|
||||
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
||||
vec_ptr[(i + 1) * N + j] = {
|
||||
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
||||
}
|
||||
i += 1;
|
||||
} else {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
||||
}
|
||||
}
|
||||
}
|
||||
vec_ptr += N * N;
|
||||
}
|
||||
a_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(a);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Eig::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const auto& a = inputs[0];
|
||||
auto& values = outputs[0];
|
||||
|
||||
auto vectors = compute_eigenvectors_
|
||||
? outputs[1]
|
||||
: array(a.shape(), complex64, nullptr, {});
|
||||
|
||||
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
||||
copy_cpu(
|
||||
a,
|
||||
a_copy,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
|
||||
values.set_data(allocator::malloc(values.nbytes()));
|
||||
|
||||
if (compute_eigenvectors_) {
|
||||
// Set the strides and flags so the eigenvectors
|
||||
// are in the columns of the output
|
||||
auto flags = vectors.flags();
|
||||
auto strides = vectors.strides();
|
||||
auto ndim = a.ndim();
|
||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||
|
||||
if (a.size() > 1) {
|
||||
flags.row_contiguous = false;
|
||||
if (ndim > 2) {
|
||||
flags.col_contiguous = false;
|
||||
} else {
|
||||
flags.col_contiguous = true;
|
||||
}
|
||||
}
|
||||
vectors.set_data(
|
||||
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
|
||||
}
|
||||
switch (a.dtype()) {
|
||||
case float32:
|
||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -12,6 +12,133 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct EighWork {};
|
||||
|
||||
template <typename T>
|
||||
struct EighWork<
|
||||
T,
|
||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||
using R = T;
|
||||
|
||||
char jobz;
|
||||
char uplo;
|
||||
int N;
|
||||
int lwork;
|
||||
int liwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EighWork(char jobz_, char uplo_, int N_)
|
||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||
}
|
||||
|
||||
void run(T* vectors, T* values) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vectors,
|
||||
&N,
|
||||
values,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(buffers[1].buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct EighWork<std::complex<float>> {
|
||||
using T = std::complex<float>;
|
||||
using R = float;
|
||||
|
||||
char jobz;
|
||||
char uplo;
|
||||
int N;
|
||||
int lwork;
|
||||
int lrwork;
|
||||
int liwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EighWork(char jobz_, char uplo_, int N_)
|
||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
|
||||
T work;
|
||||
R rwork;
|
||||
int iwork;
|
||||
heevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&rwork,
|
||||
&lrwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work.real());
|
||||
lrwork = static_cast<int>(rwork);
|
||||
liwork = iwork;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||
}
|
||||
|
||||
void run(T* vectors, R* values) {
|
||||
heevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vectors,
|
||||
&N,
|
||||
values,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||
&lrwork,
|
||||
static_cast<int*>(buffers[2].buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
if (jobz == 'V') {
|
||||
// We have pre-transposed the vectors but we also must conjugate them
|
||||
// when they are complex.
|
||||
//
|
||||
// We could vectorize this but it is so fast in comparison to heevd that
|
||||
// it doesn't really matter.
|
||||
for (int i = 0; i < N; i++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
*vectors = std::conj(*vectors);
|
||||
vectors++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void eigh_impl(
|
||||
array& vectors,
|
||||
@@ -19,8 +146,10 @@ void eigh_impl(
|
||||
const std::string& uplo,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using R = typename EighWork<T>::R;
|
||||
|
||||
auto vec_ptr = vectors.data<T>();
|
||||
auto eig_ptr = values.data<T>();
|
||||
auto eig_ptr = values.data<R>();
|
||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
@@ -33,49 +162,17 @@ void eigh_impl(
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
int lwork = -1;
|
||||
int liwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
EighWork<T> work(jobz, uplo, N);
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
||||
// Work loop
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vec_ptr,
|
||||
&N,
|
||||
eig_ptr,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
work.run(vec_ptr, eig_ptr);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
if (work.info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
<< work.info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -99,7 +196,7 @@ void Eigh::eval_cpu(
|
||||
|
||||
values.set_data(allocator::malloc(values.nbytes()));
|
||||
|
||||
copy(
|
||||
copy_cpu(
|
||||
a,
|
||||
vectors,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
@@ -131,6 +228,10 @@ void Eigh::eval_cpu(
|
||||
eigh_impl<double>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
eigh_impl<std::complex<float>>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||
|
@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(
|
||||
copy_cpu(
|
||||
in,
|
||||
out,
|
||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
|
@@ -257,15 +257,11 @@ void gather_axis(
|
||||
const array& ind,
|
||||
array& out,
|
||||
const int axis) {
|
||||
auto strides = ind.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = ind.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
||||
|
||||
strides = src.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
||||
auto shape = remove_index(ind.shape(), axis);
|
||||
ContiguousIterator ind_it(
|
||||
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
|
||||
ContiguousIterator src_it(
|
||||
shape, remove_index(src.strides(), axis), src.ndim() - 1);
|
||||
|
||||
auto ind_ptr = ind.data<IdxT>();
|
||||
auto src_ptr = src.data<T>();
|
||||
@@ -521,7 +517,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype, stream());
|
||||
copy_cpu(src, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
std::vector<array> inds;
|
||||
@@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
template <typename T, typename IdxT, typename OpT>
|
||||
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||
auto strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
||||
|
||||
strides = upd.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
||||
auto shape = remove_index(idx.shape(), axis);
|
||||
ContiguousIterator idx_it(
|
||||
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
|
||||
ContiguousIterator upd_it(
|
||||
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
|
||||
|
||||
auto idx_ptr = idx.data<IdxT>();
|
||||
auto upd_ptr = upd.data<T>();
|
||||
@@ -694,7 +686,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype, stream());
|
||||
copy_cpu(src, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(idx);
|
||||
|
@@ -115,7 +115,7 @@ void inverse_impl(
|
||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||
|
||||
// The inverse is computed in place, so just copy the input to the output.
|
||||
copy(
|
||||
copy_cpu(
|
||||
a,
|
||||
inv,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
|
@@ -2,14 +2,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// Required for Visual Studio.
|
||||
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
|
||||
#ifdef _MSC_VER
|
||||
#include <complex>
|
||||
#define LAPACK_COMPLEX_CUSTOM
|
||||
#define lapack_complex_float std::complex<float>
|
||||
#define lapack_complex_double std::complex<double>
|
||||
#endif
|
||||
#define lapack_complex_float_real(z) ((z).real())
|
||||
#define lapack_complex_float_imag(z) ((z).imag())
|
||||
#define lapack_complex_double_real(z) ((z).real())
|
||||
#define lapack_complex_double_imag(z) ((z).imag())
|
||||
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#include <Accelerate/Accelerate.h>
|
||||
@@ -32,7 +32,7 @@
|
||||
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
|
||||
#define INSTANTIATE_LAPACK_REAL(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, float>) { \
|
||||
@@ -42,11 +42,24 @@
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_TYPES(geqrf)
|
||||
INSTANTIATE_LAPACK_TYPES(orgqr)
|
||||
INSTANTIATE_LAPACK_TYPES(syevd)
|
||||
INSTANTIATE_LAPACK_TYPES(potrf)
|
||||
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
||||
INSTANTIATE_LAPACK_TYPES(getrf)
|
||||
INSTANTIATE_LAPACK_TYPES(getri)
|
||||
INSTANTIATE_LAPACK_TYPES(trtri)
|
||||
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||
INSTANTIATE_LAPACK_REAL(syevd)
|
||||
INSTANTIATE_LAPACK_REAL(geev)
|
||||
INSTANTIATE_LAPACK_REAL(potrf)
|
||||
INSTANTIATE_LAPACK_REAL(gesvdx)
|
||||
INSTANTIATE_LAPACK_REAL(getrf)
|
||||
INSTANTIATE_LAPACK_REAL(getri)
|
||||
INSTANTIATE_LAPACK_REAL(trtri)
|
||||
|
||||
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||
|
@@ -87,8 +87,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_cpu(x, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
|
@@ -31,7 +31,7 @@ void luf_impl(
|
||||
strides[ndim - 1] = M;
|
||||
strides[ndim - 2] = 1;
|
||||
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
copy_inplace(
|
||||
copy_cpu_inplace(
|
||||
a,
|
||||
lu,
|
||||
a.shape(),
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -52,6 +53,58 @@ inline void mask_matrix(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void segmented_mm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
const uint32_t* segments,
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides,
|
||||
size_t num_segments,
|
||||
const Shape& segments_shape,
|
||||
const Strides& segments_strides) {
|
||||
int ndim = a_shape.size();
|
||||
Shape a_copy = a_shape;
|
||||
Shape b_copy = b_shape;
|
||||
int32_t M = a_copy[ndim - 2];
|
||||
int32_t N = b_copy[ndim - 1];
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
uint32_t k_start =
|
||||
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
||||
uint32_t k_end =
|
||||
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||
if (k_end <= k_start) {
|
||||
std::fill_n(out + i * M * N, M * N, T(0));
|
||||
continue;
|
||||
}
|
||||
a_copy[ndim - 1] = k_end - k_start;
|
||||
b_copy[ndim - 2] = k_end - k_start;
|
||||
matmul<T>(
|
||||
a + k_start * a_strides[ndim - 1],
|
||||
b + k_start * b_strides[ndim - 2],
|
||||
out + i * M * N,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
1.0,
|
||||
0.0,
|
||||
1,
|
||||
a_copy,
|
||||
a_strides,
|
||||
b_copy,
|
||||
b_strides);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -71,21 +124,20 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector, s);
|
||||
copy_cpu(arr, arr_copy, CopyType::Vector, s);
|
||||
return std::make_tuple(false, stx, arr_copy, true);
|
||||
}
|
||||
return std::make_tuple(false, stx, arr, false);
|
||||
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector, s);
|
||||
copy_cpu(arr, arr_copy, CopyType::Vector, s);
|
||||
return std::make_tuple(true, sty, arr_copy, true);
|
||||
}
|
||||
return std::make_tuple(true, sty, arr, false);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General, s);
|
||||
int64_t stx = arr.shape(-1);
|
||||
array arr_copy = contiguous_copy_cpu(arr, s);
|
||||
return std::make_tuple(false, stx, arr_copy, true);
|
||||
}
|
||||
};
|
||||
@@ -333,7 +385,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||
int64_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, temps.back());
|
||||
}
|
||||
@@ -437,4 +489,121 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
auto check_transpose = [&s, &encoder](const array& x) {
|
||||
auto stx = x.strides()[x.ndim() - 2];
|
||||
auto sty = x.strides()[x.ndim() - 1];
|
||||
if (stx == x.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, x);
|
||||
} else if (stx == 1 && sty == x.shape(-2)) {
|
||||
return std::make_tuple(true, sty, x);
|
||||
} else {
|
||||
array xc(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_cpu(x, xc, CopyType::General, s);
|
||||
encoder.add_temporary(xc);
|
||||
int64_t stx = x.shape(-1);
|
||||
return std::make_tuple(false, stx, xc);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(inputs[0]);
|
||||
auto [b_transposed, ldb, b] = check_transpose(inputs[1]);
|
||||
auto& segments = inputs[2];
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(segments);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
segments = array::unsafe_weak_copy(segments),
|
||||
out_ptr = out.data<void>(),
|
||||
a_transposed = a_transposed,
|
||||
b_transposed = b_transposed,
|
||||
lda = lda,
|
||||
ldb = ldb]() {
|
||||
switch (a.dtype()) {
|
||||
case float64:
|
||||
segmented_mm<double>(
|
||||
a.data<double>(),
|
||||
b.data<double>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<double*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case float32:
|
||||
segmented_mm<float>(
|
||||
a.data<float>(),
|
||||
b.data<float>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<float*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case float16:
|
||||
segmented_mm<float16_t>(
|
||||
a.data<float16_t>(),
|
||||
b.data<float16_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<float16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case bfloat16:
|
||||
segmented_mm<bfloat16_t>(
|
||||
a.data<bfloat16_t>(),
|
||||
b.data<bfloat16_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<bfloat16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"Segmented mm supports only real float types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -81,7 +81,7 @@ void matmul_general(
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, stream);
|
||||
copy_cpu(arr, temps.back(), CopyType::General, stream);
|
||||
stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, temps.back());
|
||||
}
|
||||
@@ -132,14 +132,20 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
if (out.size() == 0) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return;
|
||||
}
|
||||
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1
|
||||
? CopyType::Scalar
|
||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy(c, out, ctype, stream());
|
||||
|
||||
copy_cpu(c, out, ctype, stream());
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
return;
|
||||
}
|
||||
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
||||
}
|
||||
|
||||
|
@@ -22,7 +22,7 @@ void reshape(const array& in, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||
copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
@@ -175,7 +175,7 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream());
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
}
|
||||
|
||||
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -198,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,7 +211,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(in, out, CopyType::General, stream());
|
||||
copy_cpu(in, out, CopyType::General, stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,7 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy(in, out, ctype, stream());
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
}
|
||||
|
||||
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -251,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
// Fill output with val
|
||||
copy(val, out, CopyType::Scalar, stream());
|
||||
copy_cpu(val, out, CopyType::Scalar, stream());
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
@@ -266,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||
copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
|
||||
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -340,7 +340,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto [in_offset, donated] =
|
||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||
copy_inplace(
|
||||
copy_cpu_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ out.shape(),
|
||||
@@ -372,11 +372,11 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
auto [out_offset, donated] =
|
||||
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
|
||||
copy_inplace(
|
||||
copy_cpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
@@ -412,14 +412,14 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] =
|
||||
prepare_slice(out, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
copy_inplace(
|
||||
copy_cpu_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
@@ -456,9 +456,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in.dtype() == bool_) {
|
||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||
in_tmp.copy_shared_buffer(in);
|
||||
copy_inplace(in_tmp, tmp, CopyType::General, stream());
|
||||
copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream());
|
||||
} else {
|
||||
copy_inplace(in, tmp, CopyType::General, stream());
|
||||
copy_cpu_inplace(in, tmp, CopyType::General, stream());
|
||||
}
|
||||
|
||||
auto flags = out.flags();
|
||||
|
@@ -26,7 +26,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
strides[in.ndim() - 2] = 1;
|
||||
strides[in.ndim() - 1] = M;
|
||||
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
||||
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||
copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
q.set_data(allocator::malloc(q.nbytes()));
|
||||
r.set_data(allocator::malloc(r.nbytes()));
|
||||
|
@@ -13,9 +13,18 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
|
||||
inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
|
||||
auto power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
template <typename T, int bits>
|
||||
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||
assert(bits == 3 || bits == 6);
|
||||
static_assert(bits == 3 || bits == 5 || bits == 6);
|
||||
if (bits == 3) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||
@@ -25,6 +34,16 @@ void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
||||
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
||||
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
||||
} else if (bits == 5) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x1f);
|
||||
w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));
|
||||
w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);
|
||||
w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));
|
||||
w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));
|
||||
w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);
|
||||
w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));
|
||||
w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);
|
||||
|
||||
} else if (bits == 6) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
||||
w_out[1] =
|
||||
@@ -46,8 +65,8 @@ void _qmm(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -65,7 +84,7 @@ void _qmm(
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||
if (bits == 3 || bits == 6) {
|
||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -104,8 +123,9 @@ void _qmm_t(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
|
||||
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -121,7 +141,7 @@ void _qmm_t(
|
||||
T bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||
if (bits == 3 || bits == 6) {
|
||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -304,6 +324,10 @@ void _qmm_dispatch_typed(
|
||||
_qmm_dispatch_group<T, 4>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 5:
|
||||
_qmm_dispatch_group<T, 5>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 6:
|
||||
_qmm_dispatch_group<T, 6>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
@@ -505,7 +529,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return arr;
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
}
|
||||
};
|
||||
@@ -555,7 +579,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return arr;
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
}
|
||||
};
|
||||
@@ -613,9 +637,8 @@ void quantize(
|
||||
float eps = 1e-7;
|
||||
|
||||
bool power_of_2_bits = is_power_of_2(bits);
|
||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
int el_per_int = get_pack_factor(bits, 32);
|
||||
int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||
size_t n_groups = w_size / group_size;
|
||||
|
||||
@@ -640,15 +663,21 @@ void quantize(
|
||||
}
|
||||
size_t out_idx = i * int_per_group;
|
||||
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
||||
uint32_t out_el = 0;
|
||||
uint64_t out_el = 0;
|
||||
for (int k = 0; k < el_per_int; ++k) {
|
||||
float w_el = w[w_idx + j * el_per_int + k];
|
||||
w_el = std::rint((w_el - bias) / scale);
|
||||
w_el = std::min(std::max(w_el, 0.0f), n_bins);
|
||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||
out_el |= static_cast<uint64_t>(w_el) << (k * bits);
|
||||
}
|
||||
if (power_of_2_bits) {
|
||||
out[out_idx + j] = out_el;
|
||||
} else if (bits == 5) {
|
||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
||||
out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;
|
||||
out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;
|
||||
} else {
|
||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||
@@ -683,9 +712,7 @@ void fast::AffineQuantize::eval_cpu(
|
||||
if (arr.flags().row_contiguous) {
|
||||
return std::make_pair(arr, false);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General, s);
|
||||
return std::make_pair(arr_copy, true);
|
||||
return std::make_pair(contiguous_copy_cpu(arr, s), true);
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -325,7 +325,15 @@ struct MaxReduce {
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
return simd::max(x);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
if (simd::any(x != x)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
return simd::max(x);
|
||||
};
|
||||
};
|
||||
@@ -342,7 +350,15 @@ struct MinReduce {
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
return simd::min(x);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
if (simd::any(x != x)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
return simd::min(x);
|
||||
};
|
||||
};
|
||||
@@ -527,10 +543,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
|
@@ -250,10 +250,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Ensure contiguity
|
||||
auto in = inputs[0];
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy(in, arr_copy, CopyType::General, stream());
|
||||
in = arr_copy;
|
||||
encoder.add_temporary(arr_copy);
|
||||
in = contiguous_copy_cpu(in, stream());
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
|
@@ -131,8 +131,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_cpu(x, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
|
@@ -8,7 +8,7 @@
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -333,45 +333,24 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
int axis = axis_;
|
||||
if (axis < 0) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream());
|
||||
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch(
|
||||
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(out, axis_);
|
||||
case uint8:
|
||||
return sort<uint8_t>(out, axis_);
|
||||
case uint16:
|
||||
return sort<uint16_t>(out, axis_);
|
||||
case uint32:
|
||||
return sort<uint32_t>(out, axis_);
|
||||
case uint64:
|
||||
return sort<uint64_t>(out, axis_);
|
||||
case int8:
|
||||
return sort<int8_t>(out, axis_);
|
||||
case int16:
|
||||
return sort<int16_t>(out, axis_);
|
||||
case int32:
|
||||
return sort<int32_t>(out, axis_);
|
||||
case int64:
|
||||
return sort<int64_t>(out, axis_);
|
||||
case float32:
|
||||
return sort<float>(out, axis_);
|
||||
case float64:
|
||||
return sort<double>(out, axis_);
|
||||
case float16:
|
||||
return sort<float16_t>(out, axis_);
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(out, axis_);
|
||||
case complex64:
|
||||
return sort<complex64_t>(out, axis_);
|
||||
}
|
||||
});
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
|
||||
dispatch_all_types(out.dtype(), [&](auto type_tag) {
|
||||
sort<MLX_GET_TYPE(type_tag)>(out, axis);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -426,8 +405,10 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream());
|
||||
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
|
@@ -31,7 +31,7 @@ void svd_impl(
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
copy(
|
||||
copy_cpu(
|
||||
a,
|
||||
in,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
|
@@ -2,32 +2,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
|
170
mlx/backend/cuda/CMakeLists.txt
Normal file
170
mlx/backend/cuda/CMakeLists.txt
Normal file
@@ -0,0 +1,170 @@
|
||||
# Filename rules in cuda backend:
|
||||
#
|
||||
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
||||
# * Device-only code should be put in device/ subdir.
|
||||
# * Files in device/ subdir should not include files outside.
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||
target_sources(
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
|
||||
else()
|
||||
target_sources(
|
||||
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
|
||||
endif()
|
||||
|
||||
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||
|
||||
# Embed kernel sources in binary for JIT compilation.
|
||||
file(
|
||||
GLOB MLX_JIT_SOURCES
|
||||
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
|
||||
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
||||
add_custom_command(
|
||||
OUTPUT gen/cuda_jit_sources.h
|
||||
COMMAND
|
||||
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
|
||||
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
|
||||
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
|
||||
add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)
|
||||
add_dependencies(mlx cuda_jit_sources)
|
||||
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||
|
||||
# Enable defining device lambda functions.
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
|
||||
# Enable calling host constexpr functions from device. This is needed because
|
||||
# the constexpr version of isnan is host only.
|
||||
target_compile_options(
|
||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
|
||||
|
||||
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
|
||||
# Explicitly pass this flag to suppress the warning, it is safe to set it to
|
||||
# true but the warning wouldn't be suppressed.
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||
target_compile_options(
|
||||
mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>")
|
||||
endif()
|
||||
|
||||
# Suppress warning when building for compute capability 7 used by V100.
|
||||
target_compile_options(
|
||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
|
||||
|
||||
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
|
||||
# and requires drivers released after CUDA 12.4.
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||
target_compile_options(
|
||||
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
||||
endif()
|
||||
|
||||
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
||||
# managed memory.
|
||||
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||
set(MLX_CUDA_ARCHITECTURES "native")
|
||||
endif()
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
"${MLX_CUDA_ARCHITECTURES}")
|
||||
|
||||
# Use fixed version of CCCL.
|
||||
FetchContent_Declare(
|
||||
cccl
|
||||
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
||||
FetchContent_MakeAvailable(cccl)
|
||||
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
|
||||
|
||||
# Use fixed version of NVTX.
|
||||
FetchContent_Declare(
|
||||
nvtx3
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
|
||||
GIT_TAG v3.1.1
|
||||
GIT_SHALLOW TRUE
|
||||
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(nvtx3)
|
||||
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
|
||||
|
||||
# Make cuda runtime APIs available in non-cuda files.
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||
|
||||
# Use cublasLt.
|
||||
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||
|
||||
# Use NVRTC and driver APIs.
|
||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
|
||||
# Use the frontend APIs of cuDNN.
|
||||
FetchContent_Declare(
|
||||
cudnn
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||
GIT_TAG v1.12.1
|
||||
GIT_SHALLOW TRUE
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
|
||||
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
|
||||
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
|
||||
FetchContent_MakeAvailable(cudnn)
|
||||
target_link_libraries(mlx PRIVATE cudnn_frontend)
|
||||
# Link with the actual cuDNN libraries.
|
||||
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
|
||||
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
|
||||
# Install CCCL headers for JIT.
|
||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
258
mlx/backend/cuda/allocator.cpp
Normal file
258
mlx/backend/cuda/allocator.cpp
Normal file
@@ -0,0 +1,258 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <fmt/format.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
constexpr int page_size = 16384;
|
||||
|
||||
// Any allocations smaller than this will try to use the small pool
|
||||
constexpr int small_block_size = 8;
|
||||
|
||||
// The small pool size in bytes. This should be a multiple of the host page
|
||||
// size and small_block_size.
|
||||
constexpr int small_pool_size = 4 * page_size;
|
||||
|
||||
SmallSizePool::SmallSizePool() {
|
||||
auto num_blocks = small_pool_size / small_block_size;
|
||||
buffer_ = new Block[num_blocks];
|
||||
|
||||
next_free_ = buffer_;
|
||||
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||
|
||||
auto curr = next_free_;
|
||||
for (size_t i = 1; i < num_blocks; ++i) {
|
||||
curr->next = buffer_ + i;
|
||||
curr = curr->next;
|
||||
}
|
||||
curr->next = nullptr;
|
||||
}
|
||||
|
||||
SmallSizePool::~SmallSizePool() {
|
||||
CHECK_CUDA_ERROR(cudaFree(data_));
|
||||
delete[] buffer_;
|
||||
}
|
||||
|
||||
CudaBuffer* SmallSizePool::malloc() {
|
||||
if (next_free_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
Block* b = next_free_;
|
||||
uint64_t i = next_free_ - buffer_;
|
||||
next_free_ = next_free_->next;
|
||||
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||
b->buf.size = small_block_size;
|
||||
return &b->buf;
|
||||
}
|
||||
|
||||
void SmallSizePool::free(CudaBuffer* buf) {
|
||||
auto b = reinterpret_cast<Block*>(buf);
|
||||
b->next = next_free_;
|
||||
next_free_ = b;
|
||||
}
|
||||
|
||||
bool SmallSizePool::in_pool(CudaBuffer* buf) {
|
||||
constexpr int num_blocks = (small_pool_size / small_block_size);
|
||||
auto b = reinterpret_cast<Block*>(buf);
|
||||
int64_t block_num = b - buffer_;
|
||||
return block_num >= 0 && block_num < num_blocks;
|
||||
}
|
||||
|
||||
CudaAllocator::CudaAllocator()
|
||||
: buffer_cache_(
|
||||
page_size,
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
max_pool_size_ = memory_limit_;
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
// Find available buffer from cache.
|
||||
auto orig_size = size;
|
||||
std::unique_lock lock(mutex_);
|
||||
if (size <= small_block_size) {
|
||||
size = 8;
|
||||
} else if (size < page_size) {
|
||||
size = next_power_of_2(size);
|
||||
} else {
|
||||
size = page_size * ((size + page_size - 1) / page_size);
|
||||
}
|
||||
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
||||
int64_t mem_to_free =
|
||||
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
||||
if (mem_to_free > 0) {
|
||||
buffer_cache_.release_cached_buffers(mem_to_free);
|
||||
}
|
||||
|
||||
// Try the scalar pool first
|
||||
if (size <= small_block_size) {
|
||||
buf = scalar_pool_.malloc();
|
||||
}
|
||||
lock.unlock();
|
||||
if (!buf) {
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
lock.lock();
|
||||
}
|
||||
active_memory_ += size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
|
||||
// Maintain the cache below the requested limit.
|
||||
if (get_cache_memory() > max_pool_size_) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
void CudaAllocator::free(Buffer buffer) {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock lock(mutex_);
|
||||
active_memory_ -= buf->size;
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
cuda_free(buf);
|
||||
}
|
||||
}
|
||||
|
||||
size_t CudaAllocator::size(Buffer buffer) const {
|
||||
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||
if (!buf) {
|
||||
return 0;
|
||||
}
|
||||
return buf->size;
|
||||
}
|
||||
|
||||
// This must be called with mutex_ aquired
|
||||
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
if (scalar_pool_.in_pool(buf)) {
|
||||
scalar_pool_.free(buf);
|
||||
} else {
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_active_memory() const {
|
||||
return active_memory_;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_peak_memory() const {
|
||||
return peak_memory_;
|
||||
}
|
||||
|
||||
void CudaAllocator::reset_peak_memory() {
|
||||
std::lock_guard lock(mutex_);
|
||||
peak_memory_ = 0;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_memory_limit() {
|
||||
return memory_limit_;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::set_memory_limit(size_t limit) {
|
||||
std::lock_guard lock(mutex_);
|
||||
std::swap(limit, memory_limit_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_cache_memory() const {
|
||||
return buffer_cache_.cache_size();
|
||||
}
|
||||
|
||||
size_t CudaAllocator::set_cache_limit(size_t limit) {
|
||||
std::lock_guard lk(mutex_);
|
||||
std::swap(limit, max_pool_size_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
void CudaAllocator::clear_cache() {
|
||||
std::lock_guard lk(mutex_);
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
CudaAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
// can save some time at program exit.
|
||||
static CudaAllocator* allocator_ = new CudaAllocator;
|
||||
return *allocator_;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
return cu::allocator();
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
if (!ptr_) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<cu::CudaBuffer*>(ptr_)->data;
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
size_t get_active_memory() {
|
||||
return cu::allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return cu::allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {
|
||||
return cu::allocator().reset_peak_memory();
|
||||
}
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return cu::allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return cu::allocator().get_memory_limit();
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return cu::allocator().get_cache_memory();
|
||||
}
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return cu::allocator().set_cache_limit(limit);
|
||||
}
|
||||
void clear_cache() {
|
||||
cu::allocator().clear_cache();
|
||||
}
|
||||
|
||||
// Not supported in CUDA.
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
77
mlx/backend/cuda/allocator.h
Normal file
77
mlx/backend/cuda/allocator.h
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/buffer_cache.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
// Stores cuda-managed unified memory.
|
||||
struct CudaBuffer {
|
||||
void* data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
class SmallSizePool {
|
||||
private:
|
||||
union Block {
|
||||
Block* next;
|
||||
CudaBuffer buf;
|
||||
};
|
||||
|
||||
Block* buffer_{nullptr};
|
||||
void* data_{nullptr};
|
||||
Block* next_free_{nullptr};
|
||||
|
||||
public:
|
||||
SmallSizePool();
|
||||
~SmallSizePool();
|
||||
|
||||
SmallSizePool(const SmallSizePool&) = delete;
|
||||
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
||||
|
||||
CudaBuffer* malloc();
|
||||
void free(CudaBuffer* buf);
|
||||
bool in_pool(CudaBuffer* buf);
|
||||
};
|
||||
|
||||
class CudaAllocator : public allocator::Allocator {
|
||||
public:
|
||||
Buffer malloc(size_t size) override;
|
||||
void free(Buffer buffer) override;
|
||||
size_t size(Buffer buffer) const override;
|
||||
|
||||
size_t get_active_memory() const;
|
||||
size_t get_peak_memory() const;
|
||||
void reset_peak_memory();
|
||||
size_t get_memory_limit();
|
||||
size_t set_memory_limit(size_t limit);
|
||||
size_t get_cache_memory() const;
|
||||
size_t set_cache_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
void cuda_free(CudaBuffer* buf);
|
||||
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t max_pool_size_;
|
||||
BufferCache<CudaBuffer> buffer_cache_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
SmallSizePool scalar_pool_;
|
||||
};
|
||||
|
||||
CudaAllocator& allocator();
|
||||
|
||||
} // namespace mlx::core::cu
|
55
mlx/backend/cuda/arange.cu
Normal file
55
mlx/backend/cuda/arange.cu
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/transform.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
template <typename T>
|
||||
struct Arange {
|
||||
const T start;
|
||||
const T step;
|
||||
|
||||
__device__ T operator()(uint32_t i) const {
|
||||
return start + i * step;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Arange::eval_gpu");
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||
using OutType = cuda_type_t<CTYPE>;
|
||||
CTYPE step =
|
||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||
thrust::transform(
|
||||
cu::thrust_policy(encoder.stream()),
|
||||
thrust::counting_iterator<uint32_t>(0),
|
||||
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||
thrust::device_pointer_cast(out.data<OutType>()),
|
||||
cu::Arange<OutType>{
|
||||
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
188
mlx/backend/cuda/arg_reduce.cu
Normal file
188
mlx/backend/cuda/arg_reduce.cu
Normal file
@@ -0,0 +1,188 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T>
|
||||
struct IndexValPair {
|
||||
uint32_t index;
|
||||
T val;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ArgMin {
|
||||
constexpr __device__ T init() {
|
||||
return Limits<T>::max();
|
||||
}
|
||||
|
||||
__device__ IndexValPair<T> operator()(
|
||||
const IndexValPair<T>& best,
|
||||
const IndexValPair<T>& current) {
|
||||
if (best.val > current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<T> reduce_many(
|
||||
IndexValPair<T> best,
|
||||
const AlignedVector<T, N>& vals,
|
||||
uint32_t offset) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] < best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ArgMax {
|
||||
constexpr __device__ T init() {
|
||||
return Limits<T>::min();
|
||||
}
|
||||
|
||||
__device__ IndexValPair<T> operator()(
|
||||
const IndexValPair<T>& best,
|
||||
const IndexValPair<T>& current) {
|
||||
if (best.val < current.val ||
|
||||
(best.val == current.val && best.index > current.index)) {
|
||||
return current;
|
||||
} else {
|
||||
return best;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ IndexValPair<T> reduce_many(
|
||||
IndexValPair<T> best,
|
||||
const AlignedVector<T, N>& vals,
|
||||
uint32_t offset) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (vals[i] > best.val) {
|
||||
best.val = vals[i];
|
||||
best.index = offset + i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
|
||||
__global__ void arg_reduce_general(
|
||||
const T* in,
|
||||
uint32_t* out,
|
||||
size_t size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides in_strides,
|
||||
const __grid_constant__ Strides out_strides,
|
||||
int32_t ndim,
|
||||
int64_t axis_stride,
|
||||
int32_t axis_size) {
|
||||
auto block = cg::this_thread_block();
|
||||
|
||||
int64_t index = cg::this_grid().block_rank();
|
||||
if (index >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
||||
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
||||
in += in_idx;
|
||||
|
||||
Op op;
|
||||
T init = op.init();
|
||||
IndexValPair<T> best{0, init};
|
||||
|
||||
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
|
||||
best = op.reduce_many(best, vals, tid * N_READS);
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
|
||||
__shared__ typename BlockReduceT::TempStorage temp;
|
||||
|
||||
best = BlockReduceT(temp).Reduce(best, op);
|
||||
|
||||
if (block.thread_rank() == 0) {
|
||||
out[out_idx] = best.index;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& s = stream();
|
||||
|
||||
// Prepare the shapes, strides and axis arguments.
|
||||
Shape shape = remove_index(in.shape(), axis_);
|
||||
Strides in_strides = remove_index(in.strides(), axis_);
|
||||
Strides out_strides = out.ndim() == in.ndim()
|
||||
? remove_index(out.strides(), axis_)
|
||||
: out.strides();
|
||||
int64_t axis_stride = in.strides()[axis_];
|
||||
int32_t axis_size = in.shape()[axis_];
|
||||
int32_t ndim = shape.size();
|
||||
|
||||
// ArgReduce.
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
|
||||
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
constexpr uint32_t N_READS = 4;
|
||||
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||
auto kernel =
|
||||
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
|
||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||
kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
|
||||
}
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dim(),
|
||||
0,
|
||||
in.data<T>(),
|
||||
out.data<uint32_t>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(in_strides),
|
||||
const_param(out_strides),
|
||||
ndim,
|
||||
axis_stride,
|
||||
axis_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
150
mlx/backend/cuda/bin2h.cmake
Normal file
150
mlx/backend/cuda/bin2h.cmake
Normal file
@@ -0,0 +1,150 @@
|
||||
# Based on: https://github.com/sivachandran/cmake-bin2h
|
||||
#
|
||||
# Copyright 2020 Sivachandran Paramasivam
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
include(CMakeParseArguments)
|
||||
|
||||
# Function to wrap a given string into multiple lines at the given column
|
||||
# position.
|
||||
#
|
||||
# Parameters:
|
||||
#
|
||||
# * VARIABLE - The name of the CMake variable holding the string.
|
||||
# * AT_COLUMN - The column position at which string will be wrapped.
|
||||
function(WRAP_STRING)
|
||||
set(oneValueArgs VARIABLE AT_COLUMN)
|
||||
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
|
||||
|
||||
string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)
|
||||
math(EXPR offset "0")
|
||||
|
||||
while(stringLength GREATER 0)
|
||||
if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})
|
||||
math(EXPR length "${WRAP_STRING_AT_COLUMN}")
|
||||
else()
|
||||
math(EXPR length "${stringLength}")
|
||||
endif()
|
||||
|
||||
string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)
|
||||
set(lines "${lines}\n ${line}")
|
||||
|
||||
math(EXPR stringLength "${stringLength} - ${length}")
|
||||
math(EXPR offset "${offset} + ${length}")
|
||||
endwhile()
|
||||
|
||||
set(${WRAP_STRING_VARIABLE}
|
||||
"${lines}"
|
||||
PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Function to embed contents of a file as byte array in C/C++ header file(.h).
|
||||
# The header file will contain a byte array and integer variable holding the
|
||||
# size of the array.
|
||||
#
|
||||
# Parameters:
|
||||
#
|
||||
# * SOURCE_FILES - The paths of source files whose contents will be embedded in
|
||||
# the header file.
|
||||
# * VARIABLE_NAME - The name of the variable for the byte array. The string
|
||||
# "_SIZE" will be append to this name and will be used a variable name for
|
||||
# size variable.
|
||||
# * HEADER_FILE - The path of header file.
|
||||
# * APPEND - If specified appends to the header file instead of overwriting it
|
||||
# * HEADER_NAMESPACE - The namespace, where the array should be located in.
|
||||
# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte
|
||||
# array.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG")
|
||||
function(BIN2H)
|
||||
set(options APPEND NULL_TERMINATE)
|
||||
set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)
|
||||
set(multiValueArgs SOURCE_FILES)
|
||||
cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}"
|
||||
"${multiValueArgs}" ${ARGN})
|
||||
|
||||
set(arrayDefinition "")
|
||||
foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)
|
||||
# get filename without extension
|
||||
get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)
|
||||
# convert the filename to a valid C identifier
|
||||
string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME)
|
||||
|
||||
# reads source file contents as hex string
|
||||
file(READ ${SOURCE_FILE} hexString HEX)
|
||||
|
||||
# append null
|
||||
if(BIN2H_NULL_TERMINATE)
|
||||
string(APPEND hexString "00")
|
||||
endif()
|
||||
|
||||
# wraps the hex string into multiple lines
|
||||
wrap_string(VARIABLE hexString AT_COLUMN 24)
|
||||
|
||||
# strip the © in source code
|
||||
string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString})
|
||||
|
||||
string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues
|
||||
${arrayValues})
|
||||
|
||||
# make a full variable name for the array
|
||||
set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}")
|
||||
|
||||
# declares byte array and the length variables
|
||||
string(APPEND arrayDefinition
|
||||
"constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n")
|
||||
endforeach()
|
||||
|
||||
# add namespace wrapper if defined
|
||||
if(DEFINED BIN2H_HEADER_NAMESPACE)
|
||||
set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {")
|
||||
set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}")
|
||||
set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n")
|
||||
endif()
|
||||
|
||||
set(arrayIncludes "#pragma once")
|
||||
string(PREPEND declarations "${arrayIncludes}\n\n")
|
||||
|
||||
if(BIN2H_APPEND)
|
||||
file(APPEND ${BIN2H_HEADER_FILE} "${declarations}")
|
||||
else()
|
||||
file(WRITE ${BIN2H_HEADER_FILE} "${declarations}")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# ----------------------------- CLI args -----------------------------
|
||||
|
||||
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
|
||||
foreach(source ${MLX_JIT_SOURCES_LIST})
|
||||
list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}")
|
||||
endforeach()
|
||||
|
||||
bin2h(
|
||||
SOURCE_FILES
|
||||
${MLX_JIT_SOURCES_ABS}
|
||||
NULL_TERMINATE
|
||||
VARIABLE_NAME
|
||||
"jit_source"
|
||||
HEADER_NAMESPACE
|
||||
"mlx::core"
|
||||
HEADER_FILE
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h")
|
357
mlx/backend/cuda/binary.cu
Normal file
357
mlx/backend/cuda/binary.cu
Normal file
@@ -0,0 +1,357 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_READS > size) {
|
||||
for (int i = index * N_READS; i < size; ++i) {
|
||||
out[i] = Op{}(a[0], b[0]);
|
||||
}
|
||||
} else {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(a[0], b[0]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_READS > size) {
|
||||
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||
out[i] = Op{}(a[0], b[i]);
|
||||
}
|
||||
} else {
|
||||
auto b_vec = load_vector<N_READS>(b, index);
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(a[0], b_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_READS > size) {
|
||||
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||
out[i] = Op{}(a[i], b[0]);
|
||||
}
|
||||
} else {
|
||||
auto a_vec = load_vector<N_READS>(a, index);
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(a_vec[i], b[0]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_READS > size) {
|
||||
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||
out[i] = Op{}(a[i], b[i]);
|
||||
}
|
||||
} else {
|
||||
auto a_vec = load_vector<N_READS>(a, index);
|
||||
auto b_vec = load_vector<N_READS>(b, index);
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_binary_op() {
|
||||
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
|
||||
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
|
||||
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
|
||||
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
|
||||
return std::is_same_v<In, Out>;
|
||||
}
|
||||
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
|
||||
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
|
||||
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
|
||||
return std::is_same_v<Out, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
|
||||
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
||||
}
|
||||
if (std::is_same_v<Op, NaNEqual>) {
|
||||
return std::is_same_v<Out, bool> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LogAddExp>) {
|
||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, ArcTan2>) {
|
||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
||||
std::is_same_v<Op, BitwiseXor>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In>;
|
||||
}
|
||||
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
|
||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||
!std::is_same_v<In, bool>;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::tie(shape, strides) = collapse_contiguous_dims(a, b, out);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
cu::binary_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
cu::binary_g<Op, InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
constexpr int N_READS = 16 / sizeof(InType);
|
||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_vs<Op, InType, OutType, IdxT, N_READS>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out.data<OutType>(),
|
||||
out.data_size());
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||
}
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||
auto& s = out.primitive().stream(); \
|
||||
binary_op_gpu<cu::func>(inputs, out, name(), s); \
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
BINARY_GPU(LessEqual)
|
||||
BINARY_GPU(LogicalAnd)
|
||||
BINARY_GPU(LogicalOr)
|
||||
BINARY_GPU(LogAddExp)
|
||||
BINARY_GPU(Maximum)
|
||||
BINARY_GPU(Minimum)
|
||||
BINARY_GPU(Multiply)
|
||||
BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
if (equal_nan_) {
|
||||
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
|
||||
} else {
|
||||
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
|
||||
}
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||
auto& s = out.primitive().stream();
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
333
mlx/backend/cuda/binary_two.cu
Normal file
333
mlx/backend/cuda/binary_two.cu
Normal file
@@ -0,0 +1,333 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void
|
||||
binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_READS > size) {
|
||||
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||
auto out = Op{}(a[0], b[0]);
|
||||
out_a[i] = out[0];
|
||||
out_b[i] = out[1];
|
||||
}
|
||||
} else {
|
||||
AlignedVector<Out, N_READS> out_a_vec;
|
||||
AlignedVector<Out, N_READS> out_b_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
auto out = Op{}(a[0], b[0]);
|
||||
out_a_vec[i] = out[0];
|
||||
out_b_vec[i] = out[1];
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void
|
||||
binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_READS > size) {
|
||||
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||
auto out = Op{}(a[0], b[i]);
|
||||
out_a[i] = out[0];
|
||||
out_b[i] = out[1];
|
||||
}
|
||||
} else {
|
||||
auto b_vec = load_vector<N_READS>(b, index);
|
||||
|
||||
AlignedVector<Out, N_READS> out_a_vec;
|
||||
AlignedVector<Out, N_READS> out_b_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
auto out = Op{}(a[0], b_vec[i]);
|
||||
out_a_vec[i] = out[0];
|
||||
out_b_vec[i] = out[1];
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void
|
||||
binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_READS > size) {
|
||||
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||
auto out = Op{}(a[i], b[0]);
|
||||
out_a[i] = out[0];
|
||||
out_b[i] = out[1];
|
||||
}
|
||||
} else {
|
||||
auto a_vec = load_vector<N_READS>(a, index);
|
||||
|
||||
AlignedVector<Out, N_READS> out_a_vec;
|
||||
AlignedVector<Out, N_READS> out_b_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
auto out = Op{}(a_vec[i], b[0]);
|
||||
out_a_vec[i] = out[0];
|
||||
out_b_vec[i] = out[1];
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void
|
||||
binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_READS > size) {
|
||||
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||
auto out = Op{}(a[i], b[i]);
|
||||
out_a[i] = out[0];
|
||||
out_b[i] = out[1];
|
||||
}
|
||||
} else {
|
||||
auto a_vec = load_vector<N_READS>(a, index);
|
||||
auto b_vec = load_vector<N_READS>(b, index);
|
||||
|
||||
AlignedVector<Out, N_READS> out_a_vec;
|
||||
AlignedVector<Out, N_READS> out_b_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
auto out = Op{}(a_vec[i], b_vec[i]);
|
||||
out_a_vec[i] = out[0];
|
||||
out_b_vec[i] = out[1];
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void binary_two_g_nd(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), a_strides.data(), b_strides.data());
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out, typename IdxT>
|
||||
__global__ void binary_two_g(
|
||||
const In* a,
|
||||
const In* b,
|
||||
Out* out_a,
|
||||
Out* out_b,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides a_strides,
|
||||
const __grid_constant__ Strides b_strides,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [a_idx, b_idx] = elem_to_loc(
|
||||
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||
out_a[index] = out[0];
|
||||
out_b[index] = out[1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename In, typename Out>
|
||||
constexpr bool supports_binary_two_op() {
|
||||
if (std::is_same_v<Op, DivMod>) {
|
||||
return std::is_same_v<In, Out> &&
|
||||
(std::is_integral_v<Out> || is_floating_v<Out>);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
template <typename Op>
|
||||
void binary_two_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() > 1);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
if (out_a.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
|
||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||
if constexpr (cu::supports_binary_two_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||
using InType = cuda_type_t<CTYPE_IN>;
|
||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (bopt == BinaryOpType::General) {
|
||||
dispatch_bool(
|
||||
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||
out_a.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
Shape shape;
|
||||
std::vector<Strides> strides;
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(a, b, out_a);
|
||||
auto& a_strides = strides[0];
|
||||
auto& b_strides = strides[1];
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(out_a, large());
|
||||
encoder.add_kernel_node(
|
||||
cu::binary_two_g_nd<
|
||||
Op,
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(a_strides),
|
||||
const_param<dims_constant()>(b_strides));
|
||||
});
|
||||
} else {
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(out_a, large());
|
||||
encoder.add_kernel_node(
|
||||
cu::binary_two_g<Op, InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.size(),
|
||||
const_param(shape),
|
||||
const_param(a_strides),
|
||||
const_param(b_strides),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
constexpr int N_READS = 16 / sizeof(InType);
|
||||
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
kernel = cu::binary_two_vs<Op, InType, OutType, IdxT, N_READS>;
|
||||
} else if (bopt == BinaryOpType::VectorVector) {
|
||||
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
out_a.data_size(),
|
||||
out_a.shape(),
|
||||
out_a.strides(),
|
||||
large(),
|
||||
N_READS);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
a.data<InType>(),
|
||||
b.data<InType>(),
|
||||
out_a.data<OutType>(),
|
||||
out_b.data<OutType>(),
|
||||
out_a.data_size());
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||
op,
|
||||
dtype_to_string(a.dtype()),
|
||||
dtype_to_string(out_a.dtype())));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_two_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const char* op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_two_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("DivMod::eval_gpu");
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_two_op_gpu<cu::DivMod>(inputs, outputs, name(), s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
340
mlx/backend/cuda/compiled.cpp
Normal file
340
mlx/backend/cuda/compiled.cpp
Normal file
@@ -0,0 +1,340 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
struct FusedKernelBuilder {
|
||||
std::string os;
|
||||
const std::string& kernel_name;
|
||||
const std::vector<array>& inputs;
|
||||
const std::vector<array>& outputs;
|
||||
const std::vector<array>& tape;
|
||||
const std::function<bool(size_t)>& is_constant;
|
||||
|
||||
void build(const char* name, bool contiguous) {
|
||||
NodeNamer namer;
|
||||
|
||||
// Function parameters.
|
||||
std::vector<std::string> params;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
params.push_back(
|
||||
fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname));
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
params.push_back(fmt::format(
|
||||
"const __grid_constant__ cuda::std::array<int64_t, NDIM> {}_strides",
|
||||
xname));
|
||||
}
|
||||
}
|
||||
for (const auto& x : outputs) {
|
||||
params.push_back(fmt::format(
|
||||
"{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x)));
|
||||
}
|
||||
if (!contiguous) {
|
||||
params.push_back(
|
||||
"const __grid_constant__ cuda::std::array<int32_t, NDIM> shape");
|
||||
}
|
||||
params.push_back("IdxT size");
|
||||
|
||||
// Build function signature.
|
||||
if (contiguous) {
|
||||
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||
} else {
|
||||
os +=
|
||||
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||
}
|
||||
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
os += " ";
|
||||
os += params[i];
|
||||
if (i != params.size() - 1) {
|
||||
os += ",\n";
|
||||
}
|
||||
}
|
||||
os += ") {\n";
|
||||
|
||||
// Index. For non contiguous kernels we create a separate index
|
||||
// variable per variable otherwise everyone uses `index`.
|
||||
os +=
|
||||
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
|
||||
" if (index >= size) {\n"
|
||||
" return;\n"
|
||||
" }\n";
|
||||
if (!contiguous) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
os += " IdxT " + xname + "_idx = 0;\n";
|
||||
}
|
||||
os += " {\n";
|
||||
os += " IdxT loc = index;\n";
|
||||
os +=
|
||||
" #pragma unroll\n"
|
||||
" for (int i = NDIM - 1; i >= 0; i--) {\n";
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
|
||||
"_strides[i]);\n";
|
||||
}
|
||||
os +=
|
||||
" loc /= shape[i];\n"
|
||||
" }\n"
|
||||
" }\n";
|
||||
}
|
||||
|
||||
// Vectorized read loop
|
||||
if (contiguous) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
os += fmt::format(
|
||||
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
|
||||
xname,
|
||||
type);
|
||||
}
|
||||
}
|
||||
|
||||
// Create some space for the outputs
|
||||
for (const auto& x : outputs) {
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
os += fmt::format(
|
||||
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
|
||||
}
|
||||
|
||||
// Work loop
|
||||
if (!contiguous) {
|
||||
os +=
|
||||
"\n"
|
||||
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||
} else {
|
||||
os +=
|
||||
"\n"
|
||||
" #pragma unroll\n"
|
||||
" for (int i = 0; i < work_per_thread; i++) {\n";
|
||||
}
|
||||
|
||||
// Read inputs.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
std::string value;
|
||||
if (is_constant(i)) {
|
||||
std::ostringstream ss;
|
||||
print_constant(ss, x);
|
||||
value = fmt::format("static_cast<{}>({})", type, ss.str());
|
||||
} else if (is_scalar(x)) {
|
||||
value = fmt::format("{}[0]", xname);
|
||||
} else if (contiguous) {
|
||||
value = fmt::format("vec_{}[i]", xname);
|
||||
} else {
|
||||
value = fmt::format("{}[{}_idx]", xname, xname);
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write tape.
|
||||
for (const auto& x : tape) {
|
||||
const std::string& xname = namer.get_name(x);
|
||||
std::string type = dtype_to_cuda_type(x.dtype());
|
||||
std::string value;
|
||||
if (is_static_cast(x.primitive())) {
|
||||
value = fmt::format(
|
||||
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
||||
} else {
|
||||
value = x.primitive().name();
|
||||
value += "{}(";
|
||||
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
||||
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
||||
}
|
||||
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
||||
}
|
||||
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||
}
|
||||
|
||||
// Write output.
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
|
||||
}
|
||||
|
||||
// End of work loop
|
||||
if (!contiguous) {
|
||||
os += "\n";
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
const std::string& xname = namer.get_name(x);
|
||||
if (is_scalar(x) || is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
|
||||
}
|
||||
}
|
||||
os += " }\n";
|
||||
|
||||
// Store the output to global memory
|
||||
for (const auto& x : outputs) {
|
||||
os += fmt::format(
|
||||
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
|
||||
namer.get_name(x));
|
||||
}
|
||||
|
||||
os += "}\n";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cu
|
||||
|
||||
constexpr const char* g_jit_includes = R"(
|
||||
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#define inf cuda::std::numeric_limits<float>::infinity()
|
||||
)";
|
||||
|
||||
void Compiled::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
nvtx3::scoped_range r("Compiled::eval_gpu");
|
||||
auto& s = stream();
|
||||
|
||||
// Determine the work per thread for the vectorized reads/writes. We take it
|
||||
// as 16 over the max itemsize for the outputs. Another heuristic could be
|
||||
// over the max itemsize of all arrays.
|
||||
int max_size = 1;
|
||||
for (const auto& x : outputs) {
|
||||
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
|
||||
}
|
||||
int work_per_thread = 16 / max_size;
|
||||
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
||||
// Build source code.
|
||||
cu::FusedKernelBuilder builder{
|
||||
g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_};
|
||||
builder.os +=
|
||||
"namespace mlx::core::cu {\n\n"
|
||||
"namespace cg = cooperative_groups;\n\n";
|
||||
builder.build("_contiguous", true);
|
||||
builder.os += "\n";
|
||||
builder.build("_strided", false);
|
||||
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||
// Build kernel names.
|
||||
std::vector<std::string> kernel_names;
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||
lib_name(),
|
||||
work_per_thread));
|
||||
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
|
||||
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||
});
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
auto [contiguous, shape, strides_vec] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Whether to use large index.
|
||||
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||
|
||||
cu::KernelArgs args;
|
||||
// Put inputs.
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
args.append(x);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
args.append_ptr(strides_vec[strides_index++].data());
|
||||
}
|
||||
}
|
||||
|
||||
// Put outputs.
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
for (auto& x : outputs) {
|
||||
args.append(x);
|
||||
}
|
||||
|
||||
// Put shape and size.
|
||||
if (!contiguous) {
|
||||
args.append_ptr(shape.data());
|
||||
}
|
||||
if (large) {
|
||||
args.append<int64_t>(outputs[0].data_size());
|
||||
} else {
|
||||
args.append<uint32_t>(outputs[0].data_size());
|
||||
}
|
||||
|
||||
// Choose work per thread
|
||||
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||
work_per_thread = 1;
|
||||
}
|
||||
|
||||
// Launch kernel.
|
||||
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||
if (contiguous) {
|
||||
kernel_name +=
|
||||
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
|
||||
} else {
|
||||
kernel_name += fmt::format(
|
||||
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
|
||||
}
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
for (const auto& out : outputs) {
|
||||
encoder.set_output_array(out);
|
||||
}
|
||||
|
||||
auto kernel = mod.get_kernel(kernel_name);
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(outputs[0], large, work_per_thread);
|
||||
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
546
mlx/backend/cuda/conv.cpp
Normal file
546
mlx/backend/cuda/conv.cpp
Normal file
@@ -0,0 +1,546 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
// cudnn_frontend.h redefines this macro.
|
||||
#undef CHECK_CUDA_ERROR
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include <cudnn_frontend_find_plan.h>
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Not all engines support it so can not use this API now.
|
||||
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
||||
|
||||
// Alias for better readability.
|
||||
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
||||
#define CONV_BACKWARD_INPUT \
|
||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
|
||||
#define CONV_BACKWARD_WEIGHT \
|
||||
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
||||
|
||||
struct ConvCacheKey {
|
||||
int device_id;
|
||||
cudnnDataType_t cudnn_dtype;
|
||||
std::array<int, MAX_NDIM> input_shape;
|
||||
std::array<int, MAX_NDIM> weight_shape;
|
||||
std::array<int, MAX_NDIM> stride;
|
||||
std::array<int, MAX_NDIM> padding_lo;
|
||||
std::array<int, MAX_NDIM> padding_hi;
|
||||
std::array<int, MAX_NDIM> dilation;
|
||||
int groups;
|
||||
bool flip;
|
||||
uint8_t input_alignment;
|
||||
uint8_t weight_alignment;
|
||||
uint8_t output_alignment;
|
||||
};
|
||||
|
||||
auto& conv_cache() {
|
||||
static LRUBytesKeyCache<
|
||||
ConvCacheKey,
|
||||
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
|
||||
cache(/* capacity */ 128);
|
||||
return cache;
|
||||
}
|
||||
|
||||
template <typename T, typename Vec>
|
||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||
return SmallVector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
template <typename T, template <typename U> class Vec>
|
||||
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
}
|
||||
std::array<T, MAX_NDIM> result = {};
|
||||
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||
return result;
|
||||
}
|
||||
|
||||
auto nhwc_to_nchw(const array& x) {
|
||||
auto shape = convert_vector<int64_t>(x.shape());
|
||||
shape.insert(shape.begin() + 1, shape.back());
|
||||
shape.erase(shape.end() - 1);
|
||||
auto strides = convert_vector<int64_t>(x.strides());
|
||||
strides.insert(strides.begin() + 1, strides.back());
|
||||
strides.erase(strides.end() - 1);
|
||||
return std::make_tuple(std::move(shape), std::move(strides));
|
||||
}
|
||||
|
||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
return CUDNN_DATA_INT8;
|
||||
case int32:
|
||||
return CUDNN_DATA_INT32;
|
||||
case uint8:
|
||||
return CUDNN_DATA_UINT8;
|
||||
case float16:
|
||||
return CUDNN_DATA_HALF;
|
||||
case bfloat16:
|
||||
return CUDNN_DATA_BFLOAT16;
|
||||
case float32:
|
||||
return CUDNN_DATA_FLOAT;
|
||||
case float64:
|
||||
return CUDNN_DATA_DOUBLE;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
inline uint8_t get_alignment(const array& x) {
|
||||
uint8_t alignment = 1;
|
||||
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||
for (; alignment < 32; alignment *= 2) {
|
||||
if (address % (alignment * 2)) {
|
||||
return alignment;
|
||||
}
|
||||
}
|
||||
return alignment;
|
||||
}
|
||||
|
||||
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
||||
auto [shape, strides] = nhwc_to_nchw(x);
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(shape.size(), shape.data())
|
||||
.setStrides(strides.size(), strides.data())
|
||||
.setId(id)
|
||||
.setAlignment(get_alignment(x))
|
||||
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||
.build();
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigList get_engine_configs(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
cudnn_frontend::OperationGraph& op_graph,
|
||||
bool use_fallback = false) {
|
||||
cudnn_frontend::GeneratorSource source;
|
||||
if (use_fallback) {
|
||||
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setOperation(backend_type)
|
||||
.build();
|
||||
return fallback.getFallbackList();
|
||||
};
|
||||
} else {
|
||||
source = [](cudnn_frontend::OperationGraph& op_graph) {
|
||||
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||
.setOperationGraph(op_graph)
|
||||
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||
.build();
|
||||
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||
};
|
||||
}
|
||||
|
||||
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
||||
auto configs = generator.generate_engine_config(op_graph);
|
||||
|
||||
cudnn_frontend::EngineConfigList filtered_configs;
|
||||
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||
if (cudnn_frontend::hasNumericalNote<
|
||||
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||
return true;
|
||||
}
|
||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||
dtype == float32 && !env::enable_tf32()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return filtered_configs;
|
||||
}
|
||||
|
||||
bool execute_plan(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnn_frontend::ExecutionPlan& plan,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y) {
|
||||
int workspace_size = plan.getWorkspaceSize();
|
||||
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
||||
|
||||
int64_t uids[3] = {'x', 'w', 'y'};
|
||||
void* data_ptrs[3] = {
|
||||
x.data<void>(),
|
||||
w.data<void>(),
|
||||
y.data<void>(),
|
||||
};
|
||||
|
||||
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
||||
.setWorkspacePointer(workspace.data<void>())
|
||||
.setDataPointers(3, data_ptrs)
|
||||
.setUids(3, uids)
|
||||
.build();
|
||||
|
||||
auto handle = encoder.device().cudnn_handle();
|
||||
cudnnSetStream(handle, encoder.stream());
|
||||
|
||||
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||
cudaGraph_t graph;
|
||||
cudaGraphCreate(&graph, 0);
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||
if (cudnnBackendPopulateCudaGraph(
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
encoder.add_graph_node(graph);
|
||||
#else
|
||||
auto capture = encoder.capture_context();
|
||||
if (cudnnBackendExecute(
|
||||
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
||||
CUDNN_STATUS_SUCCESS) {
|
||||
// Discard the captured graph when failed.
|
||||
capture.discard = true;
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
encoder.add_temporary(workspace);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool try_engines(
|
||||
cu::CommandEncoder& encoder,
|
||||
const ConvCacheKey& cache_key,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
cudnn_frontend::EngineConfigList& configs,
|
||||
const std::string& op_graph_tag,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y) {
|
||||
for (auto& config : configs) {
|
||||
try {
|
||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setEngineConfig(config, op_graph_tag)
|
||||
.build();
|
||||
if (execute_plan(encoder, plan, x, w, y)) {
|
||||
conv_cache().emplace(
|
||||
cache_key, std::make_pair(backend_type, std::move(plan)));
|
||||
return true;
|
||||
}
|
||||
} catch (cudnn_frontend::cudnnException& error) {
|
||||
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
auto get_conv_op_settings(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y,
|
||||
const std::vector<int>& kernel_strides,
|
||||
const std::vector<int>& padding_lo_,
|
||||
const std::vector<int>& padding_hi_,
|
||||
const std::vector<int>& kernel_dilation,
|
||||
const std::vector<int>& input_dilation) {
|
||||
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
||||
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
||||
|
||||
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||
for (int i = 0; i < padding_lo.size(); ++i) {
|
||||
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
|
||||
padding_lo[i] = wt_size - padding_lo[i] - 1;
|
||||
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
|
||||
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + padding_hi[i];
|
||||
}
|
||||
return std::make_tuple(
|
||||
convert_vector<int64_t>(input_dilation),
|
||||
std::move(padding_lo),
|
||||
std::move(padding_hi),
|
||||
convert_vector<int64_t>(kernel_dilation));
|
||||
|
||||
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||
padding_hi = padding_lo;
|
||||
return std::make_tuple(
|
||||
convert_vector<int64_t>(kernel_dilation),
|
||||
std::move(padding_lo),
|
||||
std::move(padding_hi),
|
||||
convert_vector<int64_t>(kernel_strides));
|
||||
|
||||
} else {
|
||||
return std::make_tuple(
|
||||
convert_vector<int64_t>(kernel_strides),
|
||||
std::move(padding_lo),
|
||||
std::move(padding_hi),
|
||||
convert_vector<int64_t>(kernel_dilation));
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
Dtype dtype,
|
||||
array& x,
|
||||
array& w,
|
||||
array& y,
|
||||
const SmallVector<int64_t>& stride,
|
||||
const SmallVector<int64_t>& padding_lo,
|
||||
const SmallVector<int64_t>& padding_hi,
|
||||
const SmallVector<int64_t>& dilation) {
|
||||
try {
|
||||
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
||||
? CUDNN_DATA_FLOAT
|
||||
: dtype_to_cudnn_type(dtype);
|
||||
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
||||
.setDataType(compute_dtype)
|
||||
.setMathMode(CUDNN_CROSS_CORRELATION)
|
||||
.setNDims(stride.size())
|
||||
.setStrides(stride.size(), stride.data())
|
||||
.setPrePadding(padding_lo.size(), padding_lo.data())
|
||||
.setPostPadding(padding_hi.size(), padding_hi.data())
|
||||
.setDilation(dilation.size(), dilation.data())
|
||||
.build();
|
||||
|
||||
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||
.setxDesc(build_tensor('x', x))
|
||||
.setwDesc(build_tensor('w', w))
|
||||
.setyDesc(build_tensor('y', y))
|
||||
.setcDesc(conv_desc)
|
||||
.build();
|
||||
|
||||
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||
return cudnn_frontend::OperationGraphBuilder()
|
||||
.setHandle(encoder.device().cudnn_handle())
|
||||
.setOperationGraph(ops.size(), ops.data())
|
||||
.build();
|
||||
} catch (cudnn_frontend::cudnnException& error) {
|
||||
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
|
||||
throw;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// Do necessary transposes and copies to prepare the inputs and outputs for
|
||||
// building the cuDNN conv op. It is safe to be called multiple times in one
|
||||
// eval_gpu, with cost of possible redundant copies.
|
||||
std::tuple<array, array, array> prepare_args(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
array in,
|
||||
array wt,
|
||||
array out,
|
||||
Stream s) {
|
||||
// Transpose the args depending on the backend type.
|
||||
// TODO: Handle groups.
|
||||
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||
wt = swapaxes_in_eval(wt, 0, -1);
|
||||
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||
in = swapaxes_in_eval(in, 0, -1);
|
||||
wt = swapaxes_in_eval(wt, 0, -1);
|
||||
// Create a contiguous array that shares the data with |out|, but with dim
|
||||
// C_in and C_out swapped.
|
||||
Shape shape(out.shape());
|
||||
std::swap(shape.front(), shape.back());
|
||||
Strides strides(shape.size(), 1);
|
||||
for (int i = shape.size() - 2; i >= 0; --i) {
|
||||
strides[i] = shape[i + 1] * strides[i + 1];
|
||||
}
|
||||
array intermediate(std::move(shape), out.dtype(), nullptr, {});
|
||||
intermediate.copy_shared_buffer(
|
||||
out, std::move(strides), {true, true, false}, out.data_size());
|
||||
out = intermediate;
|
||||
}
|
||||
|
||||
// cuDNN requires contiguous input.
|
||||
if (!in.flags().row_contiguous) {
|
||||
in = contiguous_copy_gpu(in, s);
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
if (!wt.flags().row_contiguous) {
|
||||
wt = contiguous_copy_gpu(wt, s);
|
||||
encoder.add_temporary(wt);
|
||||
}
|
||||
|
||||
return {std::move(in), std::move(wt), std::move(out)};
|
||||
}
|
||||
|
||||
// Get the x/w/y args from the in/wt/out args depending on backend type.
|
||||
inline std::tuple<array&, array&, array&> dispatch_args(
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
array& in,
|
||||
array& wt,
|
||||
array& out) {
|
||||
switch (backend_type) {
|
||||
case CONV_BACKWARD_INPUT:
|
||||
return {out, wt, in};
|
||||
case CONV_BACKWARD_WEIGHT:
|
||||
return {in, out, wt};
|
||||
default:
|
||||
return {in, wt, out};
|
||||
}
|
||||
}
|
||||
|
||||
// Register inputs and outputs before actually running conv op. Can only be
|
||||
// called once per eval_gpu.
|
||||
void register_args(
|
||||
cu::CommandEncoder& encoder,
|
||||
cudnnBackendDescriptorType_t backend_type,
|
||||
array& in,
|
||||
array& wt,
|
||||
array& intermediate_out,
|
||||
array& final_out) {
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(wt);
|
||||
encoder.set_output_array(final_out);
|
||||
|
||||
if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||
// Turn |out| into a strided array, which will have C_in and C_out swapped
|
||||
// in vjp and the final |grad_weight| will then be contiguous.
|
||||
Strides strides = intermediate_out.strides();
|
||||
std::swap(strides.front(), strides.back());
|
||||
final_out.copy_shared_buffer(
|
||||
intermediate_out,
|
||||
std::move(strides),
|
||||
{false, false, false},
|
||||
intermediate_out.data_size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||
if (out_.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert(inputs.size() == 2);
|
||||
array in = inputs[0];
|
||||
array wt = inputs[1];
|
||||
array out = out_;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
Dtype dtype = out.dtype();
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
// Search cache.
|
||||
ConvCacheKey cache_key{
|
||||
encoder.device().cuda_device(),
|
||||
dtype_to_cudnn_type(dtype),
|
||||
fixed_vector(in.shape()),
|
||||
fixed_vector(wt.shape()),
|
||||
fixed_vector(kernel_strides_),
|
||||
fixed_vector(padding_lo_),
|
||||
fixed_vector(padding_hi_),
|
||||
fixed_vector(kernel_dilation_),
|
||||
groups_,
|
||||
flip_,
|
||||
get_alignment(in),
|
||||
get_alignment(wt),
|
||||
get_alignment(out)};
|
||||
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||
auto& [backend_type, plan] = it->second;
|
||||
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (!execute_plan(encoder, plan, x, w, y)) {
|
||||
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// There is no reliable way to deduce the proper cuDNN backend for the
|
||||
// convolution, so we make a best guess and then try.
|
||||
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
|
||||
if (flip_) {
|
||||
// When weight is flipped, we assume it is backward input convolution.
|
||||
try_backends.push_back(CONV_BACKWARD_INPUT);
|
||||
} else {
|
||||
// Otherwise it could be backward weight convolution or forward convolution,
|
||||
// mathematically there is no difference so we have to use heuristics.
|
||||
// Empirically backward convolutions have large kernel dimensions, and
|
||||
// usually have |in| and |wt| transposed.
|
||||
if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
|
||||
wt.shape(2) > out.shape(2)) {
|
||||
try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
|
||||
} else {
|
||||
try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
|
||||
}
|
||||
}
|
||||
|
||||
// Try to build op graph.
|
||||
cudnnBackendDescriptorType_t backend_type;
|
||||
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
||||
for (auto try_backend : try_backends) {
|
||||
auto [in_copy, wt_copy, out_copy] =
|
||||
prepare_args(encoder, try_backend, in, wt, out, s);
|
||||
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
||||
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
||||
try_backend,
|
||||
x,
|
||||
w,
|
||||
y,
|
||||
kernel_strides_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_dilation_,
|
||||
input_dilation_);
|
||||
op_graph = build_op_graph(
|
||||
encoder,
|
||||
try_backend,
|
||||
dtype,
|
||||
x,
|
||||
w,
|
||||
y,
|
||||
stride,
|
||||
padding_lo,
|
||||
padding_hi,
|
||||
dilation);
|
||||
if (op_graph) {
|
||||
backend_type = try_backend;
|
||||
in = std::move(in_copy);
|
||||
wt = std::move(wt_copy);
|
||||
out = std::move(out_copy);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!op_graph) {
|
||||
throw std::runtime_error("[conv] Can not build op graph.");
|
||||
}
|
||||
|
||||
// Get ready to execute the graph.
|
||||
register_args(encoder, backend_type, in, wt, out, out_);
|
||||
|
||||
// Try to run plans based on heuristics.
|
||||
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||
auto tag = op_graph->getTag();
|
||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||
return;
|
||||
}
|
||||
// Then try fallback plans.
|
||||
configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("[conv] Unable to find a working engine.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
87
mlx/backend/cuda/copy.cu
Normal file
87
mlx/backend/cuda/copy.cu
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_gpu_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
CopyType ctype,
|
||||
const Stream& s,
|
||||
const std::optional<array>& dynamic_offset_in,
|
||||
const std::optional<array>& dynamic_offset_out) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||
return;
|
||||
}
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(
|
||||
shape, std::vector{strides_in, strides_out}, INT32_MAX);
|
||||
if (ctype == CopyType::General) {
|
||||
copy_general_input(
|
||||
encoder,
|
||||
ctype,
|
||||
in,
|
||||
out,
|
||||
offset_in,
|
||||
offset_out,
|
||||
shape_collapsed,
|
||||
strides_vec[0]);
|
||||
} else {
|
||||
if (dynamic_offset_in || dynamic_offset_out) {
|
||||
copy_general_dynamic(
|
||||
encoder,
|
||||
ctype,
|
||||
in,
|
||||
out,
|
||||
offset_in,
|
||||
offset_out,
|
||||
shape_collapsed,
|
||||
strides_vec[0],
|
||||
strides_vec[1],
|
||||
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
|
||||
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
|
||||
} else {
|
||||
copy_general(
|
||||
encoder,
|
||||
ctype,
|
||||
in,
|
||||
out,
|
||||
offset_in,
|
||||
offset_out,
|
||||
shape_collapsed,
|
||||
strides_vec[0],
|
||||
strides_vec[1]);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void fill_gpu(const array& in, array& out, const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
55
mlx/backend/cuda/copy/copy.cuh
Normal file
55
mlx/backend/cuda/copy/copy.cuh
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_contiguous(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out);
|
||||
|
||||
void copy_general(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out);
|
||||
|
||||
void copy_general_dynamic(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out,
|
||||
const array& dynamic_offset_in,
|
||||
const array& dynamic_offset_out);
|
||||
|
||||
void copy_general_input(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in);
|
||||
|
||||
} // namespace mlx::core
|
88
mlx/backend/cuda/copy/copy_contiguous.cu
Normal file
88
mlx/backend/cuda/copy/copy_contiguous.cu
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void copy_s(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_READS > size) {
|
||||
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||
out[i] = cast_to<Out>(in[0]);
|
||||
}
|
||||
} else {
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = cast_to<Out>(in[0]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||
__global__ void copy_v(const In* in, Out* out, IdxT size) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
|
||||
if ((index + 1) * N_READS > size) {
|
||||
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||
out[i] = cast_to<Out>(in[i]);
|
||||
}
|
||||
} else {
|
||||
auto in_vec = load_vector<N_READS>(in, index);
|
||||
|
||||
AlignedVector<Out, N_READS> out_vec;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N_READS; ++i) {
|
||||
out_vec[i] = cast_to<Out>(in_vec[i]);
|
||||
}
|
||||
|
||||
store_vector<N_READS>(out, index, out_vec);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_contiguous(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t in_offset,
|
||||
int64_t out_offset) {
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||
constexpr int N_READS = 16 / sizeof(InType);
|
||||
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
|
||||
if (ctype == CopyType::Vector) {
|
||||
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
|
||||
}
|
||||
auto [num_blocks, block_dims] = get_launch_args(
|
||||
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||
encoder.add_kernel_node(
|
||||
kernel,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in.data<InType>() + in_offset,
|
||||
out.data<OutType>() + out_offset,
|
||||
out.data_size());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
109
mlx/backend/cuda/copy/copy_general.cu
Normal file
109
mlx/backend/cuda/copy/copy_general.cu
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void copy_gg_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), strides_in.data(), strides_out.data());
|
||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_gg(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
const __grid_constant__ Strides strides_out,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc(
|
||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_general(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out) {
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
size_t data_size = 1;
|
||||
for (auto& s : shape)
|
||||
data_size *= s;
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(data_size, shape, out.strides(), large());
|
||||
encoder.add_kernel_node(
|
||||
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
data_size,
|
||||
const_param<ndim_constant()>(shape),
|
||||
const_param<ndim_constant()>(strides_in),
|
||||
const_param<ndim_constant()>(strides_out));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto [num_blocks, block_dims] =
|
||||
get_launch_args(data_size, shape, out.strides(), large());
|
||||
encoder.add_kernel_node(
|
||||
cu::copy_gg<InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
data_size,
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
118
mlx/backend/cuda/copy/copy_general_dynamic.cu
Normal file
118
mlx/backend/cuda/copy/copy_general_dynamic.cu
Normal file
@@ -0,0 +1,118 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void copy_gg_dynamic_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out,
|
||||
const int64_t* offset_in,
|
||||
const int64_t* offset_out) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||
index, shape.data(), strides_in.data(), strides_out.data());
|
||||
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_gg_dynamic(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
const __grid_constant__ Strides strides_out,
|
||||
int ndim,
|
||||
const int64_t* offset_in,
|
||||
const int64_t* offset_out) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
auto [idx_in, idx_out] = elem_to_loc(
|
||||
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_general_dynamic(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in,
|
||||
const Strides& strides_out,
|
||||
const array& dynamic_offset_in,
|
||||
const array& dynamic_offset_out) {
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
cu::copy_gg_dynamic_nd<
|
||||
InType,
|
||||
OutType,
|
||||
IdxT,
|
||||
dims_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(strides_in),
|
||||
const_param<dims_constant()>(strides_out),
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
cu::copy_gg_dynamic<InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
const_param(strides_out),
|
||||
ndim,
|
||||
dynamic_offset_in.data<int64_t>(),
|
||||
dynamic_offset_out.data<int64_t>());
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
97
mlx/backend/cuda/copy/copy_general_input.cu
Normal file
97
mlx/backend/cuda/copy/copy_general_input.cu
Normal file
@@ -0,0 +1,97 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||
__global__ void copy_g_nd(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
|
||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename In, typename Out, typename IdxT>
|
||||
__global__ void copy_g(
|
||||
const In* in,
|
||||
Out* out,
|
||||
IdxT size,
|
||||
const __grid_constant__ Shape shape,
|
||||
const __grid_constant__ Strides strides_in,
|
||||
int ndim) {
|
||||
IdxT index = cg::this_grid().thread_rank();
|
||||
if (index < size) {
|
||||
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
|
||||
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
void copy_general_input(
|
||||
cu::CommandEncoder& encoder,
|
||||
CopyType ctype,
|
||||
const array& in,
|
||||
array& out,
|
||||
int64_t offset_in,
|
||||
int64_t offset_out,
|
||||
const Shape& shape,
|
||||
const Strides& strides_in) {
|
||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||
dispatch_bool(
|
||||
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||
[&](auto large) {
|
||||
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||
int ndim = shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param<dims_constant()>(shape),
|
||||
const_param<dims_constant()>(strides_in));
|
||||
});
|
||||
} else { // ndim >= 4
|
||||
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||
encoder.add_kernel_node(
|
||||
cu::copy_g<InType, OutType, IdxT>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
out.size(),
|
||||
const_param(shape),
|
||||
const_param(strides_in),
|
||||
ndim);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/cuda/cuda.cpp
Normal file
11
mlx/backend/cuda/cuda.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/cuda.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
10
mlx/backend/cuda/cuda.h
Normal file
10
mlx/backend/cuda/cuda.h
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
/* Check if the CUDA backend is available. */
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::cu
|
344
mlx/backend/cuda/device.cpp
Normal file
344
mlx/backend/cuda/device.cpp
Normal file
@@ -0,0 +1,344 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <future>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace {
|
||||
|
||||
// Can be tuned with MLX_MAX_OPS_PER_BUFFER
|
||||
// This should be less than 255
|
||||
constexpr int default_max_nodes_per_graph = 20;
|
||||
|
||||
#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||
|
||||
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
if (err != CUDNN_STATUS_SUCCESS) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
int cuda_graph_cache_size() {
|
||||
static int cache_size = []() {
|
||||
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100);
|
||||
}();
|
||||
return cache_size;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Device::Device(int device) : device_(device) {
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||
&compute_capability_major_, cudaDevAttrComputeCapabilityMajor, device_));
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||
&compute_capability_minor_, cudaDevAttrComputeCapabilityMinor, device_));
|
||||
// Validate the requirements of device.
|
||||
int attr = 0;
|
||||
CHECK_CUDA_ERROR(cudaDeviceGetAttribute(
|
||||
&attr, cudaDevAttrConcurrentManagedAccess, device_));
|
||||
if (attr != 1) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Device {} does not support synchronization in managed memory.",
|
||||
device_));
|
||||
}
|
||||
// The cublasLt handle is used by matmul.
|
||||
make_current();
|
||||
CHECK_CUBLAS_ERROR(cublasLtCreate(<_));
|
||||
// The cudnn handle is used by Convolution.
|
||||
CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_));
|
||||
|
||||
// Initialize the jit module cache here ensures it is not
|
||||
// unloaded before any evaluation is done
|
||||
get_jit_module_cache();
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_));
|
||||
CHECK_CUBLAS_ERROR(cublasLtDestroy(lt_));
|
||||
}
|
||||
|
||||
void Device::make_current() {
|
||||
// We need to set/get current CUDA device very frequently, cache it to reduce
|
||||
// actual calls of CUDA APIs. This function assumes single-thread in host.
|
||||
static int current = 0;
|
||||
if (current != device_) {
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(device_));
|
||||
current = device_;
|
||||
}
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(Stream s) {
|
||||
auto it = encoders_.find(s.index);
|
||||
if (it == encoders_.end()) {
|
||||
it = encoders_.try_emplace(s.index, *this).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) {
|
||||
enc.device().make_current();
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal));
|
||||
}
|
||||
|
||||
CommandEncoder::CaptureContext::~CaptureContext() {
|
||||
CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph));
|
||||
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||
&graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); });
|
||||
if (discard) {
|
||||
return;
|
||||
}
|
||||
enc.add_graph_node(graph);
|
||||
}
|
||||
|
||||
CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc)
|
||||
: enc(enc) {
|
||||
enc.in_concurrent_ = true;
|
||||
}
|
||||
|
||||
CommandEncoder::ConcurrentContext::~ConcurrentContext() {
|
||||
enc.in_concurrent_ = false;
|
||||
|
||||
// Use an empty graph node for synchronization
|
||||
CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)};
|
||||
enc.empty_node_count_++;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddEmptyNode(&empty.node, enc.graph_, NULL, 0));
|
||||
|
||||
// Insert the concurrent -> empty node dependencies
|
||||
for (auto& from : enc.concurrent_nodes_) {
|
||||
enc.from_nodes_.push_back(from.node);
|
||||
enc.to_nodes_.push_back(empty.node);
|
||||
enc.graph_key_ += from.id;
|
||||
enc.graph_key_ += from.node_type;
|
||||
enc.graph_key_ += empty.id;
|
||||
enc.graph_key_ += empty.node_type;
|
||||
}
|
||||
|
||||
// Insert the input -> concurrent node dependencies without updating output
|
||||
// nodes
|
||||
auto outputs = std::move(enc.active_outputs_);
|
||||
enc.insert_graph_dependencies(std::move(enc.concurrent_nodes_));
|
||||
|
||||
// Update output node to be the empty node
|
||||
for (auto o : outputs) {
|
||||
enc.node_map_.emplace(o, empty).first->second = empty;
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::insert_graph_dependencies(GraphNode node) {
|
||||
if (node.node_type == 'G') {
|
||||
graph_node_count_++;
|
||||
}
|
||||
node.id = std::to_string(node_count_++);
|
||||
if (in_concurrent_) {
|
||||
concurrent_nodes_.push_back(std::move(node));
|
||||
} else {
|
||||
std::vector<GraphNode> nodes;
|
||||
nodes.push_back(std::move(node));
|
||||
insert_graph_dependencies(std::move(nodes));
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::insert_graph_dependencies(std::vector<GraphNode> nodes) {
|
||||
std::vector<GraphNode> deps;
|
||||
{
|
||||
// Dependencies must be added in the same order to produce a consistent
|
||||
// topology
|
||||
std::unordered_set<cudaGraphNode_t> set_deps;
|
||||
for (auto d : active_deps_) {
|
||||
if (auto it = node_map_.find(d); it != node_map_.end()) {
|
||||
auto [_, inserted] = set_deps.insert(it->second.node);
|
||||
if (inserted) {
|
||||
deps.push_back(it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
active_deps_.clear();
|
||||
|
||||
for (auto o : active_outputs_) {
|
||||
for (auto& node : nodes) {
|
||||
node_map_.emplace(o, node).first->second = node;
|
||||
}
|
||||
}
|
||||
active_outputs_.clear();
|
||||
|
||||
for (auto& from : deps) {
|
||||
for (auto& to : nodes) {
|
||||
from_nodes_.push_back(from.node);
|
||||
to_nodes_.push_back(to.node);
|
||||
graph_key_ += from.id;
|
||||
graph_key_ += from.node_type;
|
||||
graph_key_ += to.id;
|
||||
graph_key_ += to.node_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CommandEncoder::CommandEncoder(Device& d)
|
||||
: device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) {
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
worker_.add_task(std::move(task));
|
||||
}
|
||||
|
||||
void CommandEncoder::set_input_array(const array& arr) {
|
||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||
active_deps_.push_back(id);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_output_array(const array& arr) {
|
||||
auto id = reinterpret_cast<std::uintptr_t>(arr.buffer().ptr());
|
||||
active_deps_.push_back(id);
|
||||
active_outputs_.push_back(id);
|
||||
}
|
||||
|
||||
void CommandEncoder::maybe_commit() {
|
||||
if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) {
|
||||
commit();
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(
|
||||
void* func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
cudaKernelNodeParams kernel_params = {0};
|
||||
kernel_params.func = func;
|
||||
kernel_params.gridDim = grid_dim;
|
||||
kernel_params.blockDim = block_dim;
|
||||
kernel_params.kernelParams = params;
|
||||
kernel_params.sharedMemBytes = smem_bytes;
|
||||
add_kernel_node(kernel_params);
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(
|
||||
CUfunction func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params) {
|
||||
CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
|
||||
kernel_params.func = func;
|
||||
kernel_params.gridDimX = grid_dim.x;
|
||||
kernel_params.gridDimY = grid_dim.y;
|
||||
kernel_params.gridDimZ = grid_dim.z;
|
||||
kernel_params.blockDimX = block_dim.x;
|
||||
kernel_params.blockDimY = block_dim.y;
|
||||
kernel_params.blockDimZ = block_dim.z;
|
||||
kernel_params.kernelParams = params;
|
||||
kernel_params.sharedMemBytes = smem_bytes;
|
||||
add_kernel_node(kernel_params);
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(const cudaKernelNodeParams& params) {
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
}
|
||||
|
||||
void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) {
|
||||
CUgraphNode node;
|
||||
CHECK_CUDA_ERROR(cuGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms));
|
||||
insert_graph_dependencies(GraphNode{node, 'K'});
|
||||
}
|
||||
|
||||
void CommandEncoder::add_graph_node(cudaGraph_t child) {
|
||||
cudaGraphNode_t node;
|
||||
CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child));
|
||||
insert_graph_dependencies(GraphNode{node, 'G'});
|
||||
}
|
||||
|
||||
void CommandEncoder::commit() {
|
||||
nvtx3::scoped_range r("CommandEncoder::commit");
|
||||
if (!temporaries_.empty()) {
|
||||
add_completed_handler([temporaries = std::move(temporaries_)]() {});
|
||||
}
|
||||
if (node_count_ > 0) {
|
||||
if (!from_nodes_.empty()) {
|
||||
CHECK_CUDA_ERROR(cudaGraphAddDependencies(
|
||||
graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size()));
|
||||
}
|
||||
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(node_count_);
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(graph_node_count_);
|
||||
graph_key_ += ".";
|
||||
graph_key_ += std::to_string(empty_node_count_);
|
||||
|
||||
CudaGraphExec& graph_exec = graph_cache_[graph_key_];
|
||||
|
||||
if (graph_exec != nullptr) {
|
||||
cudaGraphExecUpdateResult update_result;
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo info;
|
||||
cudaGraphExecUpdate(graph_exec, graph_, &info);
|
||||
update_result = info.result;
|
||||
#else
|
||||
cudaGraphNode_t error_node;
|
||||
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
if (update_result != cudaGraphExecUpdateSuccess) {
|
||||
cudaGetLastError(); // reset error
|
||||
graph_exec.reset();
|
||||
}
|
||||
}
|
||||
if (graph_exec == nullptr) {
|
||||
graph_exec.instantiate(graph_);
|
||||
}
|
||||
device_.make_current();
|
||||
CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_));
|
||||
|
||||
// Reset state
|
||||
node_count_ = 0;
|
||||
graph_node_count_ = 0;
|
||||
empty_node_count_ = 0;
|
||||
from_nodes_.clear();
|
||||
to_nodes_.clear();
|
||||
graph_key_.clear();
|
||||
node_map_.clear();
|
||||
CHECK_CUDA_ERROR(cudaGraphDestroy(graph_));
|
||||
CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0));
|
||||
}
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.commit(stream_);
|
||||
}
|
||||
|
||||
void CommandEncoder::synchronize() {
|
||||
cudaStreamSynchronize(stream_);
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
std::future<void> f = p->get_future();
|
||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||
commit();
|
||||
f.wait();
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device device) {
|
||||
static std::unordered_map<int, Device> devices;
|
||||
auto it = devices.find(device.index);
|
||||
if (it == devices.end()) {
|
||||
it = devices.try_emplace(device.index, device.index).first;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s) {
|
||||
return device(s.device).get_command_encoder(s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
183
mlx/backend/cuda/device.h
Normal file
183
mlx/backend/cuda/device.h
Normal file
@@ -0,0 +1,183 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <cudnn.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class CommandEncoder {
|
||||
public:
|
||||
struct CaptureContext {
|
||||
CaptureContext(CommandEncoder& enc);
|
||||
~CaptureContext();
|
||||
cudaGraph_t graph;
|
||||
CommandEncoder& enc;
|
||||
bool discard{false};
|
||||
};
|
||||
struct ConcurrentContext {
|
||||
ConcurrentContext(CommandEncoder& enc);
|
||||
~ConcurrentContext();
|
||||
CommandEncoder& enc;
|
||||
};
|
||||
|
||||
explicit CommandEncoder(Device& d);
|
||||
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
CaptureContext capture_context() {
|
||||
return CaptureContext{*this};
|
||||
}
|
||||
ConcurrentContext concurrent_context() {
|
||||
return ConcurrentContext{*this};
|
||||
}
|
||||
|
||||
void set_input_array(const array& arr);
|
||||
void set_output_array(const array& arr);
|
||||
|
||||
template <typename F, typename... Params>
|
||||
void add_kernel_node(
|
||||
F* func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
Params&&... params) {
|
||||
constexpr size_t num = sizeof...(Params);
|
||||
void* ptrs[num];
|
||||
size_t i = 0;
|
||||
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
|
||||
std::forward<Params>(params)),
|
||||
...);
|
||||
add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
|
||||
}
|
||||
|
||||
void add_kernel_node(
|
||||
CUfunction func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params);
|
||||
|
||||
void add_kernel_node(
|
||||
void* func,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
uint32_t smem_bytes,
|
||||
void** params);
|
||||
|
||||
// Low-level graph helpers.
|
||||
void add_kernel_node(const cudaKernelNodeParams& params);
|
||||
void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params);
|
||||
void add_graph_node(cudaGraph_t child);
|
||||
|
||||
void add_temporary(const array& arr) {
|
||||
temporaries_.push_back(arr.data_shared_ptr());
|
||||
}
|
||||
|
||||
void add_completed_handler(std::function<void()> task);
|
||||
void maybe_commit();
|
||||
void commit();
|
||||
|
||||
Device& device() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
CudaStream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
// Wait until kernels and completion handlers are finished
|
||||
void synchronize();
|
||||
|
||||
private:
|
||||
struct GraphNode {
|
||||
cudaGraphNode_t node;
|
||||
// K = kernel
|
||||
// E = empty
|
||||
// G = subgraph
|
||||
char node_type;
|
||||
std::string id;
|
||||
};
|
||||
|
||||
void insert_graph_dependencies(GraphNode node);
|
||||
void insert_graph_dependencies(std::vector<GraphNode> nodes);
|
||||
|
||||
Device& device_;
|
||||
CudaStream stream_;
|
||||
cudaGraph_t graph_;
|
||||
Worker worker_;
|
||||
char node_count_{0};
|
||||
char graph_node_count_{0};
|
||||
char empty_node_count_{0};
|
||||
bool in_concurrent_{false};
|
||||
std::vector<cudaGraphNode_t> from_nodes_;
|
||||
std::vector<cudaGraphNode_t> to_nodes_;
|
||||
std::string graph_key_;
|
||||
std::vector<GraphNode> concurrent_nodes_;
|
||||
std::vector<std::shared_ptr<array::Data>> temporaries_;
|
||||
LRUCache<std::string, CudaGraphExec> graph_cache_;
|
||||
std::vector<std::uintptr_t> active_deps_;
|
||||
std::vector<std::uintptr_t> active_outputs_;
|
||||
std::unordered_map<std::uintptr_t, GraphNode> node_map_;
|
||||
};
|
||||
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(int device);
|
||||
~Device();
|
||||
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
|
||||
// Make this device the current cuda device, required by some cuda calls.
|
||||
void make_current();
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
int cuda_device() const {
|
||||
return device_;
|
||||
}
|
||||
int compute_capability_major() const {
|
||||
return compute_capability_major_;
|
||||
}
|
||||
int compute_capability_minor() const {
|
||||
return compute_capability_minor_;
|
||||
}
|
||||
cublasLtHandle_t lt_handle() const {
|
||||
return lt_;
|
||||
}
|
||||
cudnnHandle_t cudnn_handle() const {
|
||||
return cudnn_;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
cublasLtHandle_t lt_;
|
||||
cudnnHandle_t cudnn_;
|
||||
std::unordered_map<int, CommandEncoder> encoders_;
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device device);
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
||||
// Return an execution policy that does not sync for result.
|
||||
// Note that not all thrust APIs support async policy, confirm before using.
|
||||
inline auto thrust_policy(cudaStream_t stream) {
|
||||
// TODO: Connect thrust's custom allocator with mlx's allocator.
|
||||
return thrust::cuda::par_nosync.on(stream);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
63
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
63
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/complex.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
|
||||
#include <cuda/atomic>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref += val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_prod(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old * val)) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_max(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_max(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_min(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_min(val);
|
||||
}
|
||||
|
||||
// Somehow cuda::atomic_ref does not provide atomic add for following types.
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add_general(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old + val)) {
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__half* out, __half val) {
|
||||
atomicAdd(out, val);
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(complex64_t* out, complex64_t val) {
|
||||
atomic_add_general(out, val);
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||
#if __CUDA_ARCH__ < 800
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
293
mlx/backend/cuda/device/binary_ops.cuh
Normal file
293
mlx/backend/cuda/device/binary_ops.cuh
Normal file
@@ -0,0 +1,293 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||
|
||||
#include <cuda/std/array>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x / y;
|
||||
} else {
|
||||
return truncf(x / y);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
if constexpr (cuda::std::is_signed_v<T>) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
} else {
|
||||
return x % y;
|
||||
}
|
||||
} else if constexpr (is_complex_v<T>) {
|
||||
return x % y;
|
||||
} else {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r = r + y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
return x == y ||
|
||||
(isnan(x.real()) && isnan(y.real()) && isnan(x.imag()) &&
|
||||
isnan(y.imag())) ||
|
||||
(x.real() == y.real() && isnan(x.imag()) && isnan(y.imag())) ||
|
||||
(isnan(x.real()) && isnan(y.real()) && x.imag() == y.imag());
|
||||
} else {
|
||||
return x == y || (isnan(x) && isnan(y));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
if (isnan(x.real()) || isnan(x.imag()) || isnan(y.real()) ||
|
||||
isnan(y.imag())) {
|
||||
return {
|
||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||
}
|
||||
auto max = x.real() > y.real() ? x : y;
|
||||
auto min = x.real() < y.real() ? x : y;
|
||||
auto min_real = min.real();
|
||||
auto max_real = max.real();
|
||||
if (!isfinite(min_real) && (min_real == max_real)) {
|
||||
if (min_real < 0) {
|
||||
return min;
|
||||
} else {
|
||||
return Log{}(Exp{}(min) + Exp{}(max));
|
||||
}
|
||||
} else {
|
||||
return Log1p{}(Exp{}(min - max)) + max;
|
||||
}
|
||||
} else {
|
||||
if (isnan(x) || isnan(y)) {
|
||||
return cuda::std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
T maxval = max(x, y);
|
||||
T minval = min(x, y);
|
||||
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
|
||||
maxval == cuda::std::numeric_limits<T>::infinity())
|
||||
? maxval
|
||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return max(x, y);
|
||||
} else if constexpr (is_complex_v<T>) {
|
||||
if (isnan(x.real()) || isnan(x.imag())) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
} else {
|
||||
if (isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return min(x, y);
|
||||
} else if constexpr (is_complex_v<T>) {
|
||||
if (isnan(x.real()) || isnan(x.imag())) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
} else {
|
||||
if (isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
__device__ bool operator()(T x, T y) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
return x.real() != y.real() || x.imag() != y.imag();
|
||||
} else {
|
||||
return x != y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
__device__ T operator()(T base, T exp) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
} else if constexpr (is_complex_v<T>) {
|
||||
return pow(base, exp);
|
||||
} else {
|
||||
return powf(base, exp);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T y, T x) {
|
||||
return atan2f(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
struct DivMod {
|
||||
template <typename T>
|
||||
__device__ cuda::std::array<T, 2> operator()(T x, T y) {
|
||||
return {FloorDivide{}(x, y), Remainder{}(x, y)};
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
130
mlx/backend/cuda/device/cast_op.cuh
Normal file
130
mlx/backend/cuda/device/cast_op.cuh
Normal file
@@ -0,0 +1,130 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/complex.cuh"
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// An op that does static_cast, with custom conversions for some types.
|
||||
template <typename SrcT, typename DstT, typename = void>
|
||||
struct CastOp {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, DstT>;
|
||||
|
||||
__device__ DstT operator()(SrcT x) {
|
||||
return static_cast<DstT>(x);
|
||||
}
|
||||
};
|
||||
|
||||
// Castings between complex and boolean.
|
||||
template <typename T>
|
||||
struct CastOp<complex_t<T>, bool> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ bool operator()(complex_t<T> x) {
|
||||
return x.real() != 0 && x.imag() != 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CastOp<bool, complex_t<T>> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ complex_t<T> operator()(bool x) {
|
||||
return x ? complex_t<T>{1, 1} : complex_t<T>{0, 0};
|
||||
}
|
||||
};
|
||||
|
||||
// Converting a complex number to real number discards the imaginary part.
|
||||
template <typename T, typename DstT>
|
||||
struct CastOp<complex_t<T>, DstT, cuda::std::enable_if_t<!is_complex_v<DstT>>> {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<T, DstT>;
|
||||
|
||||
__device__ DstT operator()(complex_t<T> x) {
|
||||
static_assert(!is_complex_v<DstT>);
|
||||
return static_cast<DstT>(x.real());
|
||||
}
|
||||
};
|
||||
|
||||
// Allow converting a real number to complex number.
|
||||
template <typename SrcT, typename T>
|
||||
struct CastOp<SrcT, complex_t<T>, cuda::std::enable_if_t<!is_complex_v<SrcT>>> {
|
||||
static constexpr bool is_castable = cuda::std::is_convertible_v<SrcT, T>;
|
||||
|
||||
__device__ complex_t<T> operator()(SrcT x) {
|
||||
static_assert(!is_complex_v<SrcT>);
|
||||
return complex_t<T>{static_cast<T>(x), 0};
|
||||
}
|
||||
};
|
||||
|
||||
// Do nothing when no casting is needed.
|
||||
template <typename SrcT, typename DstT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
DstT,
|
||||
cuda::std::enable_if_t<cuda::std::is_same_v<SrcT, DstT>>> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ SrcT operator()(SrcT x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
// In CUDA 11 the half types do not define conversions between some types,
|
||||
// provide fallbacks here.
|
||||
#if CUDART_VERSION < 12000
|
||||
template <typename SrcT, typename DstT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
DstT,
|
||||
cuda::std::enable_if_t<
|
||||
!cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
|
||||
(cuda::std::is_same_v<DstT, __half> ||
|
||||
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ DstT operator()(SrcT x) {
|
||||
return DstT(static_cast<float>(x));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
struct CastOp<
|
||||
SrcT,
|
||||
DstT,
|
||||
cuda::std::enable_if_t<
|
||||
!cuda::std::is_convertible_v<SrcT, DstT> && !is_complex_v<SrcT> &&
|
||||
!cuda::std::is_same_v<DstT, __half> &&
|
||||
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
|
||||
(cuda::std::is_same_v<SrcT, __half> ||
|
||||
cuda::std::is_same_v<SrcT, __nv_bfloat16>)>> {
|
||||
static constexpr bool is_castable = true;
|
||||
|
||||
__device__ DstT operator()(SrcT x) {
|
||||
return DstT(static_cast<float>(x));
|
||||
}
|
||||
};
|
||||
#endif // CUDART_VERSION < 12000
|
||||
|
||||
// Helper to deduce the SrcT.
|
||||
template <typename DstT, typename SrcT>
|
||||
inline __host__ __device__ auto cast_to(SrcT x) {
|
||||
return CastOp<SrcT, DstT>{}(x);
|
||||
}
|
||||
|
||||
// Return an iterator that cast the value to DstT using CastOp.
|
||||
template <typename DstT, typename Iterator>
|
||||
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
|
||||
if constexpr (std::is_same_v<SrcT, DstT>) {
|
||||
return it;
|
||||
} else {
|
||||
return thrust::make_transform_iterator(it, CastOp<SrcT, DstT>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
60
mlx/backend/cuda/device/complex.cuh
Normal file
60
mlx/backend/cuda/device/complex.cuh
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Make multiplication and division faster.
|
||||
#define LIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS
|
||||
|
||||
#include <cuda/std/complex>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// TODO: Consider using a faster implementation as cuda::std::complex has to
|
||||
// conform to C++ standard.
|
||||
template <typename T>
|
||||
using complex_t = cuda::std::complex<T>;
|
||||
|
||||
using complex64_t = complex_t<float>;
|
||||
using complex128_t = complex_t<double>;
|
||||
|
||||
template <typename T>
|
||||
struct is_complex : cuda::std::false_type {};
|
||||
|
||||
template <typename T>
|
||||
struct is_complex<cuda::std::complex<T>> : cuda::std::true_type {};
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_complex_v = is_complex<T>::value;
|
||||
|
||||
// cuda::std::complex is missing some operators.
|
||||
template <typename T>
|
||||
inline __host__ __device__ complex_t<T> operator%(
|
||||
complex_t<T> a,
|
||||
complex_t<T> b) {
|
||||
T r = a.real() - floor(a.real() / b.real()) * b.real();
|
||||
T i = a.imag() - floor(a.imag() / b.imag()) * b.imag();
|
||||
return complex_t<T>{r, i};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __host__ __device__ bool operator>(complex_t<T> a, complex_t<T> b) {
|
||||
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __host__ __device__ bool operator<(complex_t<T> a, complex_t<T> b) {
|
||||
return operator>(b, a);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __host__ __device__ bool operator<=(complex_t<T> a, complex_t<T> b) {
|
||||
return !(a > b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __host__ __device__ bool operator>=(complex_t<T> a, complex_t<T> b) {
|
||||
return !(a < b);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
12
mlx/backend/cuda/device/config.h
Normal file
12
mlx/backend/cuda/device/config.h
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file is used by both CUDA kernel code and host-only C++ code.
|
||||
|
||||
#pragma once
|
||||
|
||||
// The maximum dimensions of shape/strides passed as kernel parameters.
|
||||
#define MAX_NDIM 10
|
||||
|
||||
// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in
|
||||
// warpSize variable exists, using it would prevent compile-time optimizations.
|
||||
#define WARP_SIZE 32
|
194
mlx/backend/cuda/device/fp16_math.cuh
Normal file
194
mlx/backend/cuda/device/fp16_math.cuh
Normal file
@@ -0,0 +1,194 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Unary ops for half types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define MLX_DEFINE_UNARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return HALF_OP(x); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define MLX_DEFINE_UNARY_OP_FALLBCK(NAME) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return ::NAME(__half2float(x)); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return ::NAME(__bfloat162float(x)); \
|
||||
} else { \
|
||||
return ::NAME(x); \
|
||||
} \
|
||||
}
|
||||
|
||||
MLX_DEFINE_UNARY_OP(abs, __habs)
|
||||
MLX_DEFINE_UNARY_OP(ceil, hceil)
|
||||
MLX_DEFINE_UNARY_OP(cos, hcos)
|
||||
MLX_DEFINE_UNARY_OP(exp, hexp)
|
||||
MLX_DEFINE_UNARY_OP(floor, hfloor)
|
||||
MLX_DEFINE_UNARY_OP(isnan, __hisnan)
|
||||
MLX_DEFINE_UNARY_OP(log, hlog)
|
||||
MLX_DEFINE_UNARY_OP(log2, hlog2)
|
||||
MLX_DEFINE_UNARY_OP(log10, hlog10)
|
||||
MLX_DEFINE_UNARY_OP(rint, hrint)
|
||||
MLX_DEFINE_UNARY_OP(rsqrt, hrsqrt)
|
||||
MLX_DEFINE_UNARY_OP(sin, hsin)
|
||||
MLX_DEFINE_UNARY_OP(sqrt, hsqrt)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(acos)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(acosh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(asin)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(asinh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(atan)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(atanh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(cosh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(log1p)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(sinh)
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(tan)
|
||||
#if __CUDA_ARCH__ >= 1280
|
||||
MLX_DEFINE_UNARY_OP(tanh, htanh)
|
||||
#else
|
||||
MLX_DEFINE_UNARY_OP_FALLBCK(tanh)
|
||||
#endif
|
||||
|
||||
#undef MLX_DEFINE_UNARY_OP
|
||||
#undef MLX_DEFINE_UNARY_OP_FALLBCK
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Binary ops for half types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x, T y) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else { \
|
||||
return ::NAME(x, y); \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define MLX_DEFINE_BINARY_OP(NAME, HALF_OP) \
|
||||
template <typename T> \
|
||||
__forceinline__ __device__ auto NAME(T x, T y) { \
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) { \
|
||||
return HALF_OP(x, y); \
|
||||
} else { \
|
||||
return ::NAME(x, y); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
MLX_DEFINE_BINARY_OP(max, __hmax)
|
||||
MLX_DEFINE_BINARY_OP(min, __hmin)
|
||||
|
||||
#undef MLX_DEFINE_BINARY_OP
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T fmod(T x, T y) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return __float2half(::fmod(__half2float(x), __half2float(y)));
|
||||
#if CUDART_VERSION >= 12000 || __CUDA_ARCH__ >= 800
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return __float2bfloat16(::fmod(__bfloat162float(x), __bfloat162float(y)));
|
||||
#endif
|
||||
} else {
|
||||
return ::fmod(x, y);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Additional C++ operator overrides between half types and native types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_integral_except =
|
||||
cuda::std::is_integral_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
template <typename T, typename U>
|
||||
constexpr bool is_arithmetic_except =
|
||||
cuda::std::is_arithmetic_v<T> && !cuda::std::is_same_v<T, U>;
|
||||
|
||||
#define MLX_DEFINE_HALF_OP(HALF, HALF2FLOAT, FLOAT2HALF, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(HALF x, T y) { \
|
||||
return FLOAT2HALF(HALF2FLOAT(x) OP static_cast<float>(y)); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_integral_except<T, HALF>>> \
|
||||
__forceinline__ __device__ HALF operator OP(T x, HALF y) { \
|
||||
return FLOAT2HALF(static_cast<float>(x) OP HALF2FLOAT(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_HALF_CMP(HALF, HALF2FLOAT, OP) \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(HALF x, T y) { \
|
||||
return HALF2FLOAT(x) OP static_cast<float>(y); \
|
||||
} \
|
||||
template < \
|
||||
typename T, \
|
||||
typename = cuda::std::enable_if_t<is_arithmetic_except<T, HALF>>> \
|
||||
__forceinline__ __device__ bool operator OP(T x, HALF y) { \
|
||||
return static_cast<float>(y) OP HALF2FLOAT(x); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, +)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, -)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, *)
|
||||
MLX_DEFINE_HALF_OP(__half, __half2float, __float2half, /)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, +)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, -)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, *)
|
||||
MLX_DEFINE_HALF_OP(__nv_bfloat16, __bfloat162float, __float2bfloat16, /)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__half, __half2float, !=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, <=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, >=)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, ==)
|
||||
MLX_DEFINE_HALF_CMP(__nv_bfloat16, __bfloat162float, !=)
|
||||
|
||||
#undef MLX_DEFINE_HALF_OP
|
||||
#undef MLX_DEFINE_HALF_CMP
|
||||
|
||||
} // namespace mlx::core::cu
|
53
mlx/backend/cuda/device/gather.cuh
Normal file
53
mlx/backend/cuda/device/gather.cuh
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
||||
__global__ void gather(
|
||||
const T* src,
|
||||
T* out,
|
||||
LocT size,
|
||||
const __grid_constant__ Shape src_shape,
|
||||
const __grid_constant__ Strides src_strides,
|
||||
int32_t src_ndim,
|
||||
const __grid_constant__ Shape slice_sizes,
|
||||
uint32_t slice_size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||
indices_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||
indices_strides) {
|
||||
LocT out_idx = cg::this_grid().thread_rank();
|
||||
if (out_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
LocT src_elem = out_idx % slice_size;
|
||||
LocT idx_elem = out_idx / slice_size;
|
||||
|
||||
LocT src_loc =
|
||||
elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||
idx_elem,
|
||||
indices_shape.data() + i * IDX_NDIM,
|
||||
indices_strides.data() + i * IDX_NDIM);
|
||||
int32_t axis = axes[i];
|
||||
LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]);
|
||||
src_loc += idx_val * src_strides[axis];
|
||||
}
|
||||
|
||||
out[out_idx] = src[src_loc];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
@@ -0,0 +1,65 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
int NDIM,
|
||||
bool SrcC,
|
||||
bool IdxC,
|
||||
typename LocT>
|
||||
__global__ void gather_axis(
|
||||
const T* src,
|
||||
const IdxT* indices,
|
||||
T* out,
|
||||
LocT idx_size_pre,
|
||||
LocT idx_size_axis,
|
||||
LocT idx_size_post,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> src_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||
int32_t axis,
|
||||
int32_t axis_size,
|
||||
int64_t src_stride_axis,
|
||||
int64_t idx_stride_axis) {
|
||||
LocT index = cg::this_grid().thread_rank();
|
||||
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||
|
||||
LocT elem_idx = z * idx_size_post;
|
||||
|
||||
LocT idx_loc = y * idx_stride_axis;
|
||||
if constexpr (IdxC) {
|
||||
idx_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
idx_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||
}
|
||||
|
||||
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||
|
||||
LocT src_loc = idx_val * src_stride_axis;
|
||||
if constexpr (SrcC) {
|
||||
src_loc += elem_idx * axis_size + x;
|
||||
} else {
|
||||
src_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), src_strides.data());
|
||||
}
|
||||
|
||||
LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x;
|
||||
|
||||
out[out_idx] = src[src_loc];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
30
mlx/backend/cuda/device/indexing.cuh
Normal file
30
mlx/backend/cuda/device/indexing.cuh
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <cuda/std/tuple>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Convert an absolute index to positions in a 3d grid, assuming the index is
|
||||
// calculated with:
|
||||
// index = x * dim1 * dim2 + y * dim2 + z
|
||||
template <typename T>
|
||||
inline __host__ __device__ cuda::std::tuple<T, T, T>
|
||||
index_to_dims(T index, T dim1, T dim2) {
|
||||
T x = index / (dim1 * dim2);
|
||||
T y = (index % (dim1 * dim2)) / dim2;
|
||||
T z = index % dim2;
|
||||
return cuda::std::make_tuple(x, y, z);
|
||||
}
|
||||
|
||||
// Get absolute index from possible negative index.
|
||||
template <typename IdxT>
|
||||
inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) {
|
||||
if constexpr (cuda::std::is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
return static_cast<int32_t>(idx < 0 ? idx + size : idx);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
68
mlx/backend/cuda/device/scatter.cuh
Normal file
68
mlx/backend/cuda/device/scatter.cuh
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NIDX,
|
||||
int IDX_NDIM,
|
||||
typename LocT>
|
||||
__global__ void scatter(
|
||||
const T* upd,
|
||||
T* out,
|
||||
LocT size,
|
||||
const __grid_constant__ Shape upd_shape,
|
||||
const __grid_constant__ Strides upd_strides,
|
||||
int32_t upd_ndim,
|
||||
LocT upd_post_idx_size,
|
||||
const __grid_constant__ Shape out_shape,
|
||||
const __grid_constant__ Strides out_strides,
|
||||
int32_t out_ndim,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||
indices_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||
indices_strides) {
|
||||
LocT upd_idx = cg::this_grid().thread_rank();
|
||||
if (upd_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
LocT out_elem = upd_idx % upd_post_idx_size;
|
||||
LocT idx_elem = upd_idx / upd_post_idx_size;
|
||||
|
||||
LocT out_idx = elem_to_loc(
|
||||
out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||
idx_elem,
|
||||
indices_shape.data() + i * IDX_NDIM,
|
||||
indices_strides.data() + i * IDX_NDIM);
|
||||
int32_t axis = axes[i];
|
||||
LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]);
|
||||
out_idx += idx_val * out_strides[axis];
|
||||
}
|
||||
|
||||
LocT upd_loc = elem_to_loc(
|
||||
out_elem + idx_elem * upd_post_idx_size,
|
||||
upd_shape.data(),
|
||||
upd_strides.data(),
|
||||
upd_ndim);
|
||||
|
||||
Op{}(out + out_idx, upd[upd_loc]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
bool UpdC,
|
||||
bool IdxC,
|
||||
typename LocT>
|
||||
__global__ void scatter_axis(
|
||||
const T* upd,
|
||||
const IdxT* indices,
|
||||
T* out,
|
||||
LocT idx_size_pre,
|
||||
LocT idx_size_axis,
|
||||
LocT idx_size_post,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> upd_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||
int32_t axis,
|
||||
int32_t axis_size,
|
||||
int64_t upd_stride_axis,
|
||||
int64_t idx_stride_axis) {
|
||||
LocT index = cg::this_grid().thread_rank();
|
||||
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||
|
||||
LocT elem_idx = z * idx_size_post;
|
||||
|
||||
LocT idx_loc = y * idx_stride_axis;
|
||||
if constexpr (IdxC) {
|
||||
idx_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
idx_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||
}
|
||||
|
||||
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||
|
||||
LocT upd_loc = y * upd_stride_axis;
|
||||
if constexpr (UpdC) {
|
||||
upd_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
upd_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), upd_strides.data());
|
||||
}
|
||||
|
||||
LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x;
|
||||
|
||||
Op{}(out + out_idx, upd[upd_loc]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/atomic_ops.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct ScatterAssign {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
*out = val;
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterSum {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_add(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterProd {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_prod(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterMax {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_max(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterMin {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_min(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
13
mlx/backend/cuda/device/ternary_ops.cuh
Normal file
13
mlx/backend/cuda/device/ternary_ops.cuh
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Select {
|
||||
template <typename T>
|
||||
__device__ T operator()(bool condition, T x, T y) {
|
||||
return condition ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
337
mlx/backend/cuda/device/unary_ops.cuh
Normal file
337
mlx/backend/cuda/device/unary_ops.cuh
Normal file
@@ -0,0 +1,337 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <math_constants.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||
return x;
|
||||
} else {
|
||||
return abs(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return acosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return asinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return atanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct BitwiseInvert {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return ~x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else if constexpr (is_complex_v<T>) {
|
||||
return T{ceil(x.real()), ceil(x.imag())};
|
||||
} else {
|
||||
return ceil(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
template <typename T>
|
||||
__device__ complex_t<T> operator()(complex_t<T> x) {
|
||||
return conj(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return cos(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return cosh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return erf(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return erf(__bfloat162float(x));
|
||||
} else {
|
||||
return erf(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return erfinv(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return erfinv(__bfloat162float(x));
|
||||
} else {
|
||||
return erfinv(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return exp(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Expm1 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, __half>) {
|
||||
return expm1(__half2float(x));
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return expm1(__bfloat162float(x));
|
||||
} else {
|
||||
return expm1(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_integral_v<T>) {
|
||||
return x;
|
||||
} else if constexpr (is_complex_v<T>) {
|
||||
return T{floor(x.real()), floor(x.imag())};
|
||||
} else {
|
||||
return floor(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Imag {
|
||||
template <typename T>
|
||||
__device__ auto operator()(complex_t<T> x) {
|
||||
return x.imag();
|
||||
}
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
auto y = Log{}(x);
|
||||
return {y.real() / CUDART_LN2_F, y.imag() / CUDART_LN2_F};
|
||||
} else {
|
||||
return log2(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log10(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
__device__ T operator()(T z) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
float x = z.real();
|
||||
float y = z.imag();
|
||||
float zabs = Abs{}(z).real();
|
||||
float theta = atan2f(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1pf(r), theta};
|
||||
} else {
|
||||
float z0 = hypotf(x + 1, y);
|
||||
return {logf(z0), theta};
|
||||
}
|
||||
} else {
|
||||
return log1p(z);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
__device__ bool operator()(bool x) {
|
||||
return !x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
return T{0, 0} - x;
|
||||
} else {
|
||||
return -x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Real {
|
||||
template <typename T>
|
||||
__device__ auto operator()(complex_t<T> x) {
|
||||
return x.real();
|
||||
}
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
return {rint(x.real()), rint(x.imag())};
|
||||
} else {
|
||||
return rint(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
T y = 1 / (1 + exp(-abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_unsigned_v<T>) {
|
||||
return x != 0;
|
||||
} else if constexpr (is_complex_v<T>) {
|
||||
if (x.real() == 0 && x.imag() == 0) {
|
||||
return x;
|
||||
} else {
|
||||
return x / Abs()(x);
|
||||
}
|
||||
} else if constexpr (cuda::std::is_same_v<T, __nv_bfloat16>) {
|
||||
return static_cast<float>((x > T(0.f)) - (x < T(0.f)));
|
||||
} else {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return sin(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return sinh(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return x * x;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (is_complex_v<T>) {
|
||||
return 1.0f / Sqrt{}(x);
|
||||
} else {
|
||||
return rsqrt(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return tan(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return tanh(x);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
447
mlx/backend/cuda/device/utils.cuh
Normal file
447
mlx/backend/cuda/device/utils.cuh
Normal file
@@ -0,0 +1,447 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// This file must not include any host-only code, utilies that work under both
|
||||
// host and device can be put here.
|
||||
//
|
||||
// See more about the requirements at:
|
||||
// https://docs.nvidia.com/cuda/nvrtc/#language
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/complex.cuh"
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/array>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/tuple>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CUDA kernel utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// To pass shape/strides to kernels via constant memory, their size must be
|
||||
// known at compile time.
|
||||
using Shape = cuda::std::array<int32_t, MAX_NDIM>;
|
||||
using Strides = cuda::std::array<int64_t, MAX_NDIM>;
|
||||
|
||||
// Vectorized load/store.
|
||||
template <typename T, int N>
|
||||
struct alignas(sizeof(T) * N) AlignedVector {
|
||||
T val[N];
|
||||
|
||||
__device__ T& operator[](int i) {
|
||||
return val[i];
|
||||
}
|
||||
|
||||
__device__ T operator[](int i) const {
|
||||
return val[i];
|
||||
}
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
inline __host__ __device__ bool is_aligned(T* x) {
|
||||
return (reinterpret_cast<uintptr_t>(x) % (N * sizeof(T))) == 0;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ AlignedVector<T, N> unsafe_load_vector(
|
||||
const T* ptr,
|
||||
uint32_t offset) {
|
||||
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||
return from[offset];
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ AlignedVector<T, N> load_vector(
|
||||
const T* ptr,
|
||||
uint32_t offset) {
|
||||
if (is_aligned<N>(ptr)) {
|
||||
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||
return from[offset];
|
||||
} else {
|
||||
AlignedVector<T, N> v;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
v[i] = ptr[offset * N + i];
|
||||
}
|
||||
return v;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename T, typename SizeT>
|
||||
inline __device__ AlignedVector<T, N>
|
||||
load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) {
|
||||
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
|
||||
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||
return from[offset];
|
||||
} else {
|
||||
AlignedVector<T, N> v;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback;
|
||||
}
|
||||
return v;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename T, typename SizeT>
|
||||
inline __device__ AlignedVector<T, N> load_vector(
|
||||
const T* ptr,
|
||||
uint32_t offset,
|
||||
SizeT size,
|
||||
int64_t stride,
|
||||
T fallback) {
|
||||
if (is_aligned<N>(ptr) && stride == 1 && (offset + 1) * N <= size) {
|
||||
auto* from = reinterpret_cast<const AlignedVector<T, N>*>(ptr);
|
||||
return from[offset];
|
||||
} else {
|
||||
AlignedVector<T, N> v;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
v[i] =
|
||||
(N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback;
|
||||
}
|
||||
return v;
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ void
|
||||
unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
||||
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||
to[offset] = vec;
|
||||
}
|
||||
|
||||
template <int N, typename T>
|
||||
inline __device__ void
|
||||
store_vector(T* ptr, uint32_t offset, const AlignedVector<T, N>& vec) {
|
||||
if (is_aligned<N>(ptr)) {
|
||||
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||
to[offset] = vec;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
ptr[offset * N + i] = vec[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename T, typename SizeT>
|
||||
inline __device__ void store_vector(
|
||||
T* ptr,
|
||||
uint32_t offset,
|
||||
const AlignedVector<T, N>& vec,
|
||||
SizeT size) {
|
||||
if (is_aligned<N>(ptr) && (offset + 1) * N <= size) {
|
||||
auto* to = reinterpret_cast<AlignedVector<T, N>*>(ptr);
|
||||
to[offset] = vec;
|
||||
} else {
|
||||
for (int i = 0; (offset * N + i) < size && i < N; ++i) {
|
||||
ptr[offset * N + i] = vec[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Type limits utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct Limits {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
}
|
||||
};
|
||||
|
||||
// CUDA 11 does not have host side arithmatic operators for half types.
|
||||
template <typename T>
|
||||
struct Limits<
|
||||
T,
|
||||
cuda::std::enable_if_t<
|
||||
cuda::std::is_same_v<T, __half> ||
|
||||
cuda::std::is_same_v<T, __nv_bfloat16>>> {
|
||||
static constexpr __host__ __device__ T max() {
|
||||
return cuda::std::numeric_limits<T>::infinity();
|
||||
}
|
||||
static constexpr __host__ __device__ T min() {
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
return -cuda::std::numeric_limits<float>::infinity();
|
||||
#else
|
||||
return -cuda::std::numeric_limits<T>::infinity();
|
||||
#endif
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_max() {
|
||||
return cuda::std::numeric_limits<T>::max();
|
||||
}
|
||||
static constexpr __host__ __device__ T finite_min() {
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
return cuda::std::numeric_limits<float>::lowest();
|
||||
#else
|
||||
return cuda::std::numeric_limits<T>::lowest();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Limits<bool> {
|
||||
static constexpr __host__ __device__ bool max() {
|
||||
return true;
|
||||
}
|
||||
static constexpr __host__ __device__ bool min() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Limits<complex_t<T>> {
|
||||
static constexpr __host__ __device__ complex_t<T> max() {
|
||||
return {Limits<T>::max(), Limits<T>::max()};
|
||||
}
|
||||
static constexpr __host__ __device__ complex_t<T> min() {
|
||||
return {Limits<T>::min(), Limits<T>::min()};
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) {
|
||||
IdxT loc = 0;
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
// Optimize when the ndim is known at compile time.
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ IdxT
|
||||
elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) {
|
||||
IdxT loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
loc += (elem % shape[i]) * IdxT(strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
template <int NDIM, typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
const int64_t* c_strides) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
IdxT c_loc = 0;
|
||||
#pragma unroll
|
||||
for (int i = NDIM - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
int ndim) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc);
|
||||
}
|
||||
|
||||
template <typename IdxT = int64_t>
|
||||
inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc(
|
||||
IdxT elem,
|
||||
const int* shape,
|
||||
const int64_t* a_strides,
|
||||
const int64_t* b_strides,
|
||||
const int64_t* c_strides,
|
||||
int ndim) {
|
||||
IdxT a_loc = 0;
|
||||
IdxT b_loc = 0;
|
||||
IdxT c_loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
int dim_idx = elem % shape[i];
|
||||
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||
elem /= shape[i];
|
||||
}
|
||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Elem to loc in a loop utils
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int DIM, bool General = true, typename OffsetT = size_t>
|
||||
struct LoopedElemToLoc {
|
||||
int dim;
|
||||
LoopedElemToLoc<DIM - 1, General, OffsetT> inner_looper;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index++;
|
||||
offset += OffsetT(strides[dim - 1]);
|
||||
if (index >= shape[dim - 1]) {
|
||||
index = 0;
|
||||
inner_looper.next(shape, strides);
|
||||
offset = inner_looper.offset;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
if (dim == 0) {
|
||||
return;
|
||||
}
|
||||
index += n;
|
||||
offset += n * OffsetT(strides[dim - 1]);
|
||||
|
||||
if (index >= shape[dim - 1]) {
|
||||
int extra = index - shape[dim - 1];
|
||||
if (extra >= shape[dim - 1]) {
|
||||
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
||||
extra = extra % shape[dim - 1];
|
||||
} else {
|
||||
inner_looper.next(shape, strides);
|
||||
}
|
||||
index = 0;
|
||||
offset = inner_looper.offset;
|
||||
if (extra > 0) {
|
||||
next(extra, shape, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, true, OffsetT> {
|
||||
int dim;
|
||||
OffsetT offset{0};
|
||||
int index{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int dim) : dim(dim) {}
|
||||
|
||||
__device__ void next(const int* shape, const int64_t* strides) {
|
||||
index++;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int* shape, const int64_t* strides) {
|
||||
index += n;
|
||||
if (dim > 1) {
|
||||
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
||||
} else {
|
||||
offset = index * OffsetT(strides[0]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OffsetT>
|
||||
struct LoopedElemToLoc<1, false, OffsetT> {
|
||||
OffsetT offset{0};
|
||||
|
||||
__device__ LoopedElemToLoc(int) {}
|
||||
|
||||
__device__ void next(const int*, const int64_t* strides) {
|
||||
offset += OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ void next(int n, const int*, const int64_t* strides) {
|
||||
offset += n * OffsetT(strides[0]);
|
||||
}
|
||||
|
||||
__device__ OffsetT location() {
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
61
mlx/backend/cuda/eval.cpp
Normal file
61
mlx/backend/cuda/eval.cpp
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/eval.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/available.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core::gpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
void new_stream(Stream s) {
|
||||
// Force initalization of cuda, so cuda runtime get destroyed at last.
|
||||
cudaFree(nullptr);
|
||||
// Ensure the static stream objects get created.
|
||||
cu::get_command_encoder(s);
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
nvtx3::scoped_range r("gpu::eval");
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
|
||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
for (auto& in : arr.inputs()) {
|
||||
// Except for the donated one.
|
||||
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
encoder.add_temporary(s);
|
||||
}
|
||||
encoder.maybe_commit();
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::finalize");
|
||||
cu::get_command_encoder(s).commit();
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
nvtx3::scoped_range r("gpu::synchronize");
|
||||
cu::get_command_encoder(s).synchronize();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::gpu
|
267
mlx/backend/cuda/event.cu
Normal file
267
mlx/backend/cuda/event.cu
Normal file
@@ -0,0 +1,267 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CudaEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Cuda event managed with RAII.
|
||||
class CudaEventHandle {
|
||||
public:
|
||||
CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
|
||||
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
|
||||
}
|
||||
|
||||
~CudaEventHandle() {
|
||||
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
|
||||
}
|
||||
|
||||
CudaEventHandle(const CudaEventHandle&) = delete;
|
||||
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
|
||||
|
||||
operator cudaEvent_t() const {
|
||||
return event_;
|
||||
}
|
||||
|
||||
private:
|
||||
cudaEvent_t event_;
|
||||
};
|
||||
|
||||
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
|
||||
|
||||
void CudaEvent::wait() {
|
||||
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaEventSynchronize(*event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(cudaStream_t stream) {
|
||||
if (!recorded_) {
|
||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
||||
}
|
||||
cudaStreamWaitEvent(stream, *event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this]() mutable { wait(); });
|
||||
} else {
|
||||
auto& enc = cu::get_command_encoder(s);
|
||||
enc.commit();
|
||||
wait(enc.stream());
|
||||
}
|
||||
}
|
||||
|
||||
void CudaEvent::record(cudaStream_t stream) {
|
||||
cudaEventRecord(*event_, stream);
|
||||
recorded_ = true;
|
||||
}
|
||||
|
||||
void CudaEvent::record(Stream s) {
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
|
||||
} else {
|
||||
auto& enc = cu::get_command_encoder(s);
|
||||
enc.commit();
|
||||
record(enc.stream());
|
||||
}
|
||||
}
|
||||
|
||||
bool CudaEvent::completed() const {
|
||||
return cudaEventQuery(*event_) == cudaSuccess;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SharedEvent implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
uint64_t current;
|
||||
while ((current = ac->load()) < value) {
|
||||
ac->wait(current);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
ac->store(value);
|
||||
ac->notify_all();
|
||||
}
|
||||
|
||||
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
event_wait(ac, value);
|
||||
}
|
||||
|
||||
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
event_signal(ac, value);
|
||||
}
|
||||
|
||||
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
|
||||
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
|
||||
}
|
||||
|
||||
SharedEvent::SharedEvent() {
|
||||
buf_ = std::shared_ptr<Buffer>(
|
||||
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
|
||||
allocator().free(*ptr);
|
||||
delete ptr;
|
||||
});
|
||||
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
|
||||
}
|
||||
|
||||
void SharedEvent::wait(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
||||
event_wait(to_atomic(buf_), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.commit();
|
||||
wait(encoder.stream(), value);
|
||||
encoder.add_completed_handler([buf = buf_]() {});
|
||||
}
|
||||
}
|
||||
|
||||
void SharedEvent::signal(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
||||
event_signal(to_atomic(buf_), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
||||
if (s.device == mlx::core::Device::cpu) {
|
||||
// Signal through a GPU stream so the atomic is updated in GPU - updating
|
||||
// the atomic in CPU sometimes does not get GPU notified.
|
||||
static CudaStream stream(device(mlx::core::Device::gpu));
|
||||
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
|
||||
} else {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.commit();
|
||||
signal(encoder.stream(), value);
|
||||
encoder.add_completed_handler([buf = buf_]() {});
|
||||
}
|
||||
}
|
||||
|
||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
||||
return to_atomic(buf_)->load() >= value;
|
||||
}
|
||||
|
||||
uint64_t SharedEvent::value() const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
||||
return to_atomic(buf_)->load();
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Event implementations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace {
|
||||
|
||||
struct EventImpl {
|
||||
// CudaEvent is preferred when possible because it is fast, however we have
|
||||
// to fallback to SharedEvent in following cases:
|
||||
// 1. the event is used to wait/signal a cpu stream;
|
||||
// 2. signal value other than 1 has been specified.
|
||||
std::unique_ptr<cu::CudaEvent> cuda;
|
||||
std::unique_ptr<cu::SharedEvent> shared;
|
||||
|
||||
bool is_created() const {
|
||||
return cuda || shared;
|
||||
}
|
||||
|
||||
void ensure_created(Stream s, uint64_t signal_value) {
|
||||
if (is_created()) {
|
||||
return;
|
||||
}
|
||||
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
|
||||
nvtx3::mark("Using slow SharedEvent");
|
||||
shared = std::make_unique<cu::SharedEvent>();
|
||||
} else {
|
||||
cuda = std::make_unique<cu::CudaEvent>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Event::Event(Stream s) : stream_(s) {
|
||||
event_ = std::shared_ptr<void>(
|
||||
new EventImpl(), [](void* ptr) { delete static_cast<EventImpl*>(ptr); });
|
||||
}
|
||||
|
||||
void Event::wait() {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
assert(event->is_created());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait();
|
||||
} else {
|
||||
event->shared->wait(value());
|
||||
}
|
||||
}
|
||||
|
||||
void Event::wait(Stream s) {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
assert(event->is_created());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->wait(s);
|
||||
} else {
|
||||
event->shared->wait(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal(Stream s) {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
event->ensure_created(s, value());
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
event->cuda->record(s);
|
||||
} else {
|
||||
event->shared->signal(s, value());
|
||||
}
|
||||
}
|
||||
|
||||
bool Event::is_signaled() const {
|
||||
auto* event = static_cast<EventImpl*>(event_.get());
|
||||
if (!event->is_created()) {
|
||||
return false;
|
||||
}
|
||||
if (event->cuda) {
|
||||
assert(value() == 1);
|
||||
return event->cuda->recorded() && event->cuda->completed();
|
||||
} else {
|
||||
return event->shared->is_signaled(value());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user