mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
260 Commits
gguf_q4_k
...
sdpav-back
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a22d0bf273 | ||
|
|
99d8de8445 | ||
|
|
c66b76a8c8 | ||
|
|
f81edd184f | ||
|
|
7f8ba2a003 | ||
|
|
c28249b81a | ||
|
|
e74bcdc5e3 | ||
|
|
d8ed6c1aa3 | ||
|
|
db5c7efcf6 | ||
|
|
7bb96e4249 | ||
|
|
fa89f0b150 | ||
|
|
ca973d1e83 | ||
|
|
828c5f1137 | ||
|
|
7d86a5c108 | ||
|
|
0b807893a7 | ||
|
|
6ad0889c8a | ||
|
|
737dd6d1ac | ||
|
|
aaf78f4c6b | ||
|
|
8831064493 | ||
|
|
be9bc96da4 | ||
|
|
86258f292f | ||
|
|
b26d88591c | ||
|
|
86c6a15571 | ||
|
|
8b25ce62d5 | ||
|
|
da5912e4f2 | ||
|
|
daafee676f | ||
|
|
d32519c8ee | ||
|
|
b405591249 | ||
|
|
3bf81ed1bd | ||
|
|
2204182bba | ||
|
|
3628e5d497 | ||
|
|
a0ae49d397 | ||
|
|
254476718b | ||
|
|
3adba92ebe | ||
|
|
ef631d63af | ||
|
|
970dbe8e25 | ||
|
|
641be9463b | ||
|
|
ab0e608862 | ||
|
|
1588659062 | ||
|
|
b9e88fb976 | ||
|
|
4ad53414dd | ||
|
|
d1165b215e | ||
|
|
dcb8319f3d | ||
|
|
5597fa089c | ||
|
|
9acec364c2 | ||
|
|
7d9d6ef456 | ||
|
|
6f5874a2f2 | ||
|
|
70dc336785 | ||
|
|
4e504039f5 | ||
|
|
d1f4d291e8 | ||
|
|
e1840853ce | ||
|
|
0f5ce173da | ||
|
|
588854195f | ||
|
|
28d068bce6 | ||
|
|
d107d8d495 | ||
|
|
1e496ddb82 | ||
|
|
74eccbf3fa | ||
|
|
08638223ca | ||
|
|
56cc858af9 | ||
|
|
f55c4ed1d6 | ||
|
|
93d70419e7 | ||
|
|
63f663d9c6 | ||
|
|
84b4d96efa | ||
|
|
aec67f2fa6 | ||
|
|
deee214a95 | ||
|
|
45adec102c | ||
|
|
31fc530c76 | ||
|
|
fbb3f65a1a | ||
|
|
6b1b8ea91b | ||
|
|
b2273733ea | ||
|
|
f409b229a4 | ||
|
|
30571e2326 | ||
|
|
d7734edd9f | ||
|
|
2ba69bc8fa | ||
|
|
cb349a291c | ||
|
|
f0a0b077a0 | ||
|
|
49114f28ab | ||
|
|
e7d2ebadd2 | ||
|
|
e569803d7c | ||
|
|
d34f887abc | ||
|
|
5201df5030 | ||
|
|
2d3c26c565 | ||
|
|
6325f60d52 | ||
|
|
42cc9cfbc7 | ||
|
|
8347575ba1 | ||
|
|
b6eec20260 | ||
|
|
0eb035b4b1 | ||
|
|
afb9817599 | ||
|
|
8fb3e7a26c | ||
|
|
8c7bc30ce4 | ||
|
|
85873cb162 | ||
|
|
e14ee12491 | ||
|
|
8b9a3f3cea | ||
|
|
fb4e8b896b | ||
|
|
2ca533b279 | ||
|
|
4a9b29a875 | ||
|
|
a4fcc893cd | ||
|
|
9d10239af7 | ||
|
|
19facd4b20 | ||
|
|
f5299f72cd | ||
|
|
0e0d9ac522 | ||
|
|
8917022deb | ||
|
|
ec0d5db67b | ||
|
|
e76e9b87f0 | ||
|
|
cfb6a244ea | ||
|
|
58f3860306 | ||
|
|
dd4f53db63 | ||
|
|
3d5e17e507 | ||
|
|
33bf1a244b | ||
|
|
772f471ff2 | ||
|
|
2c11d10f8d | ||
|
|
656ed7f780 | ||
|
|
81bb9a2a9e | ||
|
|
5adf185f86 | ||
|
|
c9a9180584 | ||
|
|
76831ed83d | ||
|
|
b3d7b85376 | ||
|
|
cad5c0241c | ||
|
|
b8022c578a | ||
|
|
bc53f8293f | ||
|
|
c552ff2451 | ||
|
|
4fda5fbdf9 | ||
|
|
580776559b | ||
|
|
a14aaa7c9d | ||
|
|
a6d780154f | ||
|
|
6871e2eeb7 | ||
|
|
8402a2acf4 | ||
|
|
fddb6933e1 | ||
|
|
c8b4787e4e | ||
|
|
2188199ff8 | ||
|
|
aa07429bad | ||
|
|
918761a25a | ||
|
|
a4fc671d3e | ||
|
|
f5f65ef48c | ||
|
|
c2dd81a8aa | ||
|
|
d7e680ffe4 | ||
|
|
c371baf53a | ||
|
|
ccf78f566c | ||
|
|
c9fa68664a | ||
|
|
c35f4d089a | ||
|
|
8590c0941e | ||
|
|
095163b8d1 | ||
|
|
99c33d011d | ||
|
|
62fecf3e13 | ||
|
|
7c4eb5d03e | ||
|
|
bae9a6b404 | ||
|
|
004c1d8ef2 | ||
|
|
7ebb2e0193 | ||
|
|
9ce77798b1 | ||
|
|
f8bad60609 | ||
|
|
5866b3857b | ||
|
|
1ca616844b | ||
|
|
2e8cf0b450 | ||
|
|
24f89173d1 | ||
|
|
c6a20b427a | ||
|
|
a5ac9244c4 | ||
|
|
c763fe1be0 | ||
|
|
52dc8c8cd5 | ||
|
|
aede70e81d | ||
|
|
85a8beb5e4 | ||
|
|
0bb89e9e5f | ||
|
|
5685ceb3c7 | ||
|
|
0408ba0a76 | ||
|
|
cbad6c3093 | ||
|
|
1b021f6984 | ||
|
|
95b7551d65 | ||
|
|
db5a7c6192 | ||
|
|
6ef2f67e7f | ||
|
|
f76ee1ffd2 | ||
|
|
54a71f270a | ||
|
|
55b4062dd8 | ||
|
|
79071bfba4 | ||
|
|
7774b87cbd | ||
|
|
35c87741cf | ||
|
|
4cbe605214 | ||
|
|
ab8883dd55 | ||
|
|
eebe73001a | ||
|
|
0359bf02c9 | ||
|
|
237f9e58a8 | ||
|
|
8576e6fe36 | ||
|
|
0654543dcc | ||
|
|
48ef3e74e2 | ||
|
|
7d4b378952 | ||
|
|
7ff5c41e06 | ||
|
|
602f43e3d1 | ||
|
|
a2cadb8218 | ||
|
|
c1eb9d05d9 | ||
|
|
cf6c939e86 | ||
|
|
130df35e1b | ||
|
|
0751263dec | ||
|
|
eca2f3eb97 | ||
|
|
3aa9cf3f9e | ||
|
|
8f3d208dce | ||
|
|
caaa3f1f8c | ||
|
|
659a51919f | ||
|
|
6661387066 | ||
|
|
a7fae8a176 | ||
|
|
0cae0bdac8 | ||
|
|
5a1a5d5ed1 | ||
|
|
1683975acf | ||
|
|
af705590ac | ||
|
|
825124af8f | ||
|
|
9c5e7da507 | ||
|
|
481349495b | ||
|
|
9daa6b003f | ||
|
|
a3a632d567 | ||
|
|
e496c5a4b4 | ||
|
|
ea890d8710 | ||
|
|
aa5d84f102 | ||
|
|
f1606486d2 | ||
|
|
87720a8908 | ||
|
|
bb6565ef14 | ||
|
|
7bb063bcb3 | ||
|
|
b36dd472bb | ||
|
|
167b759a38 | ||
|
|
99b9868859 | ||
|
|
6b2d5448f2 | ||
|
|
eaf709b83e | ||
|
|
f0e70afff0 | ||
|
|
86984cad68 | ||
|
|
fbc89e3ced | ||
|
|
38c1e720c2 | ||
|
|
600e87e03c | ||
|
|
3836445241 | ||
|
|
1d2c9d6a07 | ||
|
|
e8ac6bd2f5 | ||
|
|
fdadc4f22c | ||
|
|
79b527f45f | ||
|
|
dc4eada7f0 | ||
|
|
70ebc3b598 | ||
|
|
b13f2aed16 | ||
|
|
5f04c0f818 | ||
|
|
55935ccae7 | ||
|
|
b529515eb1 | ||
|
|
3cde719eb7 | ||
|
|
5de6d94a90 | ||
|
|
99eefd2ec0 | ||
|
|
e9e268336b | ||
|
|
7275ac7523 | ||
|
|
c4189a38e4 | ||
|
|
68d1b3256b | ||
|
|
9c6953bda7 | ||
|
|
ef7ece9851 | ||
|
|
ddaa4b7dcb | ||
|
|
dfae2c6989 | ||
|
|
515f104926 | ||
|
|
9ecefd56db | ||
|
|
e5d35aa187 | ||
|
|
00794c42bc | ||
|
|
08a1bf3f10 | ||
|
|
60c4154346 | ||
|
|
f2c85308c1 | ||
|
|
1a28b69ee2 | ||
|
|
ba09f01ce8 | ||
|
|
6cf48872b7 | ||
|
|
7b3b8fa000 | ||
|
|
ec5e2aae61 | ||
|
|
86389bf970 | ||
|
|
3290bfa690 | ||
|
|
8777fd104f |
@@ -7,15 +7,9 @@ parameters:
|
|||||||
nightly_build:
|
nightly_build:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
weekly_build:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
test_release:
|
test_release:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
linux_release:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
build_documentation:
|
||||||
@@ -38,7 +32,7 @@ jobs:
|
|||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install -r docs/requirements.txt
|
pip install -r docs/requirements.txt
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
pip install . -v
|
||||||
- when:
|
- when:
|
||||||
condition:
|
condition:
|
||||||
not: << parameters.upload-docs >>
|
not: << parameters.upload-docs >>
|
||||||
@@ -70,9 +64,9 @@ jobs:
|
|||||||
git push -f origin gh-pages
|
git push -f origin gh-pages
|
||||||
|
|
||||||
linux_build_and_test:
|
linux_build_and_test:
|
||||||
docker:
|
machine:
|
||||||
- image: cimg/python:3.9
|
image: ubuntu-2204:current
|
||||||
|
resource_class: large
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
@@ -84,37 +78,36 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
pip install --upgrade cmake
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
pip install nanobind==2.4.0
|
export NEEDRESTART_MODE=a
|
||||||
pip install numpy
|
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
uv venv
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
uv pip install cmake
|
||||||
python3 setup.py build_ext --inplace
|
uv pip install -e ".[dev]" -v
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
python3 setup.py develop
|
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
echo "stubs"
|
uv pip install typing_extensions
|
||||||
pip install typing_extensions
|
uv run --no-project setup.py generate_stubs
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
python3 -m unittest discover python/tests -v
|
source .venv/bin/activate
|
||||||
|
python -m unittest discover python/tests -v
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
mkdir -p build && cd build
|
source .venv/bin/activate
|
||||||
|
mkdir -p build && cd build
|
||||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||||
make -j `nproc`
|
make -j `nproc`
|
||||||
- run:
|
- run:
|
||||||
@@ -139,51 +132,49 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
brew install python@3.9
|
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
||||||
brew install openmpi
|
brew install openmpi uv
|
||||||
python3.9 -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install nanobind==2.4.0
|
|
||||||
pip install numpy
|
|
||||||
pip install torch
|
|
||||||
pip install tensorflow
|
|
||||||
pip install unittest-xml-reporting
|
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
uv venv --python 3.9
|
||||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
uv pip install \
|
||||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
nanobind==2.4.0 \
|
||||||
pip install -e . -v
|
cmake \
|
||||||
|
numpy \
|
||||||
|
torch \
|
||||||
|
tensorflow \
|
||||||
|
unittest-xml-reporting
|
||||||
|
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||||
|
uv pip install -e . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
uv pip install typing_extensions
|
||||||
pip install typing_extensions
|
uv run --no-project setup.py generate_stubs
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||||
- run:
|
- run:
|
||||||
name: Build example extension
|
name: Build example extension
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
cd examples/extensions
|
cd examples/extensions
|
||||||
pip install -r requirements.txt
|
uv pip install -r requirements.txt
|
||||||
python setup.py build_ext -j8
|
uv run --no-project setup.py build_ext --inplace
|
||||||
|
uv run --no-project python test.py
|
||||||
- store_test_results:
|
- store_test_results:
|
||||||
path: test-results
|
path: test-results
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||||
- run:
|
- run:
|
||||||
name: Run CPP tests
|
name: Run CPP tests
|
||||||
@@ -192,7 +183,7 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Build small binary
|
name: Build small binary
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
cd build/
|
cd build/
|
||||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
-DBUILD_SHARED_LIBS=ON \
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
@@ -204,13 +195,60 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Run Python tests with JIT
|
name: Run Python tests with JIT
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
uv pip install -e .
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
|
||||||
pip install -e . -v
|
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||||
METAL_DEBUG_ERROR_MODE=0 \
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
uv run --no-project python -m xmlrunner discover \
|
||||||
|
-v python/tests \
|
||||||
|
-o test-results/gpu_jit
|
||||||
|
|
||||||
|
cuda_build_and_test:
|
||||||
|
parameters:
|
||||||
|
image_date:
|
||||||
|
type: string
|
||||||
|
default: "2023.11.1"
|
||||||
|
machine:
|
||||||
|
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||||
|
resource_class: gpu.nvidia.small.gen2
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- restore_cache:
|
||||||
|
keys:
|
||||||
|
- cuda-<< parameters.image_date >>-{{ arch }}-
|
||||||
|
- run:
|
||||||
|
name: Install dependencies
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libcudnn9-dev-cuda-12
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||||
|
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||||
|
rm -rf ccache-4.11.3-linux-x86_64
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
- run:
|
||||||
|
name: Install Python package
|
||||||
|
command: |
|
||||||
|
uv venv
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
uv pip install -e ".[dev]" -v
|
||||||
|
- run:
|
||||||
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||||
|
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||||
|
- run:
|
||||||
|
name: CCache report
|
||||||
|
command: |
|
||||||
|
ccache --show-stats
|
||||||
|
ccache --zero-stats
|
||||||
|
ccache --max-size 400MB
|
||||||
|
ccache --cleanup
|
||||||
|
- save_cache:
|
||||||
|
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||||
|
paths:
|
||||||
|
- /home/circleci/.cache/ccache
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
parameters:
|
parameters:
|
||||||
@@ -251,22 +289,30 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
DEV_RELEASE=1 \
|
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
pip install . -v
|
pip install . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
<< parameters.build_env >> \
|
python setup.py clean --all
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||||
python -m build -w
|
- when:
|
||||||
|
condition:
|
||||||
|
equal: ["3.9", << parameters.python_version >>]
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Build common package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
python setup.py clean --all
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||||
- when:
|
- when:
|
||||||
condition: << parameters.build_env >>
|
condition: << parameters.build_env >>
|
||||||
steps:
|
steps:
|
||||||
@@ -283,52 +329,100 @@ jobs:
|
|||||||
python_version:
|
python_version:
|
||||||
type: string
|
type: string
|
||||||
default: "3.9"
|
default: "3.9"
|
||||||
extra_env:
|
build_env:
|
||||||
type: string
|
type: string
|
||||||
default: "DEV_RELEASE=1"
|
default: ""
|
||||||
docker:
|
machine:
|
||||||
- image: ubuntu:20.04
|
image: ubuntu-2204:current
|
||||||
|
resource_class: large
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Build wheel
|
name: Build wheel
|
||||||
command: |
|
command: |
|
||||||
PYTHON=python<< parameters.python_version >>
|
PYTHON=python<< parameters.python_version >>
|
||||||
apt-get update
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
apt-get upgrade -y
|
export NEEDRESTART_MODE=a
|
||||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
sudo apt-get update
|
||||||
apt-get install -y apt-utils
|
TZ=Etc/UTC sudo apt-get -y install tzdata
|
||||||
apt-get install -y software-properties-common
|
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||||
add-apt-repository -y ppa:deadsnakes/ppa
|
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
apt-get install -y build-essential git
|
|
||||||
$PYTHON -m venv env
|
$PYTHON -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.4.0
|
|
||||||
pip install --upgrade setuptools
|
|
||||||
pip install numpy
|
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
pip install patchelf
|
pip install patchelf
|
||||||
pip install build
|
pip install build
|
||||||
pip install twine
|
pip install twine
|
||||||
<< parameters.extra_env >> \
|
<< parameters.build_env >> pip install ".[dev]" -v
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
pip install . -v
|
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> \
|
python setup.py clean --all
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||||
python -m build --wheel
|
bash python/scripts/repair_linux.sh
|
||||||
auditwheel show dist/*
|
- when:
|
||||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
condition:
|
||||||
|
equal: ["3.9", << parameters.python_version >>]
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Build common package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
python setup.py clean --all
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||||
|
python -m build -w
|
||||||
|
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||||
|
- when:
|
||||||
|
condition: << parameters.build_env >>
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Upload packages
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload wheelhouse/*.whl
|
||||||
|
- store_artifacts:
|
||||||
|
path: wheelhouse/
|
||||||
|
|
||||||
|
build_cuda_release:
|
||||||
|
parameters:
|
||||||
|
build_env:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
|
machine:
|
||||||
|
image: ubuntu-2204:current
|
||||||
|
resource_class: large
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Upload package
|
name: Build wheel
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
twine upload wheelhouse/*
|
export NEEDRESTART_MODE=a
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install zip
|
||||||
|
pip install auditwheel
|
||||||
|
pip install patchelf
|
||||||
|
pip install build
|
||||||
|
pip install twine
|
||||||
|
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
python -m build -w
|
||||||
|
bash python/scripts/repair_cuda.sh
|
||||||
|
- when:
|
||||||
|
condition: << parameters.build_env >>
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Upload package
|
||||||
|
command: |
|
||||||
|
twine upload wheelhouse/*.whl
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
@@ -340,7 +434,6 @@ workflows:
|
|||||||
pattern: "^(?!pull/)[-\\w]+$"
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
value: << pipeline.git.branch >>
|
value: << pipeline.git.branch >>
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- mac_build_and_test:
|
- mac_build_and_test:
|
||||||
@@ -348,13 +441,16 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
macosx_deployment_target: ["13.5", "14.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
|
- cuda_build_and_test:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
image_date: ["2023.11.1", "2025.05.1"]
|
||||||
- build_documentation
|
- build_documentation
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_release:
|
- build_release:
|
||||||
@@ -368,6 +464,68 @@ workflows:
|
|||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
build_env: ["PYPI_RELEASE=1"]
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
xcode_version: ["16.2.0", "15.0.0"]
|
||||||
|
exclude:
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
- build_documentation:
|
- build_documentation:
|
||||||
filters:
|
filters:
|
||||||
tags:
|
tags:
|
||||||
@@ -375,6 +533,25 @@ workflows:
|
|||||||
branches:
|
branches:
|
||||||
ignore: /.*/
|
ignore: /.*/
|
||||||
upload-docs: true
|
upload-docs: true
|
||||||
|
- build_linux_release:
|
||||||
|
filters:
|
||||||
|
tags:
|
||||||
|
only: /^v.*/
|
||||||
|
branches:
|
||||||
|
ignore: /.*/
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
- build_cuda_release:
|
||||||
|
filters:
|
||||||
|
tags:
|
||||||
|
only: /^v.*/
|
||||||
|
branches:
|
||||||
|
ignore: /.*/
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
|
||||||
prb:
|
prb:
|
||||||
when:
|
when:
|
||||||
@@ -393,6 +570,11 @@ workflows:
|
|||||||
macosx_deployment_target: ["13.5", "14.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test:
|
- linux_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
|
- cuda_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
image_date: ["2023.11.1", "2025.05.1"]
|
||||||
nightly_build:
|
nightly_build:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
@@ -404,11 +586,64 @@ workflows:
|
|||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
weekly_build:
|
xcode_version: ["16.2.0", "15.0.0"]
|
||||||
|
exclude:
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
- build_linux_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
- build_cuda_release
|
||||||
|
|
||||||
|
build_dev_release:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
- << pipeline.parameters.weekly_build >>
|
- << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_release:
|
- build_release:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -416,14 +651,74 @@ workflows:
|
|||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
build_env: ["DEV_RELEASE=1"]
|
build_env: ["DEV_RELEASE=1"]
|
||||||
linux_test_release:
|
xcode_version: ["16.2.0", "15.0.0"]
|
||||||
when:
|
exclude:
|
||||||
and:
|
- macosx_deployment_target: "13.5"
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
xcode_version: "16.2.0"
|
||||||
- << pipeline.parameters.linux_release >>
|
python_version: "3.9"
|
||||||
jobs:
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
- build_linux_release:
|
- build_linux_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
- build_cuda_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
|||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
|
uv.lock
|
||||||
|
|
||||||
# vim
|
# vim
|
||||||
*.swp
|
*.swp
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
|||||||
@@ -34,13 +34,16 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
|||||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||||
|
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||||
|
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
message(
|
message(
|
||||||
@@ -63,10 +66,17 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|||||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
else()
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
endif()
|
||||||
|
|
||||||
|
if(MLX_USE_CCACHE)
|
||||||
|
find_program(CCACHE_PROGRAM ccache)
|
||||||
|
if(CCACHE_PROGRAM)
|
||||||
|
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Lib -----------------------------
|
# ----------------------------- Lib -----------------------------
|
||||||
@@ -83,6 +93,10 @@ if(MLX_BUILD_METAL)
|
|||||||
set(QUARTZ_LIB "-framework QuartzCore")
|
set(QUARTZ_LIB "-framework QuartzCore")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
enable_language(CUDA)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
message(STATUS "Metal not found. Unable to build GPU")
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
@@ -226,12 +240,19 @@ target_include_directories(
|
|||||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
$<INSTALL_INTERFACE:include>)
|
$<INSTALL_INTERFACE:include>)
|
||||||
|
|
||||||
FetchContent_Declare(
|
# Do not add mlx_EXPORTS define for shared library.
|
||||||
fmt
|
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
|
||||||
GIT_TAG 10.2.1
|
if(USE_SYSTEM_FMT)
|
||||||
EXCLUDE_FROM_ALL)
|
find_package(fmt REQUIRED)
|
||||||
FetchContent_MakeAvailable(fmt)
|
else()
|
||||||
|
FetchContent_Declare(
|
||||||
|
fmt
|
||||||
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
|
GIT_TAG 10.2.1
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
endif()
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||||
|
|
||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
include CMakeLists.txt
|
include CMakeLists.txt
|
||||||
|
include mlx.pc.in
|
||||||
recursive-include mlx/ *
|
recursive-include mlx/ *
|
||||||
|
include cmake/*
|
||||||
include python/src/*
|
include python/src/*
|
||||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||||
|
|||||||
21
README.md
21
README.md
@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
|
|||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
|
||||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||||
more complex models.
|
more complex models.
|
||||||
|
|
||||||
@@ -68,18 +68,23 @@ in the documentation.
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||||
|
macOS, run:
|
||||||
|
|
||||||
**With `pip`**:
|
```bash
|
||||||
|
|
||||||
```
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
**With `conda`**:
|
To install the CUDA backend on Linux, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cuda]
|
||||||
```
|
```
|
||||||
conda install -c conda-forge mlx
|
|
||||||
|
To install a CPU-only Linux package, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cpu]
|
||||||
```
|
```
|
||||||
|
|
||||||
Checkout the
|
Checkout the
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
|||||||
@@ -192,6 +192,22 @@ void time_reductions() {
|
|||||||
|
|
||||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||||
TIME(argmin_along_1);
|
TIME(argmin_along_1);
|
||||||
|
|
||||||
|
auto indices = mx::array({1});
|
||||||
|
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||||
|
std::vector<int> axes{0};
|
||||||
|
auto b = scatter(a, {indices}, updates, axes);
|
||||||
|
mx::eval(b);
|
||||||
|
|
||||||
|
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||||
|
TIME(max_along_0);
|
||||||
|
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||||
|
TIME(max_along_1);
|
||||||
|
|
||||||
|
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||||
|
TIME(min_along_0);
|
||||||
|
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||||
|
TIME(min_along_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_gather_scatter() {
|
void time_gather_scatter() {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.cuda
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
@@ -44,8 +45,10 @@ def bench(f, *args):
|
|||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
def sync_if_needed(x):
|
||||||
if x.device != torch.device("cpu"):
|
if x.device == torch.device("mps"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
elif x.device == torch.device("cuda"):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
@@ -340,7 +351,11 @@ if __name__ == "__main__":
|
|||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "cpu" if args.cpu else "mps"
|
device = "mps"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
if args.cpu:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
types = args.dtype
|
types = args.dtype
|
||||||
if not types:
|
if not types:
|
||||||
@@ -460,5 +475,8 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
|||||||
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dtype = "float32"
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 21, 3, 3, 128),
|
||||||
|
(4, 32, 32, 21, 3, 3, 37),
|
||||||
|
(4, 32, 32, 370, 3, 3, 370),
|
||||||
|
(4, 32, 32, 370, 7, 7, 128),
|
||||||
|
(2, 320, 640, 21, 7, 7, 21),
|
||||||
|
)
|
||||||
|
for N, H, W, C, kh, kw, O in shapes:
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_mm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = x @ w1.T
|
||||||
|
x = x @ w2.T
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_mm()
|
||||||
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate(
|
||||||
|
[
|
||||||
|
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||||
|
for i, j in enumerate(idx.tolist())
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_qmm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||||
|
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_qmm()
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
@@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def time_layer_norm():
|
def time_layer_norm(N, dt):
|
||||||
|
L = 1024
|
||||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_loop(g, x, w, b):
|
def layer_norm_loop(f, x, w, b):
|
||||||
|
for _ in range(32):
|
||||||
|
x = f(x, w, b)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||||
|
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||||
|
|
||||||
|
def layer_norm_grad_loop(g, x, w, b):
|
||||||
gx, gw, gb = x, w, b
|
gx, gw, gb = x, w, b
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx, gw, gb = g(gx, gw, gb, y)
|
gx, gw, gb = g(gx, gw, gb, y)
|
||||||
return gx, gw, gb
|
return gx, gw, gb
|
||||||
|
|
||||||
time_fn(layer_norm_loop, g1, x, w, b)
|
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||||
time_fn(layer_norm_loop, g2, x, w, b)
|
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||||
|
|
||||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0,))
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
g2 = mx.grad(f2, argnums=(0,))
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_loop(g, x):
|
def layer_norm_grad_x_loop(g, x):
|
||||||
gx = x
|
gx = x
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx = g(gx, y)
|
gx = g(gx, y)
|
||||||
return gx
|
return gx
|
||||||
|
|
||||||
time_fn(layer_norm_loop, g1, x)
|
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||||
time_fn(layer_norm_loop, g2, x)
|
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
time_layer_norm()
|
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||||
|
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||||
|
print(dt, n)
|
||||||
|
time_layer_norm(n, dt)
|
||||||
|
|||||||
@@ -51,6 +51,20 @@ def time_maximum():
|
|||||||
time_fn(mx.maximum, a, b)
|
time_fn(mx.maximum, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_max():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.max, a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def time_min():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.min, a, 0)
|
||||||
|
|
||||||
|
|
||||||
def time_negative():
|
def time_negative():
|
||||||
a = mx.random.uniform(shape=(10000, 1000))
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
@@ -108,6 +122,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
time_add()
|
time_add()
|
||||||
time_matmul()
|
time_matmul()
|
||||||
|
time_min()
|
||||||
|
time_max()
|
||||||
time_maximum()
|
time_maximum()
|
||||||
time_exp()
|
time_exp()
|
||||||
time_negative()
|
time_negative()
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ include(CMakeParseArguments)
|
|||||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||||
# files (like headers)
|
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
||||||
|
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
||||||
#
|
#
|
||||||
# clang format on
|
# clang format on
|
||||||
|
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
||||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
|
|
||||||
@@ -26,6 +27,10 @@ macro(mlx_build_metallib)
|
|||||||
|
|
||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||||
|
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
||||||
|
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
||||||
|
-frecord-sources)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import mlx.core as mx
|
|||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "MLX"
|
project = "MLX"
|
||||||
copyright = "2023, MLX Contributors"
|
copyright = "2023, Apple"
|
||||||
author = "MLX Contributors"
|
author = "MLX Contributors"
|
||||||
version = ".".join(mx.__version__.split(".")[:3])
|
version = ".".join(mx.__version__.split(".")[:3])
|
||||||
release = version
|
release = version
|
||||||
|
|||||||
@@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
|||||||
Simple Example
|
Simple Example
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
source = """
|
||||||
source = """
|
uint elem = thread_position_in_grid.x;
|
||||||
uint elem = thread_position_in_grid.x;
|
T tmp = inp[elem];
|
||||||
T tmp = inp[elem];
|
out[elem] = metal::exp(tmp);
|
||||||
out[elem] = metal::exp(tmp);
|
"""
|
||||||
"""
|
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="myexp",
|
name="myexp",
|
||||||
input_names=["inp"],
|
input_names=["inp"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
b = exp_elementwise(a)
|
b = exp_elementwise(a)
|
||||||
assert mx.allclose(b, mx.exp(a))
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
|
Every time you make a kernel, a new Metal library is created and possibly
|
||||||
|
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||||
|
:func:`fast.metal_kernel` and then use it many times.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
We are only required to pass the body of the Metal kernel in ``source``.
|
Only pass the body of the Metal kernel in ``source``. The function
|
||||||
|
signature is generated automatically.
|
||||||
|
|
||||||
The full function signature will be generated using:
|
The full function signature will be generated using:
|
||||||
|
|
||||||
@@ -78,44 +86,51 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
|||||||
|
|
||||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||||
|
|
||||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||||
|
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||||
|
dimension should be less than or equal to the corresponding grid dimension.
|
||||||
|
|
||||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||||
|
generated code for debugging purposes.
|
||||||
|
|
||||||
Using Shape/Strides
|
Using Shape/Strides
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
is ``True`` by default. This will copy the array inputs if needed
|
||||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
before the kernel is launched to ensure that the memory layout is row
|
||||||
when indexing.
|
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||||
|
have to worry about gaps or the ordering of the dims when indexing.
|
||||||
|
|
||||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||||
input array ``a`` if any are present in ``source``.
|
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||||
|
the right elements for each thread.
|
||||||
|
|
||||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||||
|
relying on a copy from ``ensure_row_contiguous``:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||||
|
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||||
|
T tmp = inp[loc];
|
||||||
|
// Output arrays are always row contiguous
|
||||||
|
out[elem] = metal::exp(tmp);
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="myexp_strided",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source
|
||||||
|
)
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
def exp_elementwise(a: mx.array):
|
||||||
source = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
|
||||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
|
||||||
T tmp = inp[loc];
|
|
||||||
// Output arrays are always row contiguous
|
|
||||||
out[elem] = metal::exp(tmp);
|
|
||||||
"""
|
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="myexp_strided",
|
|
||||||
input_names=["inp"],
|
|
||||||
output_names=["out"],
|
|
||||||
source=source
|
|
||||||
)
|
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def grid_sample_ref(x, grid):
|
def grid_sample_ref(x, grid):
|
||||||
N, H_in, W_in, _ = x.shape
|
N, H_in, W_in, _ = x.shape
|
||||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||||
|
|
||||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||||
|
|
||||||
ix_ne = ix_nw + 1
|
ix_ne = ix_nw + 1
|
||||||
iy_ne = iy_nw
|
iy_ne = iy_nw
|
||||||
|
|
||||||
ix_sw = ix_nw
|
ix_sw = ix_nw
|
||||||
iy_sw = iy_nw + 1
|
iy_sw = iy_nw + 1
|
||||||
|
|
||||||
ix_se = ix_nw + 1
|
ix_se = ix_nw + 1
|
||||||
iy_se = iy_nw + 1
|
iy_se = iy_nw + 1
|
||||||
|
|
||||||
nw = (ix_se - ix) * (iy_se - iy)
|
nw = (ix_se - ix) * (iy_se - iy)
|
||||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||||
se = (ix - ix_nw) * (iy - iy_nw)
|
se = (ix - ix_nw) * (iy - iy_nw)
|
||||||
|
|
||||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||||
|
|
||||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||||
|
|
||||||
I_nw *= mask_nw[..., None]
|
I_nw *= mask_nw[..., None]
|
||||||
I_ne *= mask_ne[..., None]
|
I_ne *= mask_ne[..., None]
|
||||||
I_sw *= mask_sw[..., None]
|
I_sw *= mask_sw[..., None]
|
||||||
I_se *= mask_se[..., None]
|
I_se *= mask_se[..., None]
|
||||||
|
|
||||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||||
to write a fast GPU kernel for both the forward and backward passes.
|
to write a fast GPU kernel for both the forward and backward passes.
|
||||||
|
|
||||||
First we'll implement the forward pass as a fused kernel:
|
First we'll implement the forward pass as a fused kernel:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@mx.custom_function
|
source = """
|
||||||
def grid_sample(x, grid):
|
uint elem = thread_position_in_grid.x;
|
||||||
|
int H = x_shape[1];
|
||||||
|
int W = x_shape[2];
|
||||||
|
int C = x_shape[3];
|
||||||
|
int gH = grid_shape[1];
|
||||||
|
int gW = grid_shape[2];
|
||||||
|
|
||||||
assert x.ndim == 4, "`x` must be 4D."
|
int w_stride = C;
|
||||||
assert grid.ndim == 4, "`grid` must be 4D."
|
int h_stride = W * w_stride;
|
||||||
|
int b_stride = H * h_stride;
|
||||||
|
|
||||||
B, _, _, C = x.shape
|
uint grid_idx = elem / C * 2;
|
||||||
_, gN, gM, D = grid.shape
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
out_shape = (B, gN, gM, C)
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
int ix_nw = floor(ix);
|
||||||
|
int iy_nw = floor(iy);
|
||||||
|
|
||||||
source = """
|
int ix_ne = ix_nw + 1;
|
||||||
uint elem = thread_position_in_grid.x;
|
int iy_ne = iy_nw;
|
||||||
int H = x_shape[1];
|
|
||||||
int W = x_shape[2];
|
|
||||||
int C = x_shape[3];
|
|
||||||
int gH = grid_shape[1];
|
|
||||||
int gW = grid_shape[2];
|
|
||||||
|
|
||||||
int w_stride = C;
|
int ix_sw = ix_nw;
|
||||||
int h_stride = W * w_stride;
|
int iy_sw = iy_nw + 1;
|
||||||
int b_stride = H * h_stride;
|
|
||||||
|
|
||||||
uint grid_idx = elem / C * 2;
|
int ix_se = ix_nw + 1;
|
||||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
int iy_se = iy_nw + 1;
|
||||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
|
||||||
|
|
||||||
int ix_nw = floor(ix);
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
int iy_nw = floor(iy);
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
|
|
||||||
int ix_ne = ix_nw + 1;
|
int batch_idx = elem / C / gH / gW * b_stride;
|
||||||
int iy_ne = iy_nw;
|
int channel_idx = elem % C;
|
||||||
|
int base_idx = batch_idx + channel_idx;
|
||||||
|
|
||||||
int ix_sw = ix_nw;
|
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||||
int iy_sw = iy_nw + 1;
|
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||||
|
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||||
|
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||||
|
|
||||||
int ix_se = ix_nw + 1;
|
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||||
int iy_se = iy_nw + 1;
|
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||||
|
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||||
|
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||||
|
|
||||||
T nw = (ix_se - ix) * (iy_se - iy);
|
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
"""
|
||||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
|
||||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
|
||||||
|
|
||||||
int batch_idx = elem / C / gH / gW * b_stride;
|
kernel = mx.fast.metal_kernel(
|
||||||
int channel_idx = elem % C;
|
name="grid_sample",
|
||||||
int base_idx = batch_idx + channel_idx;
|
input_names=["x", "grid"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
|
||||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
@mx.custom_function
|
||||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
def grid_sample(x, grid):
|
||||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
|
||||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
|
||||||
|
|
||||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
assert x.ndim == 4, "`x` must be 4D."
|
||||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
assert grid.ndim == 4, "`grid` must be 4D."
|
||||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
|
||||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
|
||||||
|
|
||||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
B, _, _, C = x.shape
|
||||||
"""
|
_, gN, gM, D = grid.shape
|
||||||
kernel = mx.fast.metal_kernel(
|
out_shape = (B, gN, gM, C)
|
||||||
name="grid_sample",
|
|
||||||
input_names=["x", "grid"],
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
output_names=["out"],
|
|
||||||
source=source,
|
outputs = kernel(
|
||||||
)
|
inputs=[x, grid],
|
||||||
outputs = kernel(
|
template=[("T", x.dtype)],
|
||||||
inputs=[x, grid],
|
output_shapes=[out_shape],
|
||||||
template=[("T", x.dtype)],
|
output_dtypes=[x.dtype],
|
||||||
output_shapes=[out_shape],
|
grid=(np.prod(out_shape), 1, 1),
|
||||||
output_dtypes=[x.dtype],
|
threadgroup=(256, 1, 1),
|
||||||
grid=(np.prod(out_shape), 1, 1),
|
)
|
||||||
threadgroup=(256, 1, 1),
|
return outputs[0]
|
||||||
)
|
|
||||||
return outputs[0]
|
|
||||||
|
|
||||||
For a reasonably sized input such as:
|
For a reasonably sized input such as:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
x.shape = (8, 1024, 1024, 64)
|
x.shape = (8, 1024, 1024, 64)
|
||||||
grid.shape = (8, 256, 256, 2)
|
grid.shape = (8, 256, 256, 2)
|
||||||
|
|
||||||
On an M1 Max, we see a big performance improvement:
|
On an M1 Max, we see a big performance improvement:
|
||||||
|
|
||||||
@@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
|
|||||||
Grid Sample VJP
|
Grid Sample VJP
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||||
its custom vjp transform so MLX can differentiate it.
|
define its custom vjp transform so MLX can differentiate it.
|
||||||
|
|
||||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
requires a few extra :func:`fast.metal_kernel` features:
|
||||||
|
|
||||||
* ``init_value=0``
|
* ``init_value=0``
|
||||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||||
@@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@grid_sample.vjp
|
source = """
|
||||||
def grid_sample_vjp(primals, cotangent, _):
|
uint elem = thread_position_in_grid.x;
|
||||||
x, grid = primals
|
int H = x_shape[1];
|
||||||
B, _, _, C = x.shape
|
int W = x_shape[2];
|
||||||
_, gN, gM, D = grid.shape
|
int C = x_shape[3];
|
||||||
|
// Pad C to the nearest larger simdgroup size multiple
|
||||||
|
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
int gH = grid_shape[1];
|
||||||
|
int gW = grid_shape[2];
|
||||||
|
|
||||||
source = """
|
int w_stride = C;
|
||||||
uint elem = thread_position_in_grid.x;
|
int h_stride = W * w_stride;
|
||||||
int H = x_shape[1];
|
int b_stride = H * h_stride;
|
||||||
int W = x_shape[2];
|
|
||||||
int C = x_shape[3];
|
|
||||||
// Pad C to the nearest larger simdgroup size multiple
|
|
||||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
|
||||||
|
|
||||||
int gH = grid_shape[1];
|
uint grid_idx = elem / C_padded * 2;
|
||||||
int gW = grid_shape[2];
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
int w_stride = C;
|
int ix_nw = floor(ix);
|
||||||
int h_stride = W * w_stride;
|
int iy_nw = floor(iy);
|
||||||
int b_stride = H * h_stride;
|
|
||||||
|
|
||||||
uint grid_idx = elem / C_padded * 2;
|
int ix_ne = ix_nw + 1;
|
||||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
int iy_ne = iy_nw;
|
||||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
|
||||||
|
|
||||||
int ix_nw = floor(ix);
|
int ix_sw = ix_nw;
|
||||||
int iy_nw = floor(iy);
|
int iy_sw = iy_nw + 1;
|
||||||
|
|
||||||
int ix_ne = ix_nw + 1;
|
int ix_se = ix_nw + 1;
|
||||||
int iy_ne = iy_nw;
|
int iy_se = iy_nw + 1;
|
||||||
|
|
||||||
int ix_sw = ix_nw;
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
int iy_sw = iy_nw + 1;
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
|
|
||||||
int ix_se = ix_nw + 1;
|
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||||
int iy_se = iy_nw + 1;
|
int channel_idx = elem % C_padded;
|
||||||
|
int base_idx = batch_idx + channel_idx;
|
||||||
|
|
||||||
T nw = (ix_se - ix) * (iy_se - iy);
|
T gix = T(0);
|
||||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
T giy = T(0);
|
||||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
if (channel_idx < C) {
|
||||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
int cot_index = elem / C_padded * C + channel_idx;
|
||||||
|
T cot = cotangent[cot_index];
|
||||||
|
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||||
|
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||||
|
|
||||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
T I_nw = x[offset];
|
||||||
int channel_idx = elem % C_padded;
|
gix -= I_nw * (iy_se - iy) * cot;
|
||||||
int base_idx = batch_idx + channel_idx;
|
giy -= I_nw * (ix_se - ix) * cot;
|
||||||
|
}
|
||||||
|
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||||
|
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||||
|
|
||||||
T gix = T(0);
|
T I_ne = x[offset];
|
||||||
T giy = T(0);
|
gix += I_ne * (iy_sw - iy) * cot;
|
||||||
if (channel_idx < C) {
|
giy -= I_ne * (ix - ix_sw) * cot;
|
||||||
int cot_index = elem / C_padded * C + channel_idx;
|
}
|
||||||
T cot = cotangent[cot_index];
|
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_nw = x[offset];
|
T I_sw = x[offset];
|
||||||
gix -= I_nw * (iy_se - iy) * cot;
|
gix -= I_sw * (iy - iy_ne) * cot;
|
||||||
giy -= I_nw * (ix_se - ix) * cot;
|
giy += I_sw * (ix_ne - ix) * cot;
|
||||||
}
|
}
|
||||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||||
|
|
||||||
T I_ne = x[offset];
|
T I_se = x[offset];
|
||||||
gix += I_ne * (iy_sw - iy) * cot;
|
gix += I_se * (iy - iy_nw) * cot;
|
||||||
giy -= I_ne * (ix - ix_sw) * cot;
|
giy += I_se * (ix - ix_nw) * cot;
|
||||||
}
|
}
|
||||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
}
|
||||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_sw = x[offset];
|
T gix_mult = W / 2;
|
||||||
gix -= I_sw * (iy - iy_ne) * cot;
|
T giy_mult = H / 2;
|
||||||
giy += I_sw * (ix_ne - ix) * cot;
|
|
||||||
}
|
|
||||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
|
||||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_se = x[offset];
|
// Reduce across each simdgroup first.
|
||||||
gix += I_se * (iy - iy_nw) * cot;
|
// This is much faster than relying purely on atomics.
|
||||||
giy += I_se * (ix - ix_nw) * cot;
|
gix = simd_sum(gix);
|
||||||
}
|
giy = simd_sum(giy);
|
||||||
}
|
|
||||||
|
|
||||||
T gix_mult = W / 2;
|
if (thread_index_in_simdgroup == 0) {
|
||||||
T giy_mult = H / 2;
|
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||||
|
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="grid_sample_grad",
|
||||||
|
input_names=["x", "grid", "cotangent"],
|
||||||
|
output_names=["x_grad", "grid_grad"],
|
||||||
|
source=source,
|
||||||
|
atomic_outputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
// Reduce across each simdgroup first.
|
@grid_sample.vjp
|
||||||
// This is much faster than relying purely on atomics.
|
def grid_sample_vjp(primals, cotangent, _):
|
||||||
gix = simd_sum(gix);
|
x, grid = primals
|
||||||
giy = simd_sum(giy);
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
|
||||||
if (thread_index_in_simdgroup == 0) {
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
|
||||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
# pad the output channels to simd group size
|
||||||
}
|
# so that our `simd_sum`s don't overlap.
|
||||||
"""
|
simdgroup_size = 32
|
||||||
kernel = mx.fast.metal_kernel(
|
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||||
name="grid_sample_grad",
|
grid_size = B * gN * gM * C_padded
|
||||||
input_names=["x", "grid", "cotangent"],
|
outputs = kernel(
|
||||||
output_names=["x_grad", "grid_grad"],
|
inputs=[x, grid, cotangent],
|
||||||
source=source,
|
template=[("T", x.dtype)],
|
||||||
atomic_outputs=True,
|
output_shapes=[x.shape, grid.shape],
|
||||||
)
|
output_dtypes=[x.dtype, x.dtype],
|
||||||
# pad the output channels to simd group size
|
grid=(grid_size, 1, 1),
|
||||||
# so that our `simd_sum`s don't overlap.
|
threadgroup=(256, 1, 1),
|
||||||
simdgroup_size = 32
|
init_value=0,
|
||||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
)
|
||||||
grid_size = B * gN * gM * C_padded
|
return outputs[0], outputs[1]
|
||||||
outputs = kernel(
|
|
||||||
inputs=[x, grid, cotangent],
|
|
||||||
template=[("T", x.dtype)],
|
|
||||||
output_shapes=[x.shape, grid.shape],
|
|
||||||
output_dtypes=[x.dtype, x.dtype],
|
|
||||||
grid=(grid_size, 1, 1),
|
|
||||||
threadgroup=(256, 1, 1),
|
|
||||||
init_value=0,
|
|
||||||
)
|
|
||||||
return outputs[0], outputs[1]
|
|
||||||
|
|
||||||
There's an even larger speed up for the vjp:
|
There's an even larger speed up for the vjp:
|
||||||
|
|
||||||
|
|||||||
@@ -138,13 +138,13 @@ more concrete:
|
|||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
*/
|
*/
|
||||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
@@ -394,14 +394,14 @@ below.
|
|||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
// Resolve name of kernel
|
// Resolve name of kernel
|
||||||
std::ostringstream kname;
|
std::stream kname;
|
||||||
kname << "axpby_" << "general_" << type_to_name(out);
|
kname = "axpby_general_" + type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Load the metal library
|
||||||
d.register_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname, lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ silicon computer is
|
|||||||
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
|
|
||||||
To install from PyPI you must meet the following requirements:
|
To install from PyPI your system must meet the following requirements:
|
||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.9
|
- Using a native Python >= 3.9
|
||||||
@@ -23,12 +23,39 @@ To install from PyPI you must meet the following requirements:
|
|||||||
MLX is only available on devices running macOS >= 13.5
|
MLX is only available on devices running macOS >= 13.5
|
||||||
It is highly recommended to use macOS 14 (Sonoma)
|
It is highly recommended to use macOS 14 (Sonoma)
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
MLX is also available on conda-forge. To install MLX with conda do:
|
MLX has a CUDA backend which you can install with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
conda install conda-forge::mlx
|
pip install mlx[cuda]
|
||||||
|
|
||||||
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Nvidia architecture >= SM 7.0 (Volta)
|
||||||
|
- Nvidia driver >= 550.54.14
|
||||||
|
- CUDA toolkit >= 12.0
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.9
|
||||||
|
|
||||||
|
|
||||||
|
CPU-only (Linux)
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
For a CPU-only version of MLX that runs on Linux use:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install mlx[cpu]
|
||||||
|
|
||||||
|
To install the CPU-only package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.9
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
@@ -65,6 +92,8 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
.. _python install:
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@@ -76,20 +105,20 @@ Then simply build and install MLX using pip:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
pip install .
|
||||||
|
|
||||||
For developing, install the package with development dependencies, and use an
|
For developing, install the package with development dependencies, and use an
|
||||||
editable install:
|
editable install:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
Once the development dependencies are installed, you can build faster with:
|
Once the development dependencies are installed, you can build faster with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
python setup.py build_ext --inplace
|
||||||
|
|
||||||
Run the tests with:
|
Run the tests with:
|
||||||
|
|
||||||
@@ -107,6 +136,8 @@ IDE:
|
|||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
.. _cpp install:
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@@ -185,6 +216,7 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -213,6 +245,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
|||||||
application. Once a kernel is compiled, it will be cached by the system. The
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
Metal kernel cache persists across reboots.
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
|
Linux
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||||
|
For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
apt-get update -y
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
From here follow the instructions to install either the :ref:`Python <python
|
||||||
|
install>` or :ref:`C++ <cpp install>` APIs.
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||||
|
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
apt-get update -y
|
||||||
|
apt-get -y install cuda-toolkit-12-9
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
|
||||||
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||||
|
|
||||||
|
To build the C++ package run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ Array
|
|||||||
array.ndim
|
array.ndim
|
||||||
array.shape
|
array.shape
|
||||||
array.size
|
array.size
|
||||||
|
array.real
|
||||||
|
array.imag
|
||||||
array.abs
|
array.abs
|
||||||
array.all
|
array.all
|
||||||
array.any
|
array.any
|
||||||
@@ -38,6 +40,7 @@ Array
|
|||||||
array.log10
|
array.log10
|
||||||
array.log1p
|
array.log1p
|
||||||
array.log2
|
array.log2
|
||||||
|
array.logcumsumexp
|
||||||
array.logsumexp
|
array.logsumexp
|
||||||
array.max
|
array.max
|
||||||
array.mean
|
array.mean
|
||||||
|
|||||||
@@ -20,3 +20,5 @@ FFT
|
|||||||
irfft2
|
irfft2
|
||||||
rfftn
|
rfftn
|
||||||
irfftn
|
irfftn
|
||||||
|
fftshift
|
||||||
|
ifftshift
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ Linear Algebra
|
|||||||
cross
|
cross
|
||||||
qr
|
qr
|
||||||
svd
|
svd
|
||||||
|
eigvals
|
||||||
|
eig
|
||||||
eigvalsh
|
eigvalsh
|
||||||
eigh
|
eigh
|
||||||
lu
|
lu
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ Operations
|
|||||||
log10
|
log10
|
||||||
log1p
|
log1p
|
||||||
logaddexp
|
logaddexp
|
||||||
|
logcumsumexp
|
||||||
logical_not
|
logical_not
|
||||||
logical_and
|
logical_and
|
||||||
logical_or
|
logical_or
|
||||||
|
|||||||
@@ -18,3 +18,5 @@ Common Optimizers
|
|||||||
AdamW
|
AdamW
|
||||||
Adamax
|
Adamax
|
||||||
Lion
|
Lion
|
||||||
|
MultiOptimizer
|
||||||
|
Muon
|
||||||
|
|||||||
@@ -107,6 +107,16 @@ same array:
|
|||||||
>>> a
|
>>> a
|
||||||
array([1, 2, 0], dtype=int32)
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
|
||||||
|
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> a[[0, 0]] = mx.array([4, 5])
|
||||||
|
|
||||||
|
The first element of ``a`` could be ``4`` or ``5``.
|
||||||
|
|
||||||
Transformations of functions which use in-place updates are allowed and work as
|
Transformations of functions which use in-place updates are allowed and work as
|
||||||
expected. For example:
|
expected. For example:
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023-2025 Apple Inc.
|
// Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
@@ -16,6 +17,19 @@
|
|||||||
|
|
||||||
namespace my_ext {
|
namespace my_ext {
|
||||||
|
|
||||||
|
// A helper function to find the location of the current binary on disk.
|
||||||
|
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
|
||||||
|
std::string current_binary_dir() {
|
||||||
|
static std::string binary_dir = []() {
|
||||||
|
Dl_info info;
|
||||||
|
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||||
|
throw std::runtime_error("Unable to get current binary dir.");
|
||||||
|
}
|
||||||
|
return std::filesystem::path(info.dli_fname).parent_path().string();
|
||||||
|
}();
|
||||||
|
return binary_dir;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Operation Implementation
|
// Operation Implementation
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -167,16 +181,15 @@ void Axpby::eval_gpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resolve name of kernel (corresponds to axpby.metal)
|
// Resolve name of kernel (corresponds to axpby.metal)
|
||||||
std::ostringstream kname;
|
std::string kname = "axpby_";
|
||||||
kname << "axpby_";
|
kname += (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
kname += type_to_name(out);
|
||||||
kname << type_to_name(out);
|
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Load the metal library
|
||||||
d.register_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname, lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|||||||
@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
|
|||||||
const std::vector<mx::array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
setuptools>=42
|
setuptools>=42
|
||||||
cmake>=3.25
|
cmake>=3.25
|
||||||
mlx>=0.21.0
|
mlx>=0.21.0
|
||||||
nanobind==2.2.0
|
nanobind==2.4.0
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
|
|||||||
|
|
||||||
a = mx.ones((3, 4))
|
a = mx.ones((3, 4))
|
||||||
b = mx.ones((3, 4))
|
b = mx.ones((3, 4))
|
||||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||||
|
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
|
||||||
|
|
||||||
print(f"c shape: {c.shape}")
|
print(f"c shape: {c_cpu.shape}")
|
||||||
print(f"c dtype: {c.dtype}")
|
print(f"c dtype: {c_cpu.dtype}")
|
||||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
|
||||||
|
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
@@ -20,7 +21,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||||
|
|
||||||
# Define MLX_VERSION only in the version.cpp file.
|
# Define MLX_VERSION only in the version.cpp file.
|
||||||
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||||
|
|
||||||
@@ -48,5 +49,19 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
|||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||||
|
else()
|
||||||
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||||
|
else()
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
17
mlx/array.h
17
mlx/array.h
@@ -10,6 +10,7 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/dtype.h"
|
#include "mlx/dtype.h"
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
|
#include "mlx/small_vector.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -18,8 +19,8 @@ class Primitive;
|
|||||||
|
|
||||||
using Deleter = std::function<void(allocator::Buffer)>;
|
using Deleter = std::function<void(allocator::Buffer)>;
|
||||||
using ShapeElem = int32_t;
|
using ShapeElem = int32_t;
|
||||||
using Shape = std::vector<ShapeElem>;
|
using Shape = SmallVector<ShapeElem>;
|
||||||
using Strides = std::vector<int64_t>;
|
using Strides = SmallVector<int64_t>;
|
||||||
|
|
||||||
class array {
|
class array {
|
||||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||||
@@ -224,6 +225,10 @@ class array {
|
|||||||
// Not copyable
|
// Not copyable
|
||||||
Data(const Data& d) = delete;
|
Data(const Data& d) = delete;
|
||||||
Data& operator=(const Data& d) = delete;
|
Data& operator=(const Data& d) = delete;
|
||||||
|
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||||
|
o.buffer = allocator::Buffer(nullptr);
|
||||||
|
o.d = [](allocator::Buffer) {};
|
||||||
|
}
|
||||||
~Data() {
|
~Data() {
|
||||||
d(buffer);
|
d(buffer);
|
||||||
}
|
}
|
||||||
@@ -339,11 +344,11 @@ class array {
|
|||||||
return allocator::allocator().size(buffer());
|
return allocator::allocator().size(buffer());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a copy of the shared pointer
|
// Return the shared pointer to the array::Data struct
|
||||||
// to the array::Data struct
|
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||||
std::shared_ptr<Data> data_shared_ptr() const {
|
|
||||||
return array_desc_->data;
|
return array_desc_->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a raw pointer to the arrays data
|
// Return a raw pointer to the arrays data
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* data() {
|
T* data() {
|
||||||
@@ -356,7 +361,7 @@ class array {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enum Status {
|
enum Status {
|
||||||
// The ouptut of a computation which has not been scheduled.
|
// The output of a computation which has not been scheduled.
|
||||||
// For example, the status of `x` in `auto x = a + b`.
|
// For example, the status of `x` in `auto x = a + b`.
|
||||||
unscheduled,
|
unscheduled,
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
|
|||||||
24
mlx/backend/common/broadcasting.cpp
Normal file
24
mlx/backend/common/broadcasting.cpp
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void broadcast(const array& in, array& out) {
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Strides strides(out.ndim(), 0);
|
||||||
|
int diff = out.ndim() - in.ndim();
|
||||||
|
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||||
|
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||||
|
}
|
||||||
|
auto flags = in.flags();
|
||||||
|
if (out.size() > in.size()) {
|
||||||
|
flags.row_contiguous = flags.col_contiguous = false;
|
||||||
|
}
|
||||||
|
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
11
mlx/backend/common/broadcasting.h
Normal file
11
mlx/backend/common/broadcasting.h
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void broadcast(const array& in, array& out);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
157
mlx/backend/common/buffer_cache.h
Normal file
157
mlx/backend/common/buffer_cache.h
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class BufferCache {
|
||||||
|
public:
|
||||||
|
BufferCache(
|
||||||
|
size_t page_size,
|
||||||
|
std::function<size_t(T*)> get_size,
|
||||||
|
std::function<void(T*)> free)
|
||||||
|
: page_size_(page_size),
|
||||||
|
get_size_(std::move(get_size)),
|
||||||
|
free_(std::move(free)) {}
|
||||||
|
|
||||||
|
~BufferCache() {
|
||||||
|
clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferCache(const BufferCache&) = delete;
|
||||||
|
BufferCache& operator=(const BufferCache&) = delete;
|
||||||
|
|
||||||
|
T* reuse_from_cache(size_t size) {
|
||||||
|
// Find the closest buffer in pool.
|
||||||
|
auto it = buffer_pool_.lower_bound(size);
|
||||||
|
if (it == buffer_pool_.end() ||
|
||||||
|
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect from the cache.
|
||||||
|
T* buf = it->second->buf;
|
||||||
|
pool_size_ -= it->first;
|
||||||
|
|
||||||
|
// Remove from record.
|
||||||
|
remove_from_list(it->second);
|
||||||
|
buffer_pool_.erase(it);
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void recycle_to_cache(T* buf) {
|
||||||
|
assert(buf);
|
||||||
|
// Add to cache.
|
||||||
|
BufferHolder* bh = new BufferHolder(buf);
|
||||||
|
add_at_head(bh);
|
||||||
|
size_t size = get_size_(buf);
|
||||||
|
pool_size_ += size;
|
||||||
|
buffer_pool_.emplace(size, bh);
|
||||||
|
}
|
||||||
|
|
||||||
|
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||||
|
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||||
|
return clear();
|
||||||
|
} else {
|
||||||
|
int n_release = 0;
|
||||||
|
size_t total_bytes_freed = 0;
|
||||||
|
|
||||||
|
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||||
|
// Release buffer.
|
||||||
|
size_t size = get_size_(tail_->buf);
|
||||||
|
total_bytes_freed += size;
|
||||||
|
free_(tail_->buf);
|
||||||
|
n_release++;
|
||||||
|
|
||||||
|
// Remove from record.
|
||||||
|
auto its = buffer_pool_.equal_range(size);
|
||||||
|
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||||
|
return el.second == tail_;
|
||||||
|
});
|
||||||
|
assert(it != buffer_pool_.end());
|
||||||
|
buffer_pool_.erase(it);
|
||||||
|
remove_from_list(tail_);
|
||||||
|
}
|
||||||
|
|
||||||
|
pool_size_ -= total_bytes_freed;
|
||||||
|
return n_release;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int clear() {
|
||||||
|
int n_release = 0;
|
||||||
|
for (auto& [size, holder] : buffer_pool_) {
|
||||||
|
free_(holder->buf);
|
||||||
|
n_release++;
|
||||||
|
delete holder;
|
||||||
|
}
|
||||||
|
buffer_pool_.clear();
|
||||||
|
pool_size_ = 0;
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
return n_release;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t cache_size() const {
|
||||||
|
return pool_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t page_size() const {
|
||||||
|
return page_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct BufferHolder {
|
||||||
|
public:
|
||||||
|
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||||
|
|
||||||
|
BufferHolder* prev{nullptr};
|
||||||
|
BufferHolder* next{nullptr};
|
||||||
|
T* buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
void add_at_head(BufferHolder* to_add) {
|
||||||
|
if (!head_) {
|
||||||
|
head_ = to_add;
|
||||||
|
tail_ = to_add;
|
||||||
|
} else {
|
||||||
|
head_->prev = to_add;
|
||||||
|
to_add->next = head_;
|
||||||
|
head_ = to_add;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void remove_from_list(BufferHolder* to_remove) {
|
||||||
|
if (to_remove->prev && to_remove->next) { // if middle
|
||||||
|
to_remove->prev->next = to_remove->next;
|
||||||
|
to_remove->next->prev = to_remove->prev;
|
||||||
|
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||||
|
tail_ = to_remove->prev;
|
||||||
|
tail_->next = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||||
|
head_ = to_remove->next;
|
||||||
|
head_->prev = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
delete to_remove;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||||
|
BufferHolder* head_{nullptr};
|
||||||
|
BufferHolder* tail_{nullptr};
|
||||||
|
size_t pool_size_{0};
|
||||||
|
|
||||||
|
const size_t page_size_;
|
||||||
|
std::function<size_t(T*)> get_size_;
|
||||||
|
std::function<void(T*)> free_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/broadcasting.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void broadcast(const array& in, array& out) {
|
|
||||||
if (out.size() == 0) {
|
|
||||||
out.set_data(nullptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Strides strides(out.ndim(), 0);
|
|
||||||
int diff = out.ndim() - in.ndim();
|
|
||||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
|
||||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
|
||||||
}
|
|
||||||
auto flags = in.flags();
|
|
||||||
if (out.size() > in.size()) {
|
|
||||||
flags.row_contiguous = flags.col_contiguous = false;
|
|
||||||
}
|
|
||||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||||
broadcast(inputs[0], out);
|
broadcast(inputs[0], out);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -15,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) {
|
|||||||
return print_float_constant<float16_t>(os, x);
|
return print_float_constant<float16_t>(os, x);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return print_float_constant<bfloat16_t>(os, x);
|
return print_float_constant<bfloat16_t>(os, x);
|
||||||
|
case float64:
|
||||||
|
return print_float_constant<double>(os, x);
|
||||||
case complex64:
|
case complex64:
|
||||||
return print_complex_constant<complex64_t>(os, x);
|
return print_complex_constant<complex64_t>(os, x);
|
||||||
case int8:
|
case int8:
|
||||||
@@ -51,6 +52,8 @@ std::string get_type_string(Dtype d) {
|
|||||||
return "float16_t";
|
return "float16_t";
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return "bfloat16_t";
|
return "bfloat16_t";
|
||||||
|
case float64:
|
||||||
|
return "double";
|
||||||
case complex64:
|
case complex64:
|
||||||
return "complex64_t";
|
return "complex64_t";
|
||||||
case bool_:
|
case bool_:
|
||||||
@@ -79,55 +82,6 @@ std::string get_type_string(Dtype d) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_lib_name(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
const std::vector<array>& tape,
|
|
||||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
|
||||||
NodeNamer namer;
|
|
||||||
std::ostringstream os;
|
|
||||||
std::ostringstream constant_hasher;
|
|
||||||
|
|
||||||
// Fill the input names. This is not really necessary, I just like having A,
|
|
||||||
// B, C, ... as the inputs.
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
namer.get_name(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The primitives describing the tape. For unary and binary primitives this
|
|
||||||
// must be enough to describe the full computation.
|
|
||||||
for (auto& a : tape) {
|
|
||||||
// name and type of output
|
|
||||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
|
||||||
// computation performed
|
|
||||||
a.primitive().print(os);
|
|
||||||
// name of inputs to the function
|
|
||||||
for (auto& inp : a.inputs()) {
|
|
||||||
os << namer.get_name(inp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
os << "_";
|
|
||||||
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
|
||||||
os << "C";
|
|
||||||
print_constant(constant_hasher, x);
|
|
||||||
} else {
|
|
||||||
os << (is_scalar(x) ? "S" : "V");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
os << "_";
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
os << kindof(x.dtype()) << x.itemsize();
|
|
||||||
}
|
|
||||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
|
||||||
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool compiled_check_contiguity(
|
bool compiled_check_contiguity(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const Shape& shape) {
|
const Shape& shape) {
|
||||||
@@ -159,8 +113,7 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::vector<array>& inputs_,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
|
||||||
bool contiguous) {
|
bool contiguous) {
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
int o = 0;
|
int o = 0;
|
||||||
@@ -175,8 +128,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||||
in.is_donatable() &&
|
in.is_donatable() && is_constant(i)) {
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o++].copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
// Get representative input flags to properly set non-donated outputs
|
// Get representative input flags to properly set non-donated outputs
|
||||||
@@ -204,7 +156,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
is_constant(i)) {
|
||||||
outputs[o].copy_shared_buffer(
|
outputs[o].copy_shared_buffer(
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
o++;
|
o++;
|
||||||
@@ -216,4 +168,74 @@ void compiled_allocate_outputs(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const array& out,
|
||||||
|
const std::function<bool(size_t)>& is_constant) {
|
||||||
|
const Shape& shape = out.shape();
|
||||||
|
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||||
|
if (contiguous) {
|
||||||
|
return {true, shape, {}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Strides> strides_vec{out.strides()};
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
// Skip constants.
|
||||||
|
if (is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip scalar inputs.
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
if (is_scalar(x)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast the inputs to the output shape.
|
||||||
|
Strides xstrides;
|
||||||
|
size_t j = 0;
|
||||||
|
for (; j < shape.size() - x.ndim(); ++j) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(out.strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||||
|
if (x.shape(i) == 1) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(out.strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(x.strides()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
strides_vec.push_back(std::move(xstrides));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||||
|
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool compiled_use_large_index(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
bool contiguous) {
|
||||||
|
if (contiguous) {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (const auto& in : inputs) {
|
||||||
|
max_size = std::max(max_size, in.data_size());
|
||||||
|
}
|
||||||
|
return max_size > UINT32_MAX;
|
||||||
|
} else {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (const auto& o : outputs) {
|
||||||
|
max_size = std::max(max_size, o.size());
|
||||||
|
}
|
||||||
|
return max_size > UINT32_MAX;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <sstream>
|
|
||||||
#include <unordered_set>
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -14,19 +13,17 @@ inline bool is_static_cast(const Primitive& p) {
|
|||||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_lib_name(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
const std::vector<array>& tape,
|
|
||||||
const std::unordered_set<uintptr_t>& constant_ids);
|
|
||||||
|
|
||||||
std::string get_type_string(Dtype d);
|
std::string get_type_string(Dtype d);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void print_float_constant(std::ostream& os, const array& x) {
|
void print_float_constant(std::ostream& os, const array& x) {
|
||||||
auto old_precision = os.precision();
|
auto old_precision = os.precision();
|
||||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
if constexpr (std::is_same_v<T, double>) {
|
||||||
<< x.item<T>() << std::setprecision(old_precision);
|
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
||||||
|
} else {
|
||||||
|
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
||||||
|
}
|
||||||
|
os << x.item<T>() << std::setprecision(old_precision);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -60,8 +57,19 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::vector<array>& inputs_,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
bool contiguous);
|
||||||
|
|
||||||
|
// Collapse contiguous dims ignoring scalars and constants.
|
||||||
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const array& out,
|
||||||
|
const std::function<bool(size_t)>& is_constant);
|
||||||
|
|
||||||
|
// Return whether the kernel should use large index.
|
||||||
|
bool compiled_use_large_index(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
bool contiguous);
|
bool contiguous);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
|||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
// If the input is donateable, we are doing a vector copy and the types
|
// If the input is donateable, we are doing a vector copy and the types
|
||||||
// have the same size, then the input buffer can hold the output.
|
// have the same size, then the input buffer can hold the output.
|
||||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
if (is_donatable(in, out)) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
|||||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (n > (1 << 26)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||||
|
}
|
||||||
return {n, m};
|
return {n, m};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
67
mlx/backend/common/matmul.h
Normal file
67
mlx/backend/common/matmul.h
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||||
|
const array& a,
|
||||||
|
const array& b) {
|
||||||
|
if (a.ndim() == 2) {
|
||||||
|
return {{1}, {0}, {0}};
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||||
|
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
|
|
||||||
|
auto [batch_shape, batch_strides] =
|
||||||
|
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||||
|
|
||||||
|
auto a_batch_strides = batch_strides[0];
|
||||||
|
auto b_batch_strides = batch_strides[1];
|
||||||
|
|
||||||
|
if (batch_shape.empty()) {
|
||||||
|
batch_shape.push_back(1);
|
||||||
|
a_batch_strides.push_back(0);
|
||||||
|
b_batch_strides.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||||
|
collapse_batches(const array& a, const array& b, const array& c) {
|
||||||
|
if (a.ndim() == 2) {
|
||||||
|
return {{1}, {0}, {0}, {0}};
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||||
|
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
|
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||||
|
|
||||||
|
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||||
|
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||||
|
|
||||||
|
auto A_batch_stride = batch_strides[0];
|
||||||
|
auto B_batch_stride = batch_strides[1];
|
||||||
|
auto C_batch_stride = batch_strides[2];
|
||||||
|
|
||||||
|
if (batch_shape.empty()) {
|
||||||
|
batch_shape.push_back(1);
|
||||||
|
A_batch_stride.push_back(0);
|
||||||
|
B_batch_stride.push_back(0);
|
||||||
|
C_batch_stride.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(
|
||||||
|
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -5,11 +5,9 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
Shape shape,
|
||||||
|
Strides strides,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto shape = x.shape();
|
|
||||||
auto strides = x.strides();
|
|
||||||
|
|
||||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||||
int a = axes[i];
|
int a = axes[i];
|
||||||
shape.erase(shape.begin() + a);
|
shape.erase(shape.begin() + a);
|
||||||
@@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|||||||
return std::make_pair(shape, strides);
|
return std::make_pair(shape, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
|
const array& x,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto shape = x.shape();
|
||||||
|
auto strides = x.strides();
|
||||||
|
return shapes_without_reduction_axes(
|
||||||
|
std::move(shape), std::move(strides), axes);
|
||||||
|
}
|
||||||
|
|
||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
|
|||||||
@@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
|||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
const array& x,
|
||||||
const std::vector<int>& axes);
|
const std::vector<int>& axes);
|
||||||
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
|
Shape shape,
|
||||||
|
Strides strides,
|
||||||
|
const std::vector<int>& axes);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
26
mlx/backend/common/unary.h
Normal file
26
mlx/backend/common/unary.h
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline void set_unary_output_data(const array& in, array& out) {
|
||||||
|
if (in.flags().contiguous) {
|
||||||
|
if (is_donatable(in, out)) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(in.data_size() * out.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,9 +1,22 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
std::filesystem::path current_binary_dir() {
|
||||||
|
static std::filesystem::path binary_dir = []() {
|
||||||
|
Dl_info info;
|
||||||
|
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||||
|
throw std::runtime_error("Unable to get current binary dir.");
|
||||||
|
}
|
||||||
|
return std::filesystem::path(info.dli_fname).parent_path();
|
||||||
|
}();
|
||||||
|
return binary_dir;
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const std::vector<Strides>& strides,
|
const std::vector<Strides>& strides,
|
||||||
@@ -101,4 +114,145 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||||
|
int pows[3] = {0, 0, 0};
|
||||||
|
int sum = 0;
|
||||||
|
while (true) {
|
||||||
|
int presum = sum;
|
||||||
|
// Check all the pows
|
||||||
|
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||||
|
pows[0]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||||
|
pows[1]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||||
|
pows[2]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == presum || sum == pow2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
||||||
|
// Dims with strides of 0 are ignored as they
|
||||||
|
// correspond to broadcasted dimensions
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims get_2d_grid_dims_common(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor) {
|
||||||
|
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||||
|
// divided by divisor.
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// No need to add this shape we can just remove it from the divisor.
|
||||||
|
if (divisor % shape[i] == 0) {
|
||||||
|
divisor /= shape[i];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (divisor > 1) {
|
||||||
|
if (grid_x % divisor == 0) {
|
||||||
|
grid_x /= divisor;
|
||||||
|
divisor = 1;
|
||||||
|
} else if (grid_y % divisor == 0) {
|
||||||
|
grid_y /= divisor;
|
||||||
|
divisor = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
|
if (divisor > 1) {
|
||||||
|
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||||
|
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
||||||
|
auto gx = (dim0 + bx - 1) / bx;
|
||||||
|
auto gy = (dim1 + by - 1) / by;
|
||||||
|
auto gz = (dim2 + bz - 1) / bz;
|
||||||
|
|
||||||
|
return std::make_pair(
|
||||||
|
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||||
|
}
|
||||||
|
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||||
|
int ndim = x.ndim();
|
||||||
|
if (axis1 < 0) {
|
||||||
|
axis1 += ndim;
|
||||||
|
}
|
||||||
|
if (axis2 < 0) {
|
||||||
|
axis2 += ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape = x.shape();
|
||||||
|
std::swap(shape[axis1], shape[axis2]);
|
||||||
|
auto strides = x.strides();
|
||||||
|
std::swap(strides[axis1], strides[axis2]);
|
||||||
|
|
||||||
|
auto [data_size, row_contiguous, col_contiguous] =
|
||||||
|
check_contiguity(shape, strides);
|
||||||
|
bool contiguous = data_size == x.data_size();
|
||||||
|
|
||||||
|
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||||
|
out.copy_shared_buffer(
|
||||||
|
x,
|
||||||
|
std::move(strides),
|
||||||
|
{contiguous, row_contiguous, col_contiguous},
|
||||||
|
x.data_size());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,12 +2,17 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Return the directory that contains current shared library.
|
||||||
|
std::filesystem::path current_binary_dir();
|
||||||
|
|
||||||
inline int64_t
|
inline int64_t
|
||||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||||
int64_t loc = 0;
|
int64_t loc = 0;
|
||||||
@@ -70,6 +75,31 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
const array& a,
|
const array& a,
|
||||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||||
|
|
||||||
|
// Compute the thread block dimensions which fit the given
|
||||||
|
// input dimensions.
|
||||||
|
// - The thread block dimensions will be powers of two
|
||||||
|
// - The thread block size will be less than 2^pow2
|
||||||
|
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
||||||
|
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||||
|
|
||||||
|
// Computes a 2D grid where each element is < UINT_MAX
|
||||||
|
// Assumes:
|
||||||
|
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||||
|
// - shape and strides correspond to a contiguous (no holes) but
|
||||||
|
// possibly broadcasted array
|
||||||
|
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
||||||
|
|
||||||
|
// Same as above but we do an implicit division with divisor.
|
||||||
|
// Basically, equivalent to factorizing
|
||||||
|
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||||
|
Dims get_2d_grid_dims_common(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor);
|
||||||
|
|
||||||
|
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
||||||
|
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
||||||
|
|
||||||
struct ContiguousIterator {
|
struct ContiguousIterator {
|
||||||
inline void step() {
|
inline void step() {
|
||||||
int dims = shape_.size();
|
int dims = shape_.size();
|
||||||
@@ -165,4 +195,14 @@ void shared_buffer_reshape(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
|
// Like the swapaxes op but safe to call in eval_gpu.
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||||
|
vec.erase(std::next(vec.begin(), index));
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -40,11 +40,13 @@ add_dependencies(mlx cpu_compiled_preamble)
|
|||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
@@ -74,8 +76,8 @@ target_sources(
|
|||||||
if(MLX_BUILD_ACCELERATE)
|
if(MLX_BUILD_ACCELERATE)
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(IOS)
|
if(IOS)
|
||||||
|
|||||||
@@ -14,10 +14,8 @@ template <typename InT, typename OpT>
|
|||||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||||
auto axis_size = in.shape()[axis];
|
auto axis_size = in.shape()[axis];
|
||||||
auto axis_stride = in.strides()[axis];
|
auto axis_stride = in.strides()[axis];
|
||||||
Strides strides = in.strides();
|
Strides strides = remove_index(in.strides(), axis);
|
||||||
Shape shape = in.shape();
|
Shape shape = remove_index(in.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
shape.erase(shape.begin() + axis);
|
|
||||||
auto in_ptr = in.data<InT>();
|
auto in_ptr = in.data<InT>();
|
||||||
auto out_ptr = out.data<uint32_t>();
|
auto out_ptr = out.data<uint32_t>();
|
||||||
|
|
||||||
|
|||||||
11
mlx/backend/cpu/available.cpp
Normal file
11
mlx/backend/cpu/available.cpp
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/available.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
||||||
9
mlx/backend/cpu/available.h
Normal file
9
mlx/backend/cpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
||||||
@@ -172,9 +172,12 @@ void binary_float(
|
|||||||
case bfloat16:
|
case bfloat16:
|
||||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[binary_float] Only supports non-complex floating point types.");
|
"[binary_float] Only supports floating point types.");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
|||||||
|
|
||||||
// The decomposition is computed in place, so just copy the input to the
|
// The decomposition is computed in place, so just copy the input to the
|
||||||
// output.
|
// output.
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
factor,
|
factor,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -40,7 +40,10 @@ struct CompilerCache {
|
|||||||
std::shared_mutex mtx;
|
std::shared_mutex mtx;
|
||||||
};
|
};
|
||||||
|
|
||||||
static CompilerCache cache{};
|
static CompilerCache& cache() {
|
||||||
|
static CompilerCache cache_;
|
||||||
|
return cache_;
|
||||||
|
};
|
||||||
|
|
||||||
// GPU compile is always available if the GPU is available and since we are in
|
// GPU compile is always available if the GPU is available and since we are in
|
||||||
// this file CPU compile is also available.
|
// this file CPU compile is also available.
|
||||||
@@ -56,14 +59,16 @@ void* compile(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::function<std::string(void)>& source_builder) {
|
const std::function<std::string(void)>& source_builder) {
|
||||||
{
|
{
|
||||||
std::shared_lock lock(cache.mtx);
|
std::shared_lock lock(cache().mtx);
|
||||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
if (auto it = cache().kernels.find(kernel_name);
|
||||||
|
it != cache().kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_lock lock(cache.mtx);
|
std::unique_lock lock(cache().mtx);
|
||||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
if (auto it = cache().kernels.find(kernel_name);
|
||||||
|
it != cache().kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
std::string source_code = source_builder();
|
std::string source_code = source_builder();
|
||||||
@@ -120,10 +125,10 @@ void* compile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load library
|
// load library
|
||||||
cache.libs.emplace_back(shared_lib_path);
|
cache().libs.emplace_back(shared_lib_path);
|
||||||
|
|
||||||
// Load function
|
// Load function
|
||||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
||||||
if (!fun) {
|
if (!fun) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||||
@@ -131,7 +136,7 @@ void* compile(
|
|||||||
<< dlerror();
|
<< dlerror();
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
cache.kernels.insert({kernel_name, fun});
|
cache().kernels.insert({kernel_name, fun});
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,18 +146,9 @@ inline void build_kernel(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& tape,
|
const std::vector<array>& tape,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
bool contiguous,
|
bool contiguous,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
// All outputs should have the exact same shape and will be row contiguous
|
|
||||||
auto output_shape = outputs[0].shape();
|
|
||||||
auto output_strides = outputs[0].strides();
|
|
||||||
|
|
||||||
// Constants are scalars that are captured by value and cannot change
|
|
||||||
auto is_constant = [&constant_ids](const array& x) {
|
|
||||||
return constant_ids.find(x.id()) != constant_ids.end();
|
|
||||||
};
|
|
||||||
|
|
||||||
NodeNamer namer;
|
NodeNamer namer;
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
@@ -165,14 +161,15 @@ inline void build_kernel(
|
|||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
auto& xname = namer.get_name(x);
|
|
||||||
|
|
||||||
// Skip constants from the input list
|
// Skip constants from the input list
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
auto tstr = get_type_string(x.dtype());
|
auto tstr = get_type_string(x.dtype());
|
||||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||||
<< "];" << std::endl;
|
<< "];" << std::endl;
|
||||||
@@ -206,10 +203,11 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read the inputs in tmps
|
// Read the inputs in tmps
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||||
print_constant(os, x);
|
print_constant(os, x);
|
||||||
os << ";" << std::endl;
|
os << ";" << std::endl;
|
||||||
@@ -233,7 +231,7 @@ inline void build_kernel(
|
|||||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||||
} else {
|
} else {
|
||||||
x.primitive().print(os);
|
os << x.primitive().name();
|
||||||
os << "()(";
|
os << "()(";
|
||||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||||
@@ -259,8 +257,9 @@ inline void build_kernel(
|
|||||||
} else {
|
} else {
|
||||||
for (int d = ndim - 1; d >= 0; --d) {
|
for (int d = ndim - 1; d >= 0; --d) {
|
||||||
// Update pointers
|
// Update pointers
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
if (is_constant(x) || is_scalar(x)) {
|
const auto& x = inputs[i];
|
||||||
|
if (is_constant(i) || is_scalar(x)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
@@ -282,65 +281,45 @@ inline void build_kernel(
|
|||||||
void Compiled::eval_cpu(
|
void Compiled::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
if (kernel_lib_.empty()) {
|
|
||||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Figure out which kernel we are using
|
|
||||||
auto& shape = outputs[0].shape();
|
|
||||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
|
||||||
// Handle all broadcasting and collect function input arguments
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
|
// handle all broadcasting.
|
||||||
|
auto [contiguous, shape, strides] =
|
||||||
|
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||||
|
|
||||||
|
// Force allocating shape/strides on heap so we can take their data() first
|
||||||
|
// and then std::move them.
|
||||||
|
// TODO: Refactor code to avoid heap allocation.
|
||||||
|
shape.grow();
|
||||||
|
for (auto& s : strides) {
|
||||||
|
s.grow();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect function input arguments.
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
std::vector<std::vector<size_t>> strides;
|
int strides_index = 1;
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
// Skip constants.
|
if (is_constant_(i)) {
|
||||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& x = inputs[i];
|
const auto& x = inputs[i];
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
args.push_back((void*)x.data<void>());
|
args.push_back((void*)x.data<void>());
|
||||||
|
if (!contiguous && !is_scalar(x)) {
|
||||||
if (contiguous || is_scalar(x)) {
|
args.push_back(strides[strides_index++].data());
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Broadcast the input to the output shape.
|
|
||||||
std::vector<size_t> xstrides;
|
|
||||||
int j = 0;
|
|
||||||
for (; j < shape.size() - x.ndim(); j++) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
|
||||||
if (x.shape(i) == 1) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(x.strides()[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
strides.push_back(std::move(xstrides));
|
|
||||||
args.push_back(strides.back().data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel name from the lib
|
// Get the kernel name from the lib
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
kernel_name += std::to_string(shape.size());
|
kernel_name += std::to_string(ndim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the function
|
// Get the function
|
||||||
auto fn_ptr = compile(kernel_name, [&]() {
|
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << get_kernel_preamble() << std::endl;
|
kernel << get_kernel_preamble() << std::endl;
|
||||||
kernel << "extern \"C\" {" << std::endl;
|
kernel << "extern \"C\" {" << std::endl;
|
||||||
@@ -350,7 +329,7 @@ void Compiled::eval_cpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
contiguous,
|
contiguous,
|
||||||
ndim);
|
ndim);
|
||||||
// Close extern "C"
|
// Close extern "C"
|
||||||
@@ -358,26 +337,22 @@ void Compiled::eval_cpu(
|
|||||||
return kernel.str();
|
return kernel.str();
|
||||||
});
|
});
|
||||||
|
|
||||||
compiled_allocate_outputs(
|
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
|
||||||
|
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
args.push_back(x.data<void>());
|
args.push_back(x.data<void>());
|
||||||
encoder.set_output_array(x);
|
encoder.set_output_array(x);
|
||||||
}
|
}
|
||||||
Shape out_shape;
|
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
out_shape = outputs[0].shape();
|
args.push_back((void*)shape.data());
|
||||||
args.push_back((void*)out_shape.data());
|
|
||||||
} else {
|
} else {
|
||||||
args.push_back((void*)outputs[0].data_size());
|
args.push_back((void*)outputs[0].data_size());
|
||||||
}
|
}
|
||||||
auto fun = (void (*)(void**))fn_ptr;
|
auto fun = (void (*)(void**))fn_ptr;
|
||||||
encoder.dispatch(
|
encoder.dispatch([fun,
|
||||||
[fun,
|
args = std::move(args),
|
||||||
args = std::move(args),
|
strides = std::move(strides),
|
||||||
strides = std::move(strides),
|
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -295,7 +295,11 @@ inline void copy_inplace_dispatch(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
void copy_cpu_inplace(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
CopyType ctype,
|
||||||
|
Stream stream) {
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
encoder.set_input_array(src);
|
encoder.set_input_array(src);
|
||||||
encoder.set_output_array(dst);
|
encoder.set_output_array(dst);
|
||||||
@@ -305,7 +309,7 @@ void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
|||||||
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||||
bool donated = set_copy_output_data(src, dst, ctype);
|
bool donated = set_copy_output_data(src, dst, ctype);
|
||||||
if (donated && src.dtype() == dst.dtype()) {
|
if (donated && src.dtype() == dst.dtype()) {
|
||||||
// If the output has the same type as the input then there is nothing to
|
// If the output has the same type as the input then there is nothing to
|
||||||
@@ -315,10 +319,10 @@ void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
|||||||
if (ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
ctype = CopyType::General;
|
ctype = CopyType::General;
|
||||||
}
|
}
|
||||||
copy_inplace(src, dst, ctype, stream);
|
copy_cpu_inplace(src, dst, ctype, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy_inplace(
|
void copy_cpu_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const Shape& data_shape,
|
const Shape& data_shape,
|
||||||
@@ -373,4 +377,10 @@ void copy_inplace(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array contiguous_copy_cpu(const array& arr, Stream stream) {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -10,10 +10,14 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
|
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
|
void copy_cpu_inplace(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
CopyType ctype,
|
||||||
|
Stream stream);
|
||||||
|
|
||||||
void copy_inplace(
|
void copy_cpu_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const Shape& data_shape,
|
const Shape& data_shape,
|
||||||
@@ -26,4 +30,7 @@ void copy_inplace(
|
|||||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||||
|
|
||||||
|
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||||
|
array contiguous_copy_cpu(const array& arr, Stream stream);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
|||||||
if (arr.flags().row_contiguous) {
|
if (arr.flags().row_contiguous) {
|
||||||
return {arr, false};
|
return {arr, false};
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
return {contiguous_copy_cpu(arr, stream), true};
|
||||||
copy(arr, arr_copy, CopyType::General, stream);
|
|
||||||
return {arr_copy, true};
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -34,8 +32,7 @@ void AllReduce::eval_cpu(
|
|||||||
}
|
}
|
||||||
return in;
|
return in;
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
array arr_copy = contiguous_copy_cpu(in, s);
|
||||||
copy(in, arr_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(arr_copy);
|
out.copy_shared_buffer(arr_copy);
|
||||||
return arr_copy;
|
return arr_copy;
|
||||||
}
|
}
|
||||||
@@ -46,8 +43,15 @@ void AllReduce::eval_cpu(
|
|||||||
case Sum:
|
case Sum:
|
||||||
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
||||||
break;
|
break;
|
||||||
|
case Max:
|
||||||
|
distributed::detail::all_max(group(), in, outputs[0], stream());
|
||||||
|
break;
|
||||||
|
case Min:
|
||||||
|
distributed::detail::all_min(group(), in, outputs[0], stream());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error(
|
||||||
|
"Only all reduce sum, min and max are supported for now");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
174
mlx/backend/cpu/eig.cpp
Normal file
174
mlx/backend/cpu/eig.cpp
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cpu/copy.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
#include "mlx/backend/cpu/lapack.h"
|
||||||
|
#include "mlx/linalg.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void eig_impl(
|
||||||
|
array& a,
|
||||||
|
array& vectors,
|
||||||
|
array& values,
|
||||||
|
bool compute_eigenvectors,
|
||||||
|
Stream stream) {
|
||||||
|
using OT = std::complex<T>;
|
||||||
|
auto a_ptr = a.data<T>();
|
||||||
|
auto eig_ptr = values.data<OT>();
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_output_array(values);
|
||||||
|
OT* vec_ptr = nullptr;
|
||||||
|
if (compute_eigenvectors) {
|
||||||
|
encoder.set_output_array(vectors);
|
||||||
|
vec_ptr = vectors.data<OT>();
|
||||||
|
}
|
||||||
|
encoder.dispatch([a_ptr,
|
||||||
|
vec_ptr,
|
||||||
|
eig_ptr,
|
||||||
|
compute_eigenvectors,
|
||||||
|
N = vectors.shape(-1),
|
||||||
|
size = vectors.size()]() mutable {
|
||||||
|
// Work query
|
||||||
|
char jobr = 'N';
|
||||||
|
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
int n_vecs_l = compute_eigenvectors ? N : 1;
|
||||||
|
int lwork = -1;
|
||||||
|
int info;
|
||||||
|
{
|
||||||
|
T work;
|
||||||
|
int iwork;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
||||||
|
auto vec_tmp_data =
|
||||||
|
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
||||||
|
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||||
|
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||||
|
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||||
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
a_ptr,
|
||||||
|
&N,
|
||||||
|
eig_tmp,
|
||||||
|
eig_tmp + N,
|
||||||
|
vec_tmp,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
||||||
|
}
|
||||||
|
if (vec_ptr) {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
if (eig_ptr[i].imag() != 0) {
|
||||||
|
// This vector and the next are a pair
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vec_ptr[i * N + j] = {
|
||||||
|
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
||||||
|
vec_ptr[(i + 1) * N + j] = {
|
||||||
|
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vec_ptr += N * N;
|
||||||
|
}
|
||||||
|
a_ptr += N * N;
|
||||||
|
eig_ptr += N;
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream msg;
|
||||||
|
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
|
<< info;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
encoder.add_temporary(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Eig::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
const auto& a = inputs[0];
|
||||||
|
auto& values = outputs[0];
|
||||||
|
|
||||||
|
auto vectors = compute_eigenvectors_
|
||||||
|
? outputs[1]
|
||||||
|
: array(a.shape(), complex64, nullptr, {});
|
||||||
|
|
||||||
|
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
||||||
|
copy_cpu(
|
||||||
|
a,
|
||||||
|
a_copy,
|
||||||
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
stream());
|
||||||
|
|
||||||
|
values.set_data(allocator::malloc(values.nbytes()));
|
||||||
|
|
||||||
|
if (compute_eigenvectors_) {
|
||||||
|
// Set the strides and flags so the eigenvectors
|
||||||
|
// are in the columns of the output
|
||||||
|
auto flags = vectors.flags();
|
||||||
|
auto strides = vectors.strides();
|
||||||
|
auto ndim = a.ndim();
|
||||||
|
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||||
|
|
||||||
|
if (a.size() > 1) {
|
||||||
|
flags.row_contiguous = false;
|
||||||
|
if (ndim > 2) {
|
||||||
|
flags.col_contiguous = false;
|
||||||
|
} else {
|
||||||
|
flags.col_contiguous = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vectors.set_data(
|
||||||
|
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
|
||||||
|
}
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case float32:
|
||||||
|
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -12,6 +12,133 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, class Enable = void>
|
||||||
|
struct EighWork {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct EighWork<
|
||||||
|
T,
|
||||||
|
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||||
|
using R = T;
|
||||||
|
|
||||||
|
char jobz;
|
||||||
|
char uplo;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int liwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EighWork(char jobz_, char uplo_, int N_)
|
||||||
|
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
|
||||||
|
T work;
|
||||||
|
int iwork;
|
||||||
|
syevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&iwork,
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work);
|
||||||
|
liwork = iwork;
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* vectors, T* values) {
|
||||||
|
syevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
vectors,
|
||||||
|
&N,
|
||||||
|
values,
|
||||||
|
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<int*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct EighWork<std::complex<float>> {
|
||||||
|
using T = std::complex<float>;
|
||||||
|
using R = float;
|
||||||
|
|
||||||
|
char jobz;
|
||||||
|
char uplo;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int lrwork;
|
||||||
|
int liwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EighWork(char jobz_, char uplo_, int N_)
|
||||||
|
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
|
||||||
|
T work;
|
||||||
|
R rwork;
|
||||||
|
int iwork;
|
||||||
|
heevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&rwork,
|
||||||
|
&lrwork,
|
||||||
|
&iwork,
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work.real());
|
||||||
|
lrwork = static_cast<int>(rwork);
|
||||||
|
liwork = iwork;
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* vectors, R* values) {
|
||||||
|
heevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
vectors,
|
||||||
|
&N,
|
||||||
|
values,
|
||||||
|
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
&lrwork,
|
||||||
|
static_cast<int*>(buffers[2].buffer.raw_ptr()),
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
if (jobz == 'V') {
|
||||||
|
// We have pre-transposed the vectors but we also must conjugate them
|
||||||
|
// when they are complex.
|
||||||
|
//
|
||||||
|
// We could vectorize this but it is so fast in comparison to heevd that
|
||||||
|
// it doesn't really matter.
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
*vectors = std::conj(*vectors);
|
||||||
|
vectors++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void eigh_impl(
|
void eigh_impl(
|
||||||
array& vectors,
|
array& vectors,
|
||||||
@@ -19,8 +146,10 @@ void eigh_impl(
|
|||||||
const std::string& uplo,
|
const std::string& uplo,
|
||||||
bool compute_eigenvectors,
|
bool compute_eigenvectors,
|
||||||
Stream stream) {
|
Stream stream) {
|
||||||
|
using R = typename EighWork<T>::R;
|
||||||
|
|
||||||
auto vec_ptr = vectors.data<T>();
|
auto vec_ptr = vectors.data<T>();
|
||||||
auto eig_ptr = values.data<T>();
|
auto eig_ptr = values.data<R>();
|
||||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
@@ -33,49 +162,17 @@ void eigh_impl(
|
|||||||
N = vectors.shape(-1),
|
N = vectors.shape(-1),
|
||||||
size = vectors.size()]() mutable {
|
size = vectors.size()]() mutable {
|
||||||
// Work query
|
// Work query
|
||||||
int lwork = -1;
|
EighWork<T> work(jobz, uplo, N);
|
||||||
int liwork = -1;
|
|
||||||
int info;
|
|
||||||
{
|
|
||||||
T work;
|
|
||||||
int iwork;
|
|
||||||
syevd<T>(
|
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&N,
|
|
||||||
nullptr,
|
|
||||||
&work,
|
|
||||||
&lwork,
|
|
||||||
&iwork,
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
lwork = static_cast<int>(work);
|
|
||||||
liwork = iwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
// Work loop
|
||||||
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
|
||||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
syevd<T>(
|
work.run(vec_ptr, eig_ptr);
|
||||||
&jobz,
|
|
||||||
&uplo,
|
|
||||||
&N,
|
|
||||||
vec_ptr,
|
|
||||||
&N,
|
|
||||||
eig_ptr,
|
|
||||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
|
||||||
&lwork,
|
|
||||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
|
||||||
&liwork,
|
|
||||||
&info);
|
|
||||||
vec_ptr += N * N;
|
vec_ptr += N * N;
|
||||||
eig_ptr += N;
|
eig_ptr += N;
|
||||||
if (info != 0) {
|
if (work.info != 0) {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
<< info;
|
<< work.info;
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -99,7 +196,7 @@ void Eigh::eval_cpu(
|
|||||||
|
|
||||||
values.set_data(allocator::malloc(values.nbytes()));
|
values.set_data(allocator::malloc(values.nbytes()));
|
||||||
|
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
vectors,
|
vectors,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
@@ -131,6 +228,10 @@ void Eigh::eval_cpu(
|
|||||||
eigh_impl<double>(
|
eigh_impl<double>(
|
||||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||||
break;
|
break;
|
||||||
|
case complex64:
|
||||||
|
eigh_impl<std::complex<float>>(
|
||||||
|
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[Eigh::eval_cpu] only supports float32 or float64.");
|
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cpu/gemm.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
void matmul<bfloat16_t>(
|
|
||||||
const bfloat16_t*,
|
|
||||||
const bfloat16_t*,
|
|
||||||
bfloat16_t*,
|
|
||||||
bool,
|
|
||||||
bool,
|
|
||||||
size_t,
|
|
||||||
size_t,
|
|
||||||
size_t,
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
size_t,
|
|
||||||
const Shape&,
|
|
||||||
const Strides&,
|
|
||||||
const Shape&,
|
|
||||||
const Strides&) {
|
|
||||||
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cpu/gemm.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
void matmul<float16_t>(
|
|
||||||
const float16_t*,
|
|
||||||
const float16_t*,
|
|
||||||
float16_t*,
|
|
||||||
bool,
|
|
||||||
bool,
|
|
||||||
size_t,
|
|
||||||
size_t,
|
|
||||||
size_t,
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
size_t,
|
|
||||||
const Shape&,
|
|
||||||
const Strides&,
|
|
||||||
const Shape&,
|
|
||||||
const Strides&) {
|
|
||||||
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
45
mlx/backend/cpu/gemms/simd_bf16.cpp
Normal file
45
mlx/backend/cpu/gemms/simd_bf16.cpp
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cpu/gemm.h"
|
||||||
|
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void matmul<bfloat16_t>(
|
||||||
|
const bfloat16_t* a,
|
||||||
|
const bfloat16_t* b,
|
||||||
|
bfloat16_t* out,
|
||||||
|
bool a_transposed,
|
||||||
|
bool b_transposed,
|
||||||
|
size_t lda,
|
||||||
|
size_t ldb,
|
||||||
|
size_t ldc,
|
||||||
|
float alpha,
|
||||||
|
float beta,
|
||||||
|
size_t batch_size,
|
||||||
|
const Shape& a_shape,
|
||||||
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides) {
|
||||||
|
auto ndim = a_shape.size();
|
||||||
|
size_t M = a_shape[ndim - 2];
|
||||||
|
size_t N = b_shape[ndim - 1];
|
||||||
|
size_t K = a_shape[ndim - 1];
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
simd_gemm<bfloat16_t, float>(
|
||||||
|
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||||
|
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||||
|
out + M * N * i,
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cpu/gemm.h"
|
||||||
|
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void matmul<float16_t>(
|
||||||
|
const float16_t* a,
|
||||||
|
const float16_t* b,
|
||||||
|
float16_t* out,
|
||||||
|
bool a_transposed,
|
||||||
|
bool b_transposed,
|
||||||
|
size_t lda,
|
||||||
|
size_t ldb,
|
||||||
|
size_t ldc,
|
||||||
|
float alpha,
|
||||||
|
float beta,
|
||||||
|
size_t batch_size,
|
||||||
|
const Shape& a_shape,
|
||||||
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides) {
|
||||||
|
auto ndim = a_shape.size();
|
||||||
|
size_t M = a_shape[ndim - 2];
|
||||||
|
size_t N = b_shape[ndim - 1];
|
||||||
|
size_t K = a_shape[ndim - 1];
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
simd_gemm<float16_t, float>(
|
||||||
|
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||||
|
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||||
|
out + M * N * i,
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
139
mlx/backend/cpu/gemms/simd_gemm.h
Normal file
139
mlx/backend/cpu/gemms/simd_gemm.h
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline int ceildiv(int a, int b) {
|
||||||
|
return (a + b - 1) / b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int block_size, typename T, typename AccT>
|
||||||
|
void load_block(
|
||||||
|
const T* in,
|
||||||
|
AccT* out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int i,
|
||||||
|
int j,
|
||||||
|
bool transpose) {
|
||||||
|
if (transpose) {
|
||||||
|
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||||
|
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||||
|
out[jj * block_size + ii] =
|
||||||
|
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||||
|
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||||
|
out[ii * block_size + jj] =
|
||||||
|
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename AccT>
|
||||||
|
void simd_gemm(
|
||||||
|
const T* a,
|
||||||
|
const T* b,
|
||||||
|
T* c,
|
||||||
|
bool a_trans,
|
||||||
|
bool b_trans,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
float alpha,
|
||||||
|
float beta) {
|
||||||
|
constexpr int block_size = 16;
|
||||||
|
constexpr int simd_size = simd::max_size<AccT>;
|
||||||
|
static_assert(
|
||||||
|
(block_size % simd_size) == 0,
|
||||||
|
"Block size must be divisible by SIMD size");
|
||||||
|
|
||||||
|
int last_k_block_size = K - block_size * (K / block_size);
|
||||||
|
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
|
||||||
|
for (int i = 0; i < ceildiv(M, block_size); i++) {
|
||||||
|
for (int j = 0; j < ceildiv(N, block_size); j++) {
|
||||||
|
AccT c_block[block_size * block_size] = {0.0};
|
||||||
|
AccT a_block[block_size * block_size];
|
||||||
|
AccT b_block[block_size * block_size];
|
||||||
|
|
||||||
|
int k = 0;
|
||||||
|
for (; k < K / block_size; k++) {
|
||||||
|
// Load a and b blocks
|
||||||
|
if (a_trans) {
|
||||||
|
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||||
|
} else {
|
||||||
|
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||||
|
}
|
||||||
|
if (b_trans) {
|
||||||
|
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||||
|
} else {
|
||||||
|
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply and accumulate
|
||||||
|
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||||
|
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||||
|
for (int kk = 0; kk < block_size; kk += simd_size) {
|
||||||
|
auto av =
|
||||||
|
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||||
|
auto bv =
|
||||||
|
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||||
|
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (last_k_block_size) {
|
||||||
|
// Load a and b blocks
|
||||||
|
if (a_trans) {
|
||||||
|
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||||
|
} else {
|
||||||
|
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||||
|
}
|
||||||
|
if (b_trans) {
|
||||||
|
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||||
|
} else {
|
||||||
|
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply and accumulate
|
||||||
|
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||||
|
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||||
|
int kk = 0;
|
||||||
|
for (; kk < last_k_simd_block; kk += simd_size) {
|
||||||
|
auto av =
|
||||||
|
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||||
|
auto bv =
|
||||||
|
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||||
|
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||||
|
}
|
||||||
|
for (; kk < last_k_block_size; ++kk) {
|
||||||
|
c_block[ii * block_size + jj] +=
|
||||||
|
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store
|
||||||
|
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||||
|
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||||
|
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
|
||||||
|
if (beta != 0) {
|
||||||
|
c[c_idx] = static_cast<T>(
|
||||||
|
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
|
||||||
|
} else {
|
||||||
|
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
copy(
|
copy_cpu(
|
||||||
in,
|
in,
|
||||||
out,
|
out,
|
||||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -257,15 +257,11 @@ void gather_axis(
|
|||||||
const array& ind,
|
const array& ind,
|
||||||
array& out,
|
array& out,
|
||||||
const int axis) {
|
const int axis) {
|
||||||
auto strides = ind.strides();
|
auto shape = remove_index(ind.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
ContiguousIterator ind_it(
|
||||||
auto shape = ind.shape();
|
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
|
||||||
shape.erase(shape.begin() + axis);
|
ContiguousIterator src_it(
|
||||||
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
shape, remove_index(src.strides(), axis), src.ndim() - 1);
|
||||||
|
|
||||||
strides = src.strides();
|
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
|
||||||
|
|
||||||
auto ind_ptr = ind.data<IdxT>();
|
auto ind_ptr = ind.data<IdxT>();
|
||||||
auto src_ptr = src.data<T>();
|
auto src_ptr = src.data<T>();
|
||||||
@@ -521,7 +517,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Copy src into out (copy allocates memory for out)
|
// Copy src into out (copy allocates memory for out)
|
||||||
auto ctype =
|
auto ctype =
|
||||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(src, out, ctype, stream());
|
copy_cpu(src, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
std::vector<array> inds;
|
std::vector<array> inds;
|
||||||
@@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
template <typename T, typename IdxT, typename OpT>
|
template <typename T, typename IdxT, typename OpT>
|
||||||
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||||
auto strides = idx.strides();
|
auto shape = remove_index(idx.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
ContiguousIterator idx_it(
|
||||||
auto shape = idx.shape();
|
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
|
||||||
shape.erase(shape.begin() + axis);
|
ContiguousIterator upd_it(
|
||||||
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
|
||||||
|
|
||||||
strides = upd.strides();
|
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
|
||||||
|
|
||||||
auto idx_ptr = idx.data<IdxT>();
|
auto idx_ptr = idx.data<IdxT>();
|
||||||
auto upd_ptr = upd.data<T>();
|
auto upd_ptr = upd.data<T>();
|
||||||
@@ -694,7 +686,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Copy src into out (copy allocates memory for out)
|
// Copy src into out (copy allocates memory for out)
|
||||||
auto ctype =
|
auto ctype =
|
||||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(src, out, ctype, stream());
|
copy_cpu(src, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(idx);
|
encoder.set_input_array(idx);
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ void inverse_impl(
|
|||||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||||
|
|
||||||
// The inverse is computed in place, so just copy the input to the output.
|
// The inverse is computed in place, so just copy the input to the output.
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
inv,
|
inv,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cpu/jit_compiler.h"
|
#include "mlx/backend/cpu/jit_compiler.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|||||||
@@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
// Required for Visual Studio.
|
|
||||||
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#define LAPACK_COMPLEX_CUSTOM
|
#define LAPACK_COMPLEX_CUSTOM
|
||||||
#define lapack_complex_float std::complex<float>
|
#define lapack_complex_float std::complex<float>
|
||||||
#define lapack_complex_double std::complex<double>
|
#define lapack_complex_double std::complex<double>
|
||||||
#endif
|
#define lapack_complex_float_real(z) ((z).real())
|
||||||
|
#define lapack_complex_float_imag(z) ((z).imag())
|
||||||
|
#define lapack_complex_double_real(z) ((z).real())
|
||||||
|
#define lapack_complex_double_imag(z) ((z).imag())
|
||||||
|
|
||||||
#ifdef MLX_USE_ACCELERATE
|
#ifdef MLX_USE_ACCELERATE
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
@@ -32,7 +32,7 @@
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
|
#define INSTANTIATE_LAPACK_REAL(FUNC) \
|
||||||
template <typename T, typename... Args> \
|
template <typename T, typename... Args> \
|
||||||
void FUNC(Args... args) { \
|
void FUNC(Args... args) { \
|
||||||
if constexpr (std::is_same_v<T, float>) { \
|
if constexpr (std::is_same_v<T, float>) { \
|
||||||
@@ -42,11 +42,24 @@
|
|||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_LAPACK_TYPES(geqrf)
|
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||||
INSTANTIATE_LAPACK_TYPES(orgqr)
|
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||||
INSTANTIATE_LAPACK_TYPES(syevd)
|
INSTANTIATE_LAPACK_REAL(syevd)
|
||||||
INSTANTIATE_LAPACK_TYPES(potrf)
|
INSTANTIATE_LAPACK_REAL(geev)
|
||||||
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
INSTANTIATE_LAPACK_REAL(potrf)
|
||||||
INSTANTIATE_LAPACK_TYPES(getrf)
|
INSTANTIATE_LAPACK_REAL(gesvdx)
|
||||||
INSTANTIATE_LAPACK_TYPES(getri)
|
INSTANTIATE_LAPACK_REAL(getrf)
|
||||||
INSTANTIATE_LAPACK_TYPES(trtri)
|
INSTANTIATE_LAPACK_REAL(getri)
|
||||||
|
INSTANTIATE_LAPACK_REAL(trtri)
|
||||||
|
|
||||||
|
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
|
||||||
|
template <typename T, typename... Args> \
|
||||||
|
void FUNC(Args... args) { \
|
||||||
|
if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||||
|
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||||
|
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||||
|
|||||||
@@ -87,8 +87,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_cpu(x, s);
|
||||||
copy(x, x_copy, CopyType::General, s);
|
|
||||||
encoder.add_temporary(x_copy);
|
encoder.add_temporary(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ void luf_impl(
|
|||||||
strides[ndim - 1] = M;
|
strides[ndim - 1] = M;
|
||||||
strides[ndim - 2] = 1;
|
strides[ndim - 2] = 1;
|
||||||
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
a,
|
a,
|
||||||
lu,
|
lu,
|
||||||
a.shape(),
|
a.shape(),
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
#include "mlx/backend/cpu/gemm.h"
|
||||||
#include "mlx/backend/cpu/lapack.h"
|
#include "mlx/backend/cpu/lapack.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@@ -52,6 +53,58 @@ inline void mask_matrix(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void segmented_mm(
|
||||||
|
const T* a,
|
||||||
|
const T* b,
|
||||||
|
const uint32_t* segments,
|
||||||
|
T* out,
|
||||||
|
bool a_transposed,
|
||||||
|
bool b_transposed,
|
||||||
|
size_t lda,
|
||||||
|
size_t ldb,
|
||||||
|
const Shape& a_shape,
|
||||||
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides,
|
||||||
|
size_t num_segments,
|
||||||
|
const Shape& segments_shape,
|
||||||
|
const Strides& segments_strides) {
|
||||||
|
int ndim = a_shape.size();
|
||||||
|
Shape a_copy = a_shape;
|
||||||
|
Shape b_copy = b_shape;
|
||||||
|
int32_t M = a_copy[ndim - 2];
|
||||||
|
int32_t N = b_copy[ndim - 1];
|
||||||
|
for (int i = 0; i < num_segments; i++) {
|
||||||
|
uint32_t k_start =
|
||||||
|
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
||||||
|
uint32_t k_end =
|
||||||
|
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||||
|
if (k_end <= k_start) {
|
||||||
|
std::fill_n(out + i * M * N, M * N, T(0));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
a_copy[ndim - 1] = k_end - k_start;
|
||||||
|
b_copy[ndim - 2] = k_end - k_start;
|
||||||
|
matmul<T>(
|
||||||
|
a + k_start * a_strides[ndim - 1],
|
||||||
|
b + k_start * b_strides[ndim - 2],
|
||||||
|
out + i * M * N,
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
N,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
1,
|
||||||
|
a_copy,
|
||||||
|
a_strides,
|
||||||
|
b_copy,
|
||||||
|
b_strides);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -71,21 +124,20 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
||||||
if (do_copy) {
|
if (do_copy) {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
copy(arr, arr_copy, CopyType::Vector, s);
|
copy_cpu(arr, arr_copy, CopyType::Vector, s);
|
||||||
return std::make_tuple(false, stx, arr_copy, true);
|
return std::make_tuple(false, stx, arr_copy, true);
|
||||||
}
|
}
|
||||||
return std::make_tuple(false, stx, arr, false);
|
return std::make_tuple(false, stx, arr, false);
|
||||||
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
||||||
if (do_copy) {
|
if (do_copy) {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
copy(arr, arr_copy, CopyType::Vector, s);
|
copy_cpu(arr, arr_copy, CopyType::Vector, s);
|
||||||
return std::make_tuple(true, sty, arr_copy, true);
|
return std::make_tuple(true, sty, arr_copy, true);
|
||||||
}
|
}
|
||||||
return std::make_tuple(true, sty, arr, false);
|
return std::make_tuple(true, sty, arr, false);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
|
||||||
copy(arr, arr_copy, CopyType::General, s);
|
|
||||||
int64_t stx = arr.shape(-1);
|
int64_t stx = arr.shape(-1);
|
||||||
|
array arr_copy = contiguous_copy_cpu(arr, s);
|
||||||
return std::make_tuple(false, stx, arr_copy, true);
|
return std::make_tuple(false, stx, arr_copy, true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -333,7 +385,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return std::make_tuple(true, sty, arr);
|
return std::make_tuple(true, sty, arr);
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, s);
|
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||||
int64_t stx = arr.shape(-1);
|
int64_t stx = arr.shape(-1);
|
||||||
return std::make_tuple(false, stx, temps.back());
|
return std::make_tuple(false, stx, temps.back());
|
||||||
}
|
}
|
||||||
@@ -437,4 +489,121 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.add_temporaries(std::move(temps));
|
encoder.add_temporaries(std::move(temps));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
auto check_transpose = [&s, &encoder](const array& x) {
|
||||||
|
auto stx = x.strides()[x.ndim() - 2];
|
||||||
|
auto sty = x.strides()[x.ndim() - 1];
|
||||||
|
if (stx == x.shape(-1) && sty == 1) {
|
||||||
|
return std::make_tuple(false, stx, x);
|
||||||
|
} else if (stx == 1 && sty == x.shape(-2)) {
|
||||||
|
return std::make_tuple(true, sty, x);
|
||||||
|
} else {
|
||||||
|
array xc(x.shape(), x.dtype(), nullptr, {});
|
||||||
|
copy_cpu(x, xc, CopyType::General, s);
|
||||||
|
encoder.add_temporary(xc);
|
||||||
|
int64_t stx = x.shape(-1);
|
||||||
|
return std::make_tuple(false, stx, xc);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto [a_transposed, lda, a] = check_transpose(inputs[0]);
|
||||||
|
auto [b_transposed, ldb, b] = check_transpose(inputs[1]);
|
||||||
|
auto& segments = inputs[2];
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(segments);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
segments = array::unsafe_weak_copy(segments),
|
||||||
|
out_ptr = out.data<void>(),
|
||||||
|
a_transposed = a_transposed,
|
||||||
|
b_transposed = b_transposed,
|
||||||
|
lda = lda,
|
||||||
|
ldb = ldb]() {
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case float64:
|
||||||
|
segmented_mm<double>(
|
||||||
|
a.data<double>(),
|
||||||
|
b.data<double>(),
|
||||||
|
segments.data<uint32_t>(),
|
||||||
|
static_cast<double*>(out_ptr),
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
a.shape(),
|
||||||
|
a.strides(),
|
||||||
|
b.shape(),
|
||||||
|
b.strides(),
|
||||||
|
segments.size() / 2,
|
||||||
|
segments.shape(),
|
||||||
|
segments.strides());
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
segmented_mm<float>(
|
||||||
|
a.data<float>(),
|
||||||
|
b.data<float>(),
|
||||||
|
segments.data<uint32_t>(),
|
||||||
|
static_cast<float*>(out_ptr),
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
a.shape(),
|
||||||
|
a.strides(),
|
||||||
|
b.shape(),
|
||||||
|
b.strides(),
|
||||||
|
segments.size() / 2,
|
||||||
|
segments.shape(),
|
||||||
|
segments.strides());
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
segmented_mm<float16_t>(
|
||||||
|
a.data<float16_t>(),
|
||||||
|
b.data<float16_t>(),
|
||||||
|
segments.data<uint32_t>(),
|
||||||
|
static_cast<float16_t*>(out_ptr),
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
a.shape(),
|
||||||
|
a.strides(),
|
||||||
|
b.shape(),
|
||||||
|
b.strides(),
|
||||||
|
segments.size() / 2,
|
||||||
|
segments.shape(),
|
||||||
|
segments.strides());
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
segmented_mm<bfloat16_t>(
|
||||||
|
a.data<bfloat16_t>(),
|
||||||
|
b.data<bfloat16_t>(),
|
||||||
|
segments.data<uint32_t>(),
|
||||||
|
static_cast<bfloat16_t*>(out_ptr),
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
a.shape(),
|
||||||
|
a.strides(),
|
||||||
|
b.shape(),
|
||||||
|
b.strides(),
|
||||||
|
segments.size() / 2,
|
||||||
|
segments.shape(),
|
||||||
|
segments.strides());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Segmented mm supports only real float types.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ void matmul_general(
|
|||||||
return std::make_tuple(true, sty, arr);
|
return std::make_tuple(true, sty, arr);
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, stream);
|
copy_cpu(arr, temps.back(), CopyType::General, stream);
|
||||||
stx = arr.shape(-1);
|
stx = arr.shape(-1);
|
||||||
return std::make_tuple(false, stx, temps.back());
|
return std::make_tuple(false, stx, temps.back());
|
||||||
}
|
}
|
||||||
@@ -132,14 +132,20 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||||
}
|
}
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Fill output with C
|
// Fill output with C
|
||||||
auto& c = inputs[2];
|
auto& c = inputs[2];
|
||||||
CopyType ctype = c.data_size() == 1
|
CopyType ctype = c.data_size() == 1
|
||||||
? CopyType::Scalar
|
? CopyType::Scalar
|
||||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||||
copy(c, out, ctype, stream());
|
copy_cpu(c, out, ctype, stream());
|
||||||
|
if (inputs[0].shape(-1) == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ void reshape(const array& in, array& out) {
|
|||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
if (copy_necessary) {
|
if (copy_necessary) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||||
} else {
|
} else {
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
}
|
}
|
||||||
@@ -175,7 +175,7 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(in, out, ctype, stream());
|
copy_cpu(in, out, ctype, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -198,7 +198,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
size_t data_offset = strides[axis_] * sizes[i];
|
size_t data_offset = strides[axis_] * sizes[i];
|
||||||
out_slice.copy_shared_buffer(
|
out_slice.copy_shared_buffer(
|
||||||
out, strides, flags, out_slice.size(), data_offset);
|
out, strides, flags, out_slice.size(), data_offset);
|
||||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,7 +211,7 @@ void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
copy(in, out, CopyType::General, stream());
|
copy_cpu(in, out, CopyType::General, stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,7 +235,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
} else {
|
} else {
|
||||||
ctype = CopyType::General;
|
ctype = CopyType::General;
|
||||||
}
|
}
|
||||||
copy(in, out, ctype, stream());
|
copy_cpu(in, out, ctype, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -251,7 +251,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||||
|
|
||||||
// Fill output with val
|
// Fill output with val
|
||||||
copy(val, out, CopyType::Scalar, stream());
|
copy_cpu(val, out, CopyType::Scalar, stream());
|
||||||
|
|
||||||
// Find offset for start of input values
|
// Find offset for start of input values
|
||||||
size_t data_offset = 0;
|
size_t data_offset = 0;
|
||||||
@@ -266,7 +266,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||||
|
|
||||||
// Copy input values into the slice
|
// Copy input values into the slice
|
||||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -340,7 +340,7 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto [in_offset, donated] =
|
auto [in_offset, donated] =
|
||||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
/* const array& src = */ in,
|
/* const array& src = */ in,
|
||||||
/* array& dst = */ out,
|
/* array& dst = */ out,
|
||||||
/* const Shape& data_shape = */ out.shape(),
|
/* const Shape& data_shape = */ out.shape(),
|
||||||
@@ -372,11 +372,11 @@ void DynamicSliceUpdate::eval_cpu(
|
|||||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||||
? CopyType::Vector
|
? CopyType::Vector
|
||||||
: CopyType::General;
|
: CopyType::General;
|
||||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||||
|
|
||||||
auto [out_offset, donated] =
|
auto [out_offset, donated] =
|
||||||
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
|
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
/* const array& src = */ upd,
|
/* const array& src = */ upd,
|
||||||
/* array& dst = */ out,
|
/* array& dst = */ out,
|
||||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||||
@@ -412,14 +412,14 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||||
? CopyType::Vector
|
? CopyType::Vector
|
||||||
: CopyType::General;
|
: CopyType::General;
|
||||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||||
|
|
||||||
// Calculate out strides, initial offset and if copy needs to be made
|
// Calculate out strides, initial offset and if copy needs to be made
|
||||||
auto [data_offset, out_strides] =
|
auto [data_offset, out_strides] =
|
||||||
prepare_slice(out, start_indices_, strides_);
|
prepare_slice(out, start_indices_, strides_);
|
||||||
|
|
||||||
// Do copy
|
// Do copy
|
||||||
copy_inplace(
|
copy_cpu_inplace(
|
||||||
/* const array& src = */ upd,
|
/* const array& src = */ upd,
|
||||||
/* array& dst = */ out,
|
/* array& dst = */ out,
|
||||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||||
@@ -456,9 +456,9 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (in.dtype() == bool_) {
|
if (in.dtype() == bool_) {
|
||||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||||
in_tmp.copy_shared_buffer(in);
|
in_tmp.copy_shared_buffer(in);
|
||||||
copy_inplace(in_tmp, tmp, CopyType::General, stream());
|
copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream());
|
||||||
} else {
|
} else {
|
||||||
copy_inplace(in, tmp, CopyType::General, stream());
|
copy_cpu_inplace(in, tmp, CopyType::General, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto flags = out.flags();
|
auto flags = out.flags();
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
|||||||
strides[in.ndim() - 2] = 1;
|
strides[in.ndim() - 2] = 1;
|
||||||
strides[in.ndim() - 1] = M;
|
strides[in.ndim() - 1] = M;
|
||||||
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
||||||
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
q.set_data(allocator::malloc(q.nbytes()));
|
q.set_data(allocator::malloc(q.nbytes()));
|
||||||
r.set_data(allocator::malloc(r.nbytes()));
|
r.set_data(allocator::malloc(r.nbytes()));
|
||||||
|
|||||||
@@ -13,9 +13,18 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
||||||
|
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
|
||||||
|
auto power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int bits>
|
template <typename T, int bits>
|
||||||
void extract_bits(const uint8_t* w_in, T* w_out) {
|
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||||
assert(bits == 3 || bits == 6);
|
static_assert(bits == 3 || bits == 5 || bits == 6);
|
||||||
if (bits == 3) {
|
if (bits == 3) {
|
||||||
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||||
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||||
@@ -25,6 +34,16 @@ void extract_bits(const uint8_t* w_in, T* w_out) {
|
|||||||
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
||||||
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
||||||
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
||||||
|
} else if (bits == 5) {
|
||||||
|
w_out[0] = static_cast<T>(w_in[0] & 0x1f);
|
||||||
|
w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));
|
||||||
|
w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);
|
||||||
|
w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));
|
||||||
|
w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));
|
||||||
|
w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);
|
||||||
|
w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));
|
||||||
|
w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);
|
||||||
|
|
||||||
} else if (bits == 6) {
|
} else if (bits == 6) {
|
||||||
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
||||||
w_out[1] =
|
w_out[1] =
|
||||||
@@ -46,8 +65,8 @@ void _qmm(
|
|||||||
int N,
|
int N,
|
||||||
int K) {
|
int K) {
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||||
constexpr int packs_in_group = group_size / pack_factor;
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
|
||||||
for (int m = 0; m < M; m++) {
|
for (int m = 0; m < M; m++) {
|
||||||
@@ -65,7 +84,7 @@ void _qmm(
|
|||||||
T scale = *scales_local++;
|
T scale = *scales_local++;
|
||||||
T bias = *biases_local++;
|
T bias = *biases_local++;
|
||||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||||
if (bits == 3 || bits == 6) {
|
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||||
T wl[pack_factor];
|
T wl[pack_factor];
|
||||||
extract_bits<T, bits>(w_local, wl);
|
extract_bits<T, bits>(w_local, wl);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@@ -104,8 +123,9 @@ void _qmm_t(
|
|||||||
int N,
|
int N,
|
||||||
int K) {
|
int K) {
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
|
||||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||||
|
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||||
constexpr int packs_in_group = group_size / pack_factor;
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
|
||||||
for (int m = 0; m < M; m++) {
|
for (int m = 0; m < M; m++) {
|
||||||
@@ -121,7 +141,7 @@ void _qmm_t(
|
|||||||
T bias = *biases_local++;
|
T bias = *biases_local++;
|
||||||
|
|
||||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||||
if (bits == 3 || bits == 6) {
|
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||||
T wl[pack_factor];
|
T wl[pack_factor];
|
||||||
extract_bits<T, bits>(w_local, wl);
|
extract_bits<T, bits>(w_local, wl);
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@@ -304,6 +324,10 @@ void _qmm_dispatch_typed(
|
|||||||
_qmm_dispatch_group<T, 4>(
|
_qmm_dispatch_group<T, 4>(
|
||||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
break;
|
break;
|
||||||
|
case 5:
|
||||||
|
_qmm_dispatch_group<T, 5>(
|
||||||
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
|
break;
|
||||||
case 6:
|
case 6:
|
||||||
_qmm_dispatch_group<T, 6>(
|
_qmm_dispatch_group<T, 6>(
|
||||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||||
@@ -505,7 +529,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return arr;
|
return arr;
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, s);
|
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||||
return temps.back();
|
return temps.back();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -555,7 +579,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return arr;
|
return arr;
|
||||||
} else {
|
} else {
|
||||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||||
copy(arr, temps.back(), CopyType::General, s);
|
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||||
return temps.back();
|
return temps.back();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -613,9 +637,8 @@ void quantize(
|
|||||||
float eps = 1e-7;
|
float eps = 1e-7;
|
||||||
|
|
||||||
bool power_of_2_bits = is_power_of_2(bits);
|
bool power_of_2_bits = is_power_of_2(bits);
|
||||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
int el_per_int = get_pack_factor(bits, 32);
|
||||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
int bytes_per_pack = get_bytes_per_pack(bits);
|
||||||
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
|
||||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||||
size_t n_groups = w_size / group_size;
|
size_t n_groups = w_size / group_size;
|
||||||
|
|
||||||
@@ -640,15 +663,21 @@ void quantize(
|
|||||||
}
|
}
|
||||||
size_t out_idx = i * int_per_group;
|
size_t out_idx = i * int_per_group;
|
||||||
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
||||||
uint32_t out_el = 0;
|
uint64_t out_el = 0;
|
||||||
for (int k = 0; k < el_per_int; ++k) {
|
for (int k = 0; k < el_per_int; ++k) {
|
||||||
float w_el = w[w_idx + j * el_per_int + k];
|
float w_el = w[w_idx + j * el_per_int + k];
|
||||||
w_el = std::rint((w_el - bias) / scale);
|
w_el = std::rint((w_el - bias) / scale);
|
||||||
w_el = std::min(std::max(w_el, 0.0f), n_bins);
|
w_el = std::min(std::max(w_el, 0.0f), n_bins);
|
||||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
out_el |= static_cast<uint64_t>(w_el) << (k * bits);
|
||||||
}
|
}
|
||||||
if (power_of_2_bits) {
|
if (power_of_2_bits) {
|
||||||
out[out_idx + j] = out_el;
|
out[out_idx + j] = out_el;
|
||||||
|
} else if (bits == 5) {
|
||||||
|
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||||
|
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||||
|
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
||||||
|
out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;
|
||||||
|
out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;
|
||||||
} else {
|
} else {
|
||||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||||
@@ -683,9 +712,7 @@ void fast::AffineQuantize::eval_cpu(
|
|||||||
if (arr.flags().row_contiguous) {
|
if (arr.flags().row_contiguous) {
|
||||||
return std::make_pair(arr, false);
|
return std::make_pair(arr, false);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
return std::make_pair(contiguous_copy_cpu(arr, s), true);
|
||||||
copy(arr, arr_copy, CopyType::General, s);
|
|
||||||
return std::make_pair(arr_copy, true);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -325,7 +325,15 @@ struct MaxReduce {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
T operator()(simd::Simd<T, N> x) {
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::max(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
if (simd::any(x != x)) {
|
||||||
|
return static_cast<T>(NAN);
|
||||||
|
}
|
||||||
return simd::max(x);
|
return simd::max(x);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
@@ -342,7 +350,15 @@ struct MinReduce {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
T operator()(simd::Simd<T, N> x) {
|
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
return simd::min(x);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N, typename T>
|
||||||
|
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||||
|
if (simd::any(x != x)) {
|
||||||
|
return static_cast<T>(NAN);
|
||||||
|
}
|
||||||
return simd::min(x);
|
return simd::min(x);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
@@ -527,10 +543,10 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<int8_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<int16_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cpu/binary_ops.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
@@ -226,6 +227,16 @@ void scan_dispatch(
|
|||||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case Scan::LogAddExp: {
|
||||||
|
auto op = [](U a, T b) {
|
||||||
|
return detail::LogAddExp{}(a, static_cast<U>(b));
|
||||||
|
};
|
||||||
|
auto init = (issubdtype(in.dtype(), floating))
|
||||||
|
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||||
|
: std::numeric_limits<U>::min();
|
||||||
|
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,10 +250,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Ensure contiguity
|
// Ensure contiguity
|
||||||
auto in = inputs[0];
|
auto in = inputs[0];
|
||||||
if (!in.flags().row_contiguous) {
|
if (!in.flags().row_contiguous) {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
in = contiguous_copy_cpu(in, stream());
|
||||||
copy(in, arr_copy, CopyType::General, stream());
|
encoder.add_temporary(in);
|
||||||
in = arr_copy;
|
|
||||||
encoder.add_temporary(arr_copy);
|
|
||||||
}
|
}
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
@@ -319,7 +328,8 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
scan_dispatch<complex64_t, complex64_t>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -88,12 +88,33 @@ DEFAULT_UNARY(expm1, std::expm1)
|
|||||||
DEFAULT_UNARY(floor, std::floor)
|
DEFAULT_UNARY(floor, std::floor)
|
||||||
DEFAULT_UNARY(log, std::log)
|
DEFAULT_UNARY(log, std::log)
|
||||||
DEFAULT_UNARY(log10, std::log10)
|
DEFAULT_UNARY(log10, std::log10)
|
||||||
DEFAULT_UNARY(log1p, std::log1p)
|
|
||||||
DEFAULT_UNARY(sinh, std::sinh)
|
DEFAULT_UNARY(sinh, std::sinh)
|
||||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||||
DEFAULT_UNARY(tan, std::tan)
|
DEFAULT_UNARY(tan, std::tan)
|
||||||
DEFAULT_UNARY(tanh, std::tanh)
|
DEFAULT_UNARY(tanh, std::tanh)
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Simd<T, 1> log1p(Simd<T, 1> in) {
|
||||||
|
if constexpr (is_complex<T>) {
|
||||||
|
auto x = in.value.real();
|
||||||
|
auto y = in.value.imag();
|
||||||
|
auto zabs = std::abs(in.value);
|
||||||
|
auto theta = std::atan2(y, x + 1);
|
||||||
|
if (zabs < 0.5) {
|
||||||
|
auto r = x * (2 + x) + y * y;
|
||||||
|
if (r == 0) { // handle underflow
|
||||||
|
return Simd<T, 1>{T{x, theta}};
|
||||||
|
}
|
||||||
|
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
|
||||||
|
} else {
|
||||||
|
auto z0 = std::hypot(x + 1, y);
|
||||||
|
return Simd<T, 1>{T{std::log(z0), theta}};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Simd<T, 1>{std::log1p(in.value)};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Simd<T, 1> log2(Simd<T, 1> in) {
|
Simd<T, 1> log2(Simd<T, 1> in) {
|
||||||
if constexpr (is_complex<T>) {
|
if constexpr (is_complex<T>) {
|
||||||
|
|||||||
@@ -131,8 +131,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_cpu(x, s);
|
||||||
copy(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -333,45 +333,24 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
int axis = axis_;
|
||||||
|
if (axis < 0) {
|
||||||
|
axis += in.ndim();
|
||||||
|
}
|
||||||
|
|
||||||
// Copy input to output
|
// Copy input to output
|
||||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
|
||||||
copy(in, out, ctype, stream());
|
? CopyType::Vector
|
||||||
|
: CopyType::General;
|
||||||
|
copy_cpu(in, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.dispatch(
|
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
|
||||||
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
|
dispatch_all_types(out.dtype(), [&](auto type_tag) {
|
||||||
switch (out.dtype()) {
|
sort<MLX_GET_TYPE(type_tag)>(out, axis);
|
||||||
case bool_:
|
});
|
||||||
return sort<bool>(out, axis_);
|
});
|
||||||
case uint8:
|
|
||||||
return sort<uint8_t>(out, axis_);
|
|
||||||
case uint16:
|
|
||||||
return sort<uint16_t>(out, axis_);
|
|
||||||
case uint32:
|
|
||||||
return sort<uint32_t>(out, axis_);
|
|
||||||
case uint64:
|
|
||||||
return sort<uint64_t>(out, axis_);
|
|
||||||
case int8:
|
|
||||||
return sort<int8_t>(out, axis_);
|
|
||||||
case int16:
|
|
||||||
return sort<int16_t>(out, axis_);
|
|
||||||
case int32:
|
|
||||||
return sort<int32_t>(out, axis_);
|
|
||||||
case int64:
|
|
||||||
return sort<int64_t>(out, axis_);
|
|
||||||
case float32:
|
|
||||||
return sort<float>(out, axis_);
|
|
||||||
case float64:
|
|
||||||
return sort<double>(out, axis_);
|
|
||||||
case float16:
|
|
||||||
return sort<float16_t>(out, axis_);
|
|
||||||
case bfloat16:
|
|
||||||
return sort<bfloat16_t>(out, axis_);
|
|
||||||
case complex64:
|
|
||||||
return sort<complex64_t>(out, axis_);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -426,8 +405,10 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
// Copy input to output
|
// Copy input to output
|
||||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
||||||
copy(in, out, ctype, stream());
|
? CopyType::Vector
|
||||||
|
: CopyType::General;
|
||||||
|
copy_cpu(in, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ void svd_impl(
|
|||||||
|
|
||||||
// lapack clobbers the input, so we have to make a copy.
|
// lapack clobbers the input, so we have to make a copy.
|
||||||
array in(a.shape(), a.dtype(), nullptr, {});
|
array in(a.shape(), a.dtype(), nullptr, {});
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
in,
|
in,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
// Required for using M_LN2 in MSVC.
|
||||||
|
#define _USE_MATH_DEFINES
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/backend/cpu/unary.h"
|
#include "mlx/backend/cpu/unary.h"
|
||||||
|
|||||||
@@ -2,32 +2,13 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/backend/common/unary.h"
|
||||||
#include "mlx/array.h"
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void set_unary_output_data(const array& in, array& out) {
|
|
||||||
if (in.flags().contiguous) {
|
|
||||||
if (is_donatable(in, out)) {
|
|
||||||
out.copy_shared_buffer(in);
|
|
||||||
} else {
|
|
||||||
auto size = in.data_size();
|
|
||||||
out.set_data(
|
|
||||||
allocator::malloc(size * out.itemsize()),
|
|
||||||
size,
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = T, typename Op>
|
template <typename T, typename U = T, typename Op>
|
||||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||||
for (size_t i = 0; i < shape; i += 1) {
|
for (size_t i = 0; i < shape; i += 1) {
|
||||||
|
|||||||
170
mlx/backend/cuda/CMakeLists.txt
Normal file
170
mlx/backend/cuda/CMakeLists.txt
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# Filename rules in cuda backend:
|
||||||
|
#
|
||||||
|
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
||||||
|
# * Device-only code should be put in device/ subdir.
|
||||||
|
# * Files in device/ subdir should not include files outside.
|
||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||||
|
|
||||||
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
|
||||||
|
target_sources(
|
||||||
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu)
|
||||||
|
else()
|
||||||
|
target_sources(
|
||||||
|
mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||||
|
|
||||||
|
# Embed kernel sources in binary for JIT compilation.
|
||||||
|
file(
|
||||||
|
GLOB MLX_JIT_SOURCES
|
||||||
|
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
|
||||||
|
string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT gen/cuda_jit_sources.h
|
||||||
|
COMMAND
|
||||||
|
${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
-DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
|
||||||
|
"${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
|
||||||
|
DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
|
||||||
|
add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)
|
||||||
|
add_dependencies(mlx cuda_jit_sources)
|
||||||
|
target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
|
||||||
|
|
||||||
|
# Enable defining device lambda functions.
|
||||||
|
target_compile_options(mlx
|
||||||
|
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||||
|
|
||||||
|
# Enable calling host constexpr functions from device. This is needed because
|
||||||
|
# the constexpr version of isnan is host only.
|
||||||
|
target_compile_options(
|
||||||
|
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
|
||||||
|
|
||||||
|
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
|
||||||
|
# Explicitly pass this flag to suppress the warning, it is safe to set it to
|
||||||
|
# true but the warning wouldn't be suppressed.
|
||||||
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||||
|
target_compile_options(
|
||||||
|
mlx
|
||||||
|
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Suppress warning when building for compute capability 7 used by V100.
|
||||||
|
target_compile_options(
|
||||||
|
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--Wno-deprecated-gpu-targets>")
|
||||||
|
|
||||||
|
# Use stronger binaries compression. This feature was introduced in CUDA 12.8
|
||||||
|
# and requires drivers released after CUDA 12.4.
|
||||||
|
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
|
||||||
|
target_compile_options(
|
||||||
|
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
|
||||||
|
# managed memory.
|
||||||
|
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
|
||||||
|
set(MLX_CUDA_ARCHITECTURES "native")
|
||||||
|
endif()
|
||||||
|
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||||
|
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||||
|
"${MLX_CUDA_ARCHITECTURES}")
|
||||||
|
|
||||||
|
# Use fixed version of CCCL.
|
||||||
|
FetchContent_Declare(
|
||||||
|
cccl
|
||||||
|
URL "https://github.com/NVIDIA/cccl/releases/download/v2.8.1/cccl-v2.8.1.zip")
|
||||||
|
FetchContent_MakeAvailable(cccl)
|
||||||
|
target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
|
||||||
|
|
||||||
|
# Use fixed version of NVTX.
|
||||||
|
FetchContent_Declare(
|
||||||
|
nvtx3
|
||||||
|
GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
|
||||||
|
GIT_TAG v3.1.1
|
||||||
|
GIT_SHALLOW TRUE
|
||||||
|
SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
|
||||||
|
FetchContent_MakeAvailable(nvtx3)
|
||||||
|
target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
|
||||||
|
|
||||||
|
# Make cuda runtime APIs available in non-cuda files.
|
||||||
|
find_package(CUDAToolkit REQUIRED)
|
||||||
|
target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
# Use cublasLt.
|
||||||
|
target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||||
|
|
||||||
|
# Use NVRTC and driver APIs.
|
||||||
|
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||||
|
|
||||||
|
# Use the frontend APIs of cuDNN.
|
||||||
|
FetchContent_Declare(
|
||||||
|
cudnn
|
||||||
|
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||||
|
GIT_TAG v1.12.1
|
||||||
|
GIT_SHALLOW TRUE
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
|
||||||
|
FetchContent_MakeAvailable(cudnn)
|
||||||
|
target_link_libraries(mlx PRIVATE cudnn_frontend)
|
||||||
|
# Link with the actual cuDNN libraries.
|
||||||
|
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
|
||||||
|
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||||
|
|
||||||
|
# Suppress nvcc warnings on MLX headers.
|
||||||
|
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||||
|
--diag_suppress=997>)
|
||||||
|
|
||||||
|
# Install CCCL headers for JIT.
|
||||||
|
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||||
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
|
||||||
258
mlx/backend/cuda/allocator.cpp
Normal file
258
mlx/backend/cuda/allocator.cpp
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
|
#include "mlx/backend/cuda/utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
constexpr int page_size = 16384;
|
||||||
|
|
||||||
|
// Any allocations smaller than this will try to use the small pool
|
||||||
|
constexpr int small_block_size = 8;
|
||||||
|
|
||||||
|
// The small pool size in bytes. This should be a multiple of the host page
|
||||||
|
// size and small_block_size.
|
||||||
|
constexpr int small_pool_size = 4 * page_size;
|
||||||
|
|
||||||
|
SmallSizePool::SmallSizePool() {
|
||||||
|
auto num_blocks = small_pool_size / small_block_size;
|
||||||
|
buffer_ = new Block[num_blocks];
|
||||||
|
|
||||||
|
next_free_ = buffer_;
|
||||||
|
|
||||||
|
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||||
|
|
||||||
|
auto curr = next_free_;
|
||||||
|
for (size_t i = 1; i < num_blocks; ++i) {
|
||||||
|
curr->next = buffer_ + i;
|
||||||
|
curr = curr->next;
|
||||||
|
}
|
||||||
|
curr->next = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallSizePool::~SmallSizePool() {
|
||||||
|
CHECK_CUDA_ERROR(cudaFree(data_));
|
||||||
|
delete[] buffer_;
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaBuffer* SmallSizePool::malloc() {
|
||||||
|
if (next_free_ == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Block* b = next_free_;
|
||||||
|
uint64_t i = next_free_ - buffer_;
|
||||||
|
next_free_ = next_free_->next;
|
||||||
|
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||||
|
b->buf.size = small_block_size;
|
||||||
|
return &b->buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SmallSizePool::free(CudaBuffer* buf) {
|
||||||
|
auto b = reinterpret_cast<Block*>(buf);
|
||||||
|
b->next = next_free_;
|
||||||
|
next_free_ = b;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SmallSizePool::in_pool(CudaBuffer* buf) {
|
||||||
|
constexpr int num_blocks = (small_pool_size / small_block_size);
|
||||||
|
auto b = reinterpret_cast<Block*>(buf);
|
||||||
|
int64_t block_num = b - buffer_;
|
||||||
|
return block_num >= 0 && block_num < num_blocks;
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaAllocator::CudaAllocator()
|
||||||
|
: buffer_cache_(
|
||||||
|
page_size,
|
||||||
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
|
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||||
|
// TODO: Set memory limit for multi-device.
|
||||||
|
size_t free, total;
|
||||||
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
|
memory_limit_ = total * 0.8;
|
||||||
|
max_pool_size_ = memory_limit_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Buffer CudaAllocator::malloc(size_t size) {
|
||||||
|
// Find available buffer from cache.
|
||||||
|
auto orig_size = size;
|
||||||
|
std::unique_lock lock(mutex_);
|
||||||
|
if (size <= small_block_size) {
|
||||||
|
size = 8;
|
||||||
|
} else if (size < page_size) {
|
||||||
|
size = next_power_of_2(size);
|
||||||
|
} else {
|
||||||
|
size = page_size * ((size + page_size - 1) / page_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
|
if (!buf) {
|
||||||
|
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
||||||
|
int64_t mem_to_free =
|
||||||
|
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
||||||
|
if (mem_to_free > 0) {
|
||||||
|
buffer_cache_.release_cached_buffers(mem_to_free);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try the scalar pool first
|
||||||
|
if (size <= small_block_size) {
|
||||||
|
buf = scalar_pool_.malloc();
|
||||||
|
}
|
||||||
|
lock.unlock();
|
||||||
|
if (!buf) {
|
||||||
|
buf = new CudaBuffer{nullptr, size};
|
||||||
|
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||||
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lock.lock();
|
||||||
|
}
|
||||||
|
active_memory_ += size;
|
||||||
|
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||||
|
|
||||||
|
// Maintain the cache below the requested limit.
|
||||||
|
if (get_cache_memory() > max_pool_size_) {
|
||||||
|
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||||
|
}
|
||||||
|
return Buffer{buf};
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudaAllocator::free(Buffer buffer) {
|
||||||
|
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||||
|
if (!buf) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_lock lock(mutex_);
|
||||||
|
active_memory_ -= buf->size;
|
||||||
|
if (get_cache_memory() < max_pool_size_) {
|
||||||
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
|
} else {
|
||||||
|
cuda_free(buf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CudaAllocator::size(Buffer buffer) const {
|
||||||
|
auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
|
||||||
|
if (!buf) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return buf->size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This must be called with mutex_ aquired
|
||||||
|
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||||
|
if (scalar_pool_.in_pool(buf)) {
|
||||||
|
scalar_pool_.free(buf);
|
||||||
|
} else {
|
||||||
|
cudaFree(buf->data);
|
||||||
|
delete buf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CudaAllocator::get_active_memory() const {
|
||||||
|
return active_memory_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CudaAllocator::get_peak_memory() const {
|
||||||
|
return peak_memory_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudaAllocator::reset_peak_memory() {
|
||||||
|
std::lock_guard lock(mutex_);
|
||||||
|
peak_memory_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CudaAllocator::get_memory_limit() {
|
||||||
|
return memory_limit_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CudaAllocator::set_memory_limit(size_t limit) {
|
||||||
|
std::lock_guard lock(mutex_);
|
||||||
|
std::swap(limit, memory_limit_);
|
||||||
|
return limit;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CudaAllocator::get_cache_memory() const {
|
||||||
|
return buffer_cache_.cache_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CudaAllocator::set_cache_limit(size_t limit) {
|
||||||
|
std::lock_guard lk(mutex_);
|
||||||
|
std::swap(limit, max_pool_size_);
|
||||||
|
return limit;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CudaAllocator::clear_cache() {
|
||||||
|
std::lock_guard lk(mutex_);
|
||||||
|
buffer_cache_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaAllocator& allocator() {
|
||||||
|
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||||
|
// will not be called on exit and buffers in the cache will be leaked. This
|
||||||
|
// can save some time at program exit.
|
||||||
|
static CudaAllocator* allocator_ = new CudaAllocator;
|
||||||
|
return *allocator_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
namespace allocator {
|
||||||
|
|
||||||
|
Allocator& allocator() {
|
||||||
|
return cu::allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
void* Buffer::raw_ptr() {
|
||||||
|
if (!ptr_) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return static_cast<cu::CudaBuffer*>(ptr_)->data;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace allocator
|
||||||
|
|
||||||
|
size_t get_active_memory() {
|
||||||
|
return cu::allocator().get_active_memory();
|
||||||
|
}
|
||||||
|
size_t get_peak_memory() {
|
||||||
|
return cu::allocator().get_peak_memory();
|
||||||
|
}
|
||||||
|
void reset_peak_memory() {
|
||||||
|
return cu::allocator().reset_peak_memory();
|
||||||
|
}
|
||||||
|
size_t set_memory_limit(size_t limit) {
|
||||||
|
return cu::allocator().set_memory_limit(limit);
|
||||||
|
}
|
||||||
|
size_t get_memory_limit() {
|
||||||
|
return cu::allocator().get_memory_limit();
|
||||||
|
}
|
||||||
|
size_t get_cache_memory() {
|
||||||
|
return cu::allocator().get_cache_memory();
|
||||||
|
}
|
||||||
|
size_t set_cache_limit(size_t limit) {
|
||||||
|
return cu::allocator().set_cache_limit(limit);
|
||||||
|
}
|
||||||
|
void clear_cache() {
|
||||||
|
cu::allocator().clear_cache();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not supported in CUDA.
|
||||||
|
size_t set_wired_limit(size_t) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
77
mlx/backend/cuda/allocator.h
Normal file
77
mlx/backend/cuda/allocator.h
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/buffer_cache.h"
|
||||||
|
|
||||||
|
#include <mutex>
|
||||||
|
#include <set>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
using allocator::Buffer;
|
||||||
|
|
||||||
|
// Stores cuda-managed unified memory.
|
||||||
|
struct CudaBuffer {
|
||||||
|
void* data;
|
||||||
|
size_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SmallSizePool {
|
||||||
|
private:
|
||||||
|
union Block {
|
||||||
|
Block* next;
|
||||||
|
CudaBuffer buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
Block* buffer_{nullptr};
|
||||||
|
void* data_{nullptr};
|
||||||
|
Block* next_free_{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
SmallSizePool();
|
||||||
|
~SmallSizePool();
|
||||||
|
|
||||||
|
SmallSizePool(const SmallSizePool&) = delete;
|
||||||
|
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
||||||
|
|
||||||
|
CudaBuffer* malloc();
|
||||||
|
void free(CudaBuffer* buf);
|
||||||
|
bool in_pool(CudaBuffer* buf);
|
||||||
|
};
|
||||||
|
|
||||||
|
class CudaAllocator : public allocator::Allocator {
|
||||||
|
public:
|
||||||
|
Buffer malloc(size_t size) override;
|
||||||
|
void free(Buffer buffer) override;
|
||||||
|
size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
|
size_t get_active_memory() const;
|
||||||
|
size_t get_peak_memory() const;
|
||||||
|
void reset_peak_memory();
|
||||||
|
size_t get_memory_limit();
|
||||||
|
size_t set_memory_limit(size_t limit);
|
||||||
|
size_t get_cache_memory() const;
|
||||||
|
size_t set_cache_limit(size_t limit);
|
||||||
|
void clear_cache();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void cuda_free(CudaBuffer* buf);
|
||||||
|
|
||||||
|
CudaAllocator();
|
||||||
|
friend CudaAllocator& allocator();
|
||||||
|
|
||||||
|
std::mutex mutex_;
|
||||||
|
size_t memory_limit_;
|
||||||
|
size_t max_pool_size_;
|
||||||
|
BufferCache<CudaBuffer> buffer_cache_;
|
||||||
|
size_t active_memory_{0};
|
||||||
|
size_t peak_memory_{0};
|
||||||
|
SmallSizePool scalar_pool_;
|
||||||
|
};
|
||||||
|
|
||||||
|
CudaAllocator& allocator();
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu
|
||||||
55
mlx/backend/cuda/arange.cu
Normal file
55
mlx/backend/cuda/arange.cu
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <thrust/device_ptr.h>
|
||||||
|
#include <thrust/transform.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Arange {
|
||||||
|
const T start;
|
||||||
|
const T step;
|
||||||
|
|
||||||
|
__device__ T operator()(uint32_t i) const {
|
||||||
|
return start + i * step;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Arange::eval_gpu");
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(stream());
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||||
|
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||||
|
using OutType = cuda_type_t<CTYPE>;
|
||||||
|
CTYPE step =
|
||||||
|
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||||
|
thrust::transform(
|
||||||
|
cu::thrust_policy(encoder.stream()),
|
||||||
|
thrust::counting_iterator<uint32_t>(0),
|
||||||
|
thrust::counting_iterator<uint32_t>(out.data_size()),
|
||||||
|
thrust::device_pointer_cast(out.data<OutType>()),
|
||||||
|
cu::Arange<OutType>{
|
||||||
|
static_cast<OutType>(start_), static_cast<OutType>(step)});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
188
mlx/backend/cuda/arg_reduce.cu
Normal file
188
mlx/backend/cuda/arg_reduce.cu
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <cub/block/block_load.cuh>
|
||||||
|
#include <cub/block/block_reduce.cuh>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct IndexValPair {
|
||||||
|
uint32_t index;
|
||||||
|
T val;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ArgMin {
|
||||||
|
constexpr __device__ T init() {
|
||||||
|
return Limits<T>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ IndexValPair<T> operator()(
|
||||||
|
const IndexValPair<T>& best,
|
||||||
|
const IndexValPair<T>& current) {
|
||||||
|
if (best.val > current.val ||
|
||||||
|
(best.val == current.val && best.index > current.index)) {
|
||||||
|
return current;
|
||||||
|
} else {
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
__device__ IndexValPair<T> reduce_many(
|
||||||
|
IndexValPair<T> best,
|
||||||
|
const AlignedVector<T, N>& vals,
|
||||||
|
uint32_t offset) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
if (vals[i] < best.val) {
|
||||||
|
best.val = vals[i];
|
||||||
|
best.index = offset + i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ArgMax {
|
||||||
|
constexpr __device__ T init() {
|
||||||
|
return Limits<T>::min();
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ IndexValPair<T> operator()(
|
||||||
|
const IndexValPair<T>& best,
|
||||||
|
const IndexValPair<T>& current) {
|
||||||
|
if (best.val < current.val ||
|
||||||
|
(best.val == current.val && best.index > current.index)) {
|
||||||
|
return current;
|
||||||
|
} else {
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
__device__ IndexValPair<T> reduce_many(
|
||||||
|
IndexValPair<T> best,
|
||||||
|
const AlignedVector<T, N>& vals,
|
||||||
|
uint32_t offset) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
if (vals[i] > best.val) {
|
||||||
|
best.val = vals[i];
|
||||||
|
best.index = offset + i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
|
||||||
|
__global__ void arg_reduce_general(
|
||||||
|
const T* in,
|
||||||
|
uint32_t* out,
|
||||||
|
size_t size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides in_strides,
|
||||||
|
const __grid_constant__ Strides out_strides,
|
||||||
|
int32_t ndim,
|
||||||
|
int64_t axis_stride,
|
||||||
|
int32_t axis_size) {
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
|
||||||
|
int64_t index = cg::this_grid().block_rank();
|
||||||
|
if (index >= size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
|
||||||
|
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
|
||||||
|
in += in_idx;
|
||||||
|
|
||||||
|
Op op;
|
||||||
|
T init = op.init();
|
||||||
|
IndexValPair<T> best{0, init};
|
||||||
|
|
||||||
|
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
|
||||||
|
auto tid = r * BLOCK_DIM + block.thread_index().x;
|
||||||
|
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
|
||||||
|
best = op.reduce_many(best, vals, tid * N_READS);
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
|
||||||
|
__shared__ typename BlockReduceT::TempStorage temp;
|
||||||
|
|
||||||
|
best = BlockReduceT(temp).Reduce(best, op);
|
||||||
|
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
out[out_idx] = best.index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("ArgReduce::eval_gpu");
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
// Prepare the shapes, strides and axis arguments.
|
||||||
|
Shape shape = remove_index(in.shape(), axis_);
|
||||||
|
Strides in_strides = remove_index(in.strides(), axis_);
|
||||||
|
Strides out_strides = out.ndim() == in.ndim()
|
||||||
|
? remove_index(out.strides(), axis_)
|
||||||
|
: out.strides();
|
||||||
|
int64_t axis_stride = in.strides()[axis_];
|
||||||
|
int32_t axis_size = in.shape()[axis_];
|
||||||
|
int32_t ndim = shape.size();
|
||||||
|
|
||||||
|
// ArgReduce.
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
|
||||||
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
constexpr uint32_t N_READS = 4;
|
||||||
|
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
|
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
|
auto kernel =
|
||||||
|
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
|
||||||
|
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||||
|
kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
|
||||||
|
}
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
num_blocks,
|
||||||
|
block_dim(),
|
||||||
|
0,
|
||||||
|
in.data<T>(),
|
||||||
|
out.data<uint32_t>(),
|
||||||
|
out.size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(in_strides),
|
||||||
|
const_param(out_strides),
|
||||||
|
ndim,
|
||||||
|
axis_stride,
|
||||||
|
axis_size);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
150
mlx/backend/cuda/bin2h.cmake
Normal file
150
mlx/backend/cuda/bin2h.cmake
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
# Based on: https://github.com/sivachandran/cmake-bin2h
|
||||||
|
#
|
||||||
|
# Copyright 2020 Sivachandran Paramasivam
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
include(CMakeParseArguments)
|
||||||
|
|
||||||
|
# Function to wrap a given string into multiple lines at the given column
|
||||||
|
# position.
|
||||||
|
#
|
||||||
|
# Parameters:
|
||||||
|
#
|
||||||
|
# * VARIABLE - The name of the CMake variable holding the string.
|
||||||
|
# * AT_COLUMN - The column position at which string will be wrapped.
|
||||||
|
function(WRAP_STRING)
|
||||||
|
set(oneValueArgs VARIABLE AT_COLUMN)
|
||||||
|
cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
|
||||||
|
|
||||||
|
string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)
|
||||||
|
math(EXPR offset "0")
|
||||||
|
|
||||||
|
while(stringLength GREATER 0)
|
||||||
|
if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})
|
||||||
|
math(EXPR length "${WRAP_STRING_AT_COLUMN}")
|
||||||
|
else()
|
||||||
|
math(EXPR length "${stringLength}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)
|
||||||
|
set(lines "${lines}\n ${line}")
|
||||||
|
|
||||||
|
math(EXPR stringLength "${stringLength} - ${length}")
|
||||||
|
math(EXPR offset "${offset} + ${length}")
|
||||||
|
endwhile()
|
||||||
|
|
||||||
|
set(${WRAP_STRING_VARIABLE}
|
||||||
|
"${lines}"
|
||||||
|
PARENT_SCOPE)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
# Function to embed contents of a file as byte array in C/C++ header file(.h).
|
||||||
|
# The header file will contain a byte array and integer variable holding the
|
||||||
|
# size of the array.
|
||||||
|
#
|
||||||
|
# Parameters:
|
||||||
|
#
|
||||||
|
# * SOURCE_FILES - The paths of source files whose contents will be embedded in
|
||||||
|
# the header file.
|
||||||
|
# * VARIABLE_NAME - The name of the variable for the byte array. The string
|
||||||
|
# "_SIZE" will be append to this name and will be used a variable name for
|
||||||
|
# size variable.
|
||||||
|
# * HEADER_FILE - The path of header file.
|
||||||
|
# * APPEND - If specified appends to the header file instead of overwriting it
|
||||||
|
# * HEADER_NAMESPACE - The namespace, where the array should be located in.
|
||||||
|
# * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte
|
||||||
|
# array.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
#
|
||||||
|
# bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG")
|
||||||
|
function(BIN2H)
|
||||||
|
set(options APPEND NULL_TERMINATE)
|
||||||
|
set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)
|
||||||
|
set(multiValueArgs SOURCE_FILES)
|
||||||
|
cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}"
|
||||||
|
"${multiValueArgs}" ${ARGN})
|
||||||
|
|
||||||
|
set(arrayDefinition "")
|
||||||
|
foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)
|
||||||
|
# get filename without extension
|
||||||
|
get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)
|
||||||
|
# convert the filename to a valid C identifier
|
||||||
|
string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME)
|
||||||
|
|
||||||
|
# reads source file contents as hex string
|
||||||
|
file(READ ${SOURCE_FILE} hexString HEX)
|
||||||
|
|
||||||
|
# append null
|
||||||
|
if(BIN2H_NULL_TERMINATE)
|
||||||
|
string(APPEND hexString "00")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# wraps the hex string into multiple lines
|
||||||
|
wrap_string(VARIABLE hexString AT_COLUMN 24)
|
||||||
|
|
||||||
|
# strip the © in source code
|
||||||
|
string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString})
|
||||||
|
|
||||||
|
string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues
|
||||||
|
${arrayValues})
|
||||||
|
|
||||||
|
# make a full variable name for the array
|
||||||
|
set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}")
|
||||||
|
|
||||||
|
# declares byte array and the length variables
|
||||||
|
string(APPEND arrayDefinition
|
||||||
|
"constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n")
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
# add namespace wrapper if defined
|
||||||
|
if(DEFINED BIN2H_HEADER_NAMESPACE)
|
||||||
|
set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {")
|
||||||
|
set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}")
|
||||||
|
set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(arrayIncludes "#pragma once")
|
||||||
|
string(PREPEND declarations "${arrayIncludes}\n\n")
|
||||||
|
|
||||||
|
if(BIN2H_APPEND)
|
||||||
|
file(APPEND ${BIN2H_HEADER_FILE} "${declarations}")
|
||||||
|
else()
|
||||||
|
file(WRITE ${BIN2H_HEADER_FILE} "${declarations}")
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
# ----------------------------- CLI args -----------------------------
|
||||||
|
|
||||||
|
string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
|
||||||
|
foreach(source ${MLX_JIT_SOURCES_LIST})
|
||||||
|
list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}")
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
bin2h(
|
||||||
|
SOURCE_FILES
|
||||||
|
${MLX_JIT_SOURCES_ABS}
|
||||||
|
NULL_TERMINATE
|
||||||
|
VARIABLE_NAME
|
||||||
|
"jit_source"
|
||||||
|
HEADER_NAMESPACE
|
||||||
|
"mlx::core"
|
||||||
|
HEADER_FILE
|
||||||
|
"${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h")
|
||||||
357
mlx/backend/cuda/binary.cu
Normal file
357
mlx/backend/cuda/binary.cu
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
|
__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
|
||||||
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (int i = index * N_READS; i < size; ++i) {
|
||||||
|
out[i] = Op{}(a[0], b[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = Op{}(a[0], b[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
|
__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
|
||||||
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
|
out[i] = Op{}(a[0], b[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = Op{}(a[0], b_vec[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
|
__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
|
||||||
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
|
out[i] = Op{}(a[i], b[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = Op{}(a_vec[i], b[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
|
__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
|
||||||
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
|
out[i] = Op{}(a[i], b[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = Op{}(a_vec[i], b_vec[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||||
|
__global__ void binary_g_nd(
|
||||||
|
const In* a,
|
||||||
|
const In* b,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||||
|
index, shape.data(), a_strides.data(), b_strides.data());
|
||||||
|
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void binary_g(
|
||||||
|
const In* a,
|
||||||
|
const In* b,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides a_strides,
|
||||||
|
const __grid_constant__ Strides b_strides,
|
||||||
|
int ndim) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc(
|
||||||
|
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||||
|
out[index] = Op{}(a[a_idx], b[b_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out>
|
||||||
|
constexpr bool supports_binary_op() {
|
||||||
|
if (std::is_same_v<Op, Add> || std::is_same_v<Op, Divide> ||
|
||||||
|
std::is_same_v<Op, Maximum> || std::is_same_v<Op, Minimum> ||
|
||||||
|
std::is_same_v<Op, Multiply> || std::is_same_v<Op, Subtract> ||
|
||||||
|
std::is_same_v<Op, Power> || std::is_same_v<Op, Remainder>) {
|
||||||
|
return std::is_same_v<In, Out>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, Equal> || std::is_same_v<Op, Greater> ||
|
||||||
|
std::is_same_v<Op, GreaterEqual> || std::is_same_v<Op, Less> ||
|
||||||
|
std::is_same_v<Op, LessEqual> || std::is_same_v<Op, NotEqual>) {
|
||||||
|
return std::is_same_v<Out, bool>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, LogicalAnd> || std::is_same_v<Op, LogicalOr>) {
|
||||||
|
return std::is_same_v<Out, bool> && std::is_same_v<In, bool>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, NaNEqual>) {
|
||||||
|
return std::is_same_v<Out, bool> && is_inexact_v<In>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, LogAddExp>) {
|
||||||
|
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, ArcTan2>) {
|
||||||
|
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, BitwiseAnd> || std::is_same_v<Op, BitwiseOr> ||
|
||||||
|
std::is_same_v<Op, BitwiseXor>) {
|
||||||
|
return std::is_same_v<In, Out> && std::is_integral_v<In>;
|
||||||
|
}
|
||||||
|
if (std::is_same_v<Op, LeftShift> || std::is_same_v<Op, RightShift>) {
|
||||||
|
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||||
|
!std::is_same_v<In, bool>;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_op_gpu_inplace(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const char* op,
|
||||||
|
const Stream& s) {
|
||||||
|
assert(inputs.size() > 1);
|
||||||
|
const auto& a = inputs[0];
|
||||||
|
const auto& b = inputs[1];
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||||
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
|
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||||
|
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||||
|
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||||
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
if (bopt == BinaryOpType::General) {
|
||||||
|
dispatch_bool(
|
||||||
|
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||||
|
out.data_size() > INT32_MAX,
|
||||||
|
[&](auto large) {
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
|
Shape shape;
|
||||||
|
std::vector<Strides> strides;
|
||||||
|
std::tie(shape, strides) = collapse_contiguous_dims(a, b, out);
|
||||||
|
auto& a_strides = strides[0];
|
||||||
|
auto& b_strides = strides[1];
|
||||||
|
int ndim = shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(out, large());
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::binary_g_nd<
|
||||||
|
Op,
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant()>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out.data<OutType>(),
|
||||||
|
out.size(),
|
||||||
|
const_param<dims_constant()>(shape),
|
||||||
|
const_param<dims_constant()>(a_strides),
|
||||||
|
const_param<dims_constant()>(b_strides));
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::binary_g<Op, InType, OutType, IdxT>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out.data<OutType>(),
|
||||||
|
out.size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(a_strides),
|
||||||
|
const_param(b_strides),
|
||||||
|
ndim);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
|
constexpr int N_READS = 16 / sizeof(InType);
|
||||||
|
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
|
kernel = cu::binary_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||||
|
kernel = cu::binary_vs<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
} else if (bopt == BinaryOpType::VectorVector) {
|
||||||
|
kernel = cu::binary_vv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
}
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
|
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out.data<OutType>(),
|
||||||
|
out.data_size());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||||
|
op,
|
||||||
|
dtype_to_string(a.dtype()),
|
||||||
|
dtype_to_string(out.dtype())));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_op_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
const char* op,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
binary_op_gpu_inplace<Op>(inputs, out, op, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BINARY_GPU(func) \
|
||||||
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
|
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
||||||
|
auto& s = out.primitive().stream(); \
|
||||||
|
binary_op_gpu<cu::func>(inputs, out, name(), s); \
|
||||||
|
}
|
||||||
|
|
||||||
|
BINARY_GPU(Add)
|
||||||
|
BINARY_GPU(ArcTan2)
|
||||||
|
BINARY_GPU(Divide)
|
||||||
|
BINARY_GPU(Remainder)
|
||||||
|
BINARY_GPU(Greater)
|
||||||
|
BINARY_GPU(GreaterEqual)
|
||||||
|
BINARY_GPU(Less)
|
||||||
|
BINARY_GPU(LessEqual)
|
||||||
|
BINARY_GPU(LogicalAnd)
|
||||||
|
BINARY_GPU(LogicalOr)
|
||||||
|
BINARY_GPU(LogAddExp)
|
||||||
|
BINARY_GPU(Maximum)
|
||||||
|
BINARY_GPU(Minimum)
|
||||||
|
BINARY_GPU(Multiply)
|
||||||
|
BINARY_GPU(NotEqual)
|
||||||
|
BINARY_GPU(Power)
|
||||||
|
BINARY_GPU(Subtract)
|
||||||
|
|
||||||
|
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
if (equal_nan_) {
|
||||||
|
binary_op_gpu<cu::NaNEqual>(inputs, out, name(), s);
|
||||||
|
} else {
|
||||||
|
binary_op_gpu<cu::Equal>(inputs, out, name(), s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
switch (op_) {
|
||||||
|
case BitwiseBinary::And:
|
||||||
|
binary_op_gpu<cu::BitwiseAnd>(inputs, out, name(), s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::Or:
|
||||||
|
binary_op_gpu<cu::BitwiseOr>(inputs, out, name(), s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::Xor:
|
||||||
|
binary_op_gpu<cu::BitwiseXor>(inputs, out, name(), s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::LeftShift:
|
||||||
|
binary_op_gpu<cu::LeftShift>(inputs, out, name(), s);
|
||||||
|
break;
|
||||||
|
case BitwiseBinary::RightShift:
|
||||||
|
binary_op_gpu<cu::RightShift>(inputs, out, name(), s);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
333
mlx/backend/cuda/binary_two.cu
Normal file
333
mlx/backend/cuda/binary_two.cu
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/binary.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
|
__global__ void
|
||||||
|
binary_two_ss(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
|
||||||
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
|
auto out = Op{}(a[0], b[0]);
|
||||||
|
out_a[i] = out[0];
|
||||||
|
out_b[i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a[0], b[0]);
|
||||||
|
out_a_vec[i] = out[0];
|
||||||
|
out_b_vec[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
|
__global__ void
|
||||||
|
binary_two_sv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
|
||||||
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
|
auto out = Op{}(a[0], b[i]);
|
||||||
|
out_a[i] = out[0];
|
||||||
|
out_b[i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a[0], b_vec[i]);
|
||||||
|
out_a_vec[i] = out[0];
|
||||||
|
out_b_vec[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
|
__global__ void
|
||||||
|
binary_two_vs(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
|
||||||
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
|
auto out = Op{}(a[i], b[0]);
|
||||||
|
out_a[i] = out[0];
|
||||||
|
out_b[i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a_vec[i], b[0]);
|
||||||
|
out_a_vec[i] = out[0];
|
||||||
|
out_b_vec[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
||||||
|
__global__ void
|
||||||
|
binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
|
||||||
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
|
auto out = Op{}(a[i], b[i]);
|
||||||
|
out_a[i] = out[0];
|
||||||
|
out_b[i] = out[1];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto a_vec = load_vector<N_READS>(a, index);
|
||||||
|
auto b_vec = load_vector<N_READS>(b, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_a_vec;
|
||||||
|
AlignedVector<Out, N_READS> out_b_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
auto out = Op{}(a_vec[i], b_vec[i]);
|
||||||
|
out_a_vec[i] = out[0];
|
||||||
|
out_b_vec[i] = out[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out_a, index, out_a_vec);
|
||||||
|
store_vector<N_READS>(out_b, index, out_b_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT, int NDIM>
|
||||||
|
__global__ void binary_two_g_nd(
|
||||||
|
const In* a,
|
||||||
|
const In* b,
|
||||||
|
Out* out_a,
|
||||||
|
Out* out_b,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_strides) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc_nd<NDIM>(
|
||||||
|
index, shape.data(), a_strides.data(), b_strides.data());
|
||||||
|
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||||
|
out_a[index] = out[0];
|
||||||
|
out_b[index] = out[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void binary_two_g(
|
||||||
|
const In* a,
|
||||||
|
const In* b,
|
||||||
|
Out* out_a,
|
||||||
|
Out* out_b,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides a_strides,
|
||||||
|
const __grid_constant__ Strides b_strides,
|
||||||
|
int ndim) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [a_idx, b_idx] = elem_to_loc(
|
||||||
|
index, shape.data(), a_strides.data(), b_strides.data(), ndim);
|
||||||
|
auto out = Op{}(a[a_idx], b[b_idx]);
|
||||||
|
out_a[index] = out[0];
|
||||||
|
out_b[index] = out[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op, typename In, typename Out>
|
||||||
|
constexpr bool supports_binary_two_op() {
|
||||||
|
if (std::is_same_v<Op, DivMod>) {
|
||||||
|
return std::is_same_v<In, Out> &&
|
||||||
|
(std::is_integral_v<Out> || is_floating_v<Out>);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_two_op_gpu_inplace(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
const char* op,
|
||||||
|
const Stream& s) {
|
||||||
|
assert(inputs.size() > 1);
|
||||||
|
const auto& a = inputs[0];
|
||||||
|
const auto& b = inputs[1];
|
||||||
|
auto& out_a = outputs[0];
|
||||||
|
auto& out_b = outputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out_a, bopt);
|
||||||
|
set_binary_op_output_data(a, b, out_b, bopt);
|
||||||
|
|
||||||
|
if (out_a.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out_a);
|
||||||
|
encoder.set_output_array(out_b);
|
||||||
|
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||||
|
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
|
||||||
|
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||||
|
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||||
|
if constexpr (cu::supports_binary_two_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||||
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
if (bopt == BinaryOpType::General) {
|
||||||
|
dispatch_bool(
|
||||||
|
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||||
|
out_a.data_size() > INT32_MAX,
|
||||||
|
[&](auto large) {
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
|
Shape shape;
|
||||||
|
std::vector<Strides> strides;
|
||||||
|
std::tie(shape, strides) =
|
||||||
|
collapse_contiguous_dims(a, b, out_a);
|
||||||
|
auto& a_strides = strides[0];
|
||||||
|
auto& b_strides = strides[1];
|
||||||
|
int ndim = shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(out_a, large());
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::binary_two_g_nd<
|
||||||
|
Op,
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant()>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out_a.data<OutType>(),
|
||||||
|
out_b.data<OutType>(),
|
||||||
|
out_a.size(),
|
||||||
|
const_param<dims_constant()>(shape),
|
||||||
|
const_param<dims_constant()>(a_strides),
|
||||||
|
const_param<dims_constant()>(b_strides));
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(out_a, large());
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::binary_two_g<Op, InType, OutType, IdxT>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out_a.data<OutType>(),
|
||||||
|
out_b.data<OutType>(),
|
||||||
|
out_a.size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(a_strides),
|
||||||
|
const_param(b_strides),
|
||||||
|
ndim);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
dispatch_bool(out_a.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
|
constexpr int N_READS = 16 / sizeof(InType);
|
||||||
|
auto kernel = cu::binary_two_ss<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
|
kernel = cu::binary_two_sv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||||
|
kernel = cu::binary_two_vs<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
} else if (bopt == BinaryOpType::VectorVector) {
|
||||||
|
kernel = cu::binary_two_vv<Op, InType, OutType, IdxT, N_READS>;
|
||||||
|
}
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
|
out_a.data_size(),
|
||||||
|
out_a.shape(),
|
||||||
|
out_a.strides(),
|
||||||
|
large(),
|
||||||
|
N_READS);
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
a.data<InType>(),
|
||||||
|
b.data<InType>(),
|
||||||
|
out_a.data<OutType>(),
|
||||||
|
out_b.data<OutType>(),
|
||||||
|
out_a.data_size());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Can not do binary op {} on inputs of {} with result of {}.",
|
||||||
|
op,
|
||||||
|
dtype_to_string(a.dtype()),
|
||||||
|
dtype_to_string(out_a.dtype())));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_two_op_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
const char* op,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& a = inputs[0];
|
||||||
|
auto& b = inputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||||
|
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||||
|
binary_two_op_gpu_inplace<Op>(inputs, outputs, op, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DivMod::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("DivMod::eval_gpu");
|
||||||
|
auto& s = outputs[0].primitive().stream();
|
||||||
|
binary_two_op_gpu<cu::DivMod>(inputs, outputs, name(), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
340
mlx/backend/cuda/compiled.cpp
Normal file
340
mlx/backend/cuda/compiled.cpp
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/compiled.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/jit_module.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/graph_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
struct FusedKernelBuilder {
|
||||||
|
std::string os;
|
||||||
|
const std::string& kernel_name;
|
||||||
|
const std::vector<array>& inputs;
|
||||||
|
const std::vector<array>& outputs;
|
||||||
|
const std::vector<array>& tape;
|
||||||
|
const std::function<bool(size_t)>& is_constant;
|
||||||
|
|
||||||
|
void build(const char* name, bool contiguous) {
|
||||||
|
NodeNamer namer;
|
||||||
|
|
||||||
|
// Function parameters.
|
||||||
|
std::vector<std::string> params;
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
if (is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
params.push_back(
|
||||||
|
fmt::format("const {}* {}", dtype_to_cuda_type(x.dtype()), xname));
|
||||||
|
if (!is_scalar(x) && !contiguous) {
|
||||||
|
params.push_back(fmt::format(
|
||||||
|
"const __grid_constant__ cuda::std::array<int64_t, NDIM> {}_strides",
|
||||||
|
xname));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
params.push_back(fmt::format(
|
||||||
|
"{}* {}", dtype_to_cuda_type(x.dtype()), namer.get_name(x)));
|
||||||
|
}
|
||||||
|
if (!contiguous) {
|
||||||
|
params.push_back(
|
||||||
|
"const __grid_constant__ cuda::std::array<int32_t, NDIM> shape");
|
||||||
|
}
|
||||||
|
params.push_back("IdxT size");
|
||||||
|
|
||||||
|
// Build function signature.
|
||||||
|
if (contiguous) {
|
||||||
|
os += "template <typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
|
} else {
|
||||||
|
os +=
|
||||||
|
"template <int NDIM, typename IdxT = uint32_t, int work_per_thread = 1>\n";
|
||||||
|
}
|
||||||
|
os += fmt::format("__global__ void {}(\n", kernel_name + name);
|
||||||
|
for (size_t i = 0; i < params.size(); ++i) {
|
||||||
|
os += " ";
|
||||||
|
os += params[i];
|
||||||
|
if (i != params.size() - 1) {
|
||||||
|
os += ",\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os += ") {\n";
|
||||||
|
|
||||||
|
// Index. For non contiguous kernels we create a separate index
|
||||||
|
// variable per variable otherwise everyone uses `index`.
|
||||||
|
os +=
|
||||||
|
" IdxT index = cg::this_grid().thread_rank() * work_per_thread;\n"
|
||||||
|
" if (index >= size) {\n"
|
||||||
|
" return;\n"
|
||||||
|
" }\n";
|
||||||
|
if (!contiguous) {
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " IdxT " + xname + "_idx = 0;\n";
|
||||||
|
}
|
||||||
|
os += " {\n";
|
||||||
|
os += " IdxT loc = index;\n";
|
||||||
|
os +=
|
||||||
|
" #pragma unroll\n"
|
||||||
|
" for (int i = NDIM - 1; i >= 0; i--) {\n";
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname +
|
||||||
|
"_strides[i]);\n";
|
||||||
|
}
|
||||||
|
os +=
|
||||||
|
" loc /= shape[i];\n"
|
||||||
|
" }\n"
|
||||||
|
" }\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vectorized read loop
|
||||||
|
if (contiguous) {
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
std::string type = dtype_to_cuda_type(x.dtype());
|
||||||
|
os += fmt::format(
|
||||||
|
" auto vec_{0} = load_vector<work_per_thread, {1}>({0} + index, 0, size - index, 0);\n",
|
||||||
|
xname,
|
||||||
|
type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create some space for the outputs
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
std::string type = dtype_to_cuda_type(x.dtype());
|
||||||
|
os += fmt::format(
|
||||||
|
" AlignedVector<{}, work_per_thread> vec_{};\n", type, xname);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work loop
|
||||||
|
if (!contiguous) {
|
||||||
|
os +=
|
||||||
|
"\n"
|
||||||
|
" for (int i = 0; i < work_per_thread && index < size; i++) {\n";
|
||||||
|
} else {
|
||||||
|
os +=
|
||||||
|
"\n"
|
||||||
|
" #pragma unroll\n"
|
||||||
|
" for (int i = 0; i < work_per_thread; i++) {\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read inputs.
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
std::string type = dtype_to_cuda_type(x.dtype());
|
||||||
|
std::string value;
|
||||||
|
if (is_constant(i)) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
print_constant(ss, x);
|
||||||
|
value = fmt::format("static_cast<{}>({})", type, ss.str());
|
||||||
|
} else if (is_scalar(x)) {
|
||||||
|
value = fmt::format("{}[0]", xname);
|
||||||
|
} else if (contiguous) {
|
||||||
|
value = fmt::format("vec_{}[i]", xname);
|
||||||
|
} else {
|
||||||
|
value = fmt::format("{}[{}_idx]", xname, xname);
|
||||||
|
}
|
||||||
|
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write tape.
|
||||||
|
for (const auto& x : tape) {
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
std::string type = dtype_to_cuda_type(x.dtype());
|
||||||
|
std::string value;
|
||||||
|
if (is_static_cast(x.primitive())) {
|
||||||
|
value = fmt::format(
|
||||||
|
"static_cast<{}>(tmp_{})", type, namer.get_name(x.inputs()[0]));
|
||||||
|
} else {
|
||||||
|
value = x.primitive().name();
|
||||||
|
value += "{}(";
|
||||||
|
for (size_t i = 0; i < x.inputs().size() - 1; ++i) {
|
||||||
|
value += fmt::format("tmp_{}, ", namer.get_name(x.inputs()[i]));
|
||||||
|
}
|
||||||
|
value += fmt::format("tmp_{})", namer.get_name(x.inputs().back()));
|
||||||
|
}
|
||||||
|
os += fmt::format(" {} tmp_{} = {};\n", type, xname, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write output.
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
os += fmt::format(" vec_{0}[i] = tmp_{0};\n", namer.get_name(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
// End of work loop
|
||||||
|
if (!contiguous) {
|
||||||
|
os += "\n";
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
const std::string& xname = namer.get_name(x);
|
||||||
|
if (is_scalar(x) || is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
os += fmt::format(" {0}_idx += {0}_strides[NDIM - 1];\n", xname);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os += " }\n";
|
||||||
|
|
||||||
|
// Store the output to global memory
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
os += fmt::format(
|
||||||
|
" store_vector({0} + index, 0, vec_{0}, size - index);\n",
|
||||||
|
namer.get_name(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
os += "}\n";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
constexpr const char* g_jit_includes = R"(
|
||||||
|
#include "mlx/backend/cuda/device/binary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/ternary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
#define inf cuda::std::numeric_limits<float>::infinity()
|
||||||
|
)";
|
||||||
|
|
||||||
|
void Compiled::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
nvtx3::scoped_range r("Compiled::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
|
||||||
|
// Determine the work per thread for the vectorized reads/writes. We take it
|
||||||
|
// as 16 over the max itemsize for the outputs. Another heuristic could be
|
||||||
|
// over the max itemsize of all arrays.
|
||||||
|
int max_size = 1;
|
||||||
|
for (const auto& x : outputs) {
|
||||||
|
max_size = (max_size > x.itemsize()) ? max_size : x.itemsize();
|
||||||
|
}
|
||||||
|
int work_per_thread = 16 / max_size;
|
||||||
|
|
||||||
|
cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
|
||||||
|
// Build source code.
|
||||||
|
cu::FusedKernelBuilder builder{
|
||||||
|
g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_};
|
||||||
|
builder.os +=
|
||||||
|
"namespace mlx::core::cu {\n\n"
|
||||||
|
"namespace cg = cooperative_groups;\n\n";
|
||||||
|
builder.build("_contiguous", true);
|
||||||
|
builder.os += "\n";
|
||||||
|
builder.build("_strided", false);
|
||||||
|
builder.os += "\n} // namespace mlx::core::cu\n";
|
||||||
|
// Build kernel names.
|
||||||
|
std::vector<std::string> kernel_names;
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_contiguous<uint32_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
work_per_thread));
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_contiguous<int64_t, {}>",
|
||||||
|
lib_name(),
|
||||||
|
work_per_thread));
|
||||||
|
for (auto wpt : std::array<int, 2>{1, work_per_thread}) {
|
||||||
|
for (int i = 1; i <= MAX_NDIM; ++i) {
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_strided<{}, uint32_t, {}>", lib_name(), i, wpt));
|
||||||
|
kernel_names.push_back(fmt::format(
|
||||||
|
"mlx::core::cu::{}_strided<{}, int64_t, {}>", lib_name(), i, wpt));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_pair(std::move(builder.os), std::move(kernel_names));
|
||||||
|
});
|
||||||
|
|
||||||
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
|
// handle all broadcasting.
|
||||||
|
auto [contiguous, shape, strides_vec] =
|
||||||
|
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||||
|
|
||||||
|
// Whether to use large index.
|
||||||
|
bool large = compiled_use_large_index(inputs, outputs, contiguous);
|
||||||
|
|
||||||
|
cu::KernelArgs args;
|
||||||
|
// Put inputs.
|
||||||
|
int strides_index = 1;
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
if (is_constant_(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
args.append(x);
|
||||||
|
if (!contiguous && !is_scalar(x)) {
|
||||||
|
args.append_ptr(strides_vec[strides_index++].data());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put outputs.
|
||||||
|
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||||
|
for (auto& x : outputs) {
|
||||||
|
args.append(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put shape and size.
|
||||||
|
if (!contiguous) {
|
||||||
|
args.append_ptr(shape.data());
|
||||||
|
}
|
||||||
|
if (large) {
|
||||||
|
args.append<int64_t>(outputs[0].data_size());
|
||||||
|
} else {
|
||||||
|
args.append<uint32_t>(outputs[0].data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Choose work per thread
|
||||||
|
if (!contiguous && shape.back() % work_per_thread != 0) {
|
||||||
|
work_per_thread = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch kernel.
|
||||||
|
const char* index_type = large ? "int64_t" : "uint32_t";
|
||||||
|
std::string kernel_name = fmt::format("mlx::core::cu::{}", lib_name());
|
||||||
|
if (contiguous) {
|
||||||
|
kernel_name +=
|
||||||
|
fmt::format("_contiguous<{}, {}>", index_type, work_per_thread);
|
||||||
|
} else {
|
||||||
|
kernel_name += fmt::format(
|
||||||
|
"_strided<{}, {}, {}>", shape.size(), index_type, work_per_thread);
|
||||||
|
}
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
for (const auto& in : inputs) {
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
}
|
||||||
|
for (const auto& out : outputs) {
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel = mod.get_kernel(kernel_name);
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(outputs[0], large, work_per_thread);
|
||||||
|
encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
546
mlx/backend/cuda/conv.cpp
Normal file
546
mlx/backend/cuda/conv.cpp
Normal file
@@ -0,0 +1,546 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
// cudnn_frontend.h redefines this macro.
|
||||||
|
#undef CHECK_CUDA_ERROR
|
||||||
|
|
||||||
|
#include <cudnn_frontend.h>
|
||||||
|
#include <cudnn_frontend_find_plan.h>
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Not all engines support it so can not use this API now.
|
||||||
|
#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0
|
||||||
|
|
||||||
|
// Alias for better readability.
|
||||||
|
#define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR
|
||||||
|
#define CONV_BACKWARD_INPUT \
|
||||||
|
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
|
||||||
|
#define CONV_BACKWARD_WEIGHT \
|
||||||
|
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
|
||||||
|
|
||||||
|
struct ConvCacheKey {
|
||||||
|
int device_id;
|
||||||
|
cudnnDataType_t cudnn_dtype;
|
||||||
|
std::array<int, MAX_NDIM> input_shape;
|
||||||
|
std::array<int, MAX_NDIM> weight_shape;
|
||||||
|
std::array<int, MAX_NDIM> stride;
|
||||||
|
std::array<int, MAX_NDIM> padding_lo;
|
||||||
|
std::array<int, MAX_NDIM> padding_hi;
|
||||||
|
std::array<int, MAX_NDIM> dilation;
|
||||||
|
int groups;
|
||||||
|
bool flip;
|
||||||
|
uint8_t input_alignment;
|
||||||
|
uint8_t weight_alignment;
|
||||||
|
uint8_t output_alignment;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto& conv_cache() {
|
||||||
|
static LRUBytesKeyCache<
|
||||||
|
ConvCacheKey,
|
||||||
|
std::pair<cudnnBackendDescriptorType_t, cudnn_frontend::ExecutionPlan>>
|
||||||
|
cache(/* capacity */ 128);
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Vec>
|
||||||
|
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||||
|
return SmallVector<T>(vec.begin(), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, template <typename U> class Vec>
|
||||||
|
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
|
||||||
|
if (vec.size() > MAX_NDIM) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||||
|
}
|
||||||
|
std::array<T, MAX_NDIM> result = {};
|
||||||
|
std::copy_n(vec.begin(), vec.size(), result.begin());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nhwc_to_nchw(const array& x) {
|
||||||
|
auto shape = convert_vector<int64_t>(x.shape());
|
||||||
|
shape.insert(shape.begin() + 1, shape.back());
|
||||||
|
shape.erase(shape.end() - 1);
|
||||||
|
auto strides = convert_vector<int64_t>(x.strides());
|
||||||
|
strides.insert(strides.begin() + 1, strides.back());
|
||||||
|
strides.erase(strides.end() - 1);
|
||||||
|
return std::make_tuple(std::move(shape), std::move(strides));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||||
|
switch (dtype) {
|
||||||
|
case int8:
|
||||||
|
return CUDNN_DATA_INT8;
|
||||||
|
case int32:
|
||||||
|
return CUDNN_DATA_INT32;
|
||||||
|
case uint8:
|
||||||
|
return CUDNN_DATA_UINT8;
|
||||||
|
case float16:
|
||||||
|
return CUDNN_DATA_HALF;
|
||||||
|
case bfloat16:
|
||||||
|
return CUDNN_DATA_BFLOAT16;
|
||||||
|
case float32:
|
||||||
|
return CUDNN_DATA_FLOAT;
|
||||||
|
case float64:
|
||||||
|
return CUDNN_DATA_DOUBLE;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(fmt::format(
|
||||||
|
"Unsupported dtype in Convolution: {}.", dtype_to_string(dtype)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint8_t get_alignment(const array& x) {
|
||||||
|
uint8_t alignment = 1;
|
||||||
|
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
|
||||||
|
for (; alignment < 32; alignment *= 2) {
|
||||||
|
if (address % (alignment * 2)) {
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) {
|
||||||
|
auto [shape, strides] = nhwc_to_nchw(x);
|
||||||
|
return cudnn_frontend::TensorBuilder()
|
||||||
|
.setDim(shape.size(), shape.data())
|
||||||
|
.setStrides(strides.size(), strides.data())
|
||||||
|
.setId(id)
|
||||||
|
.setAlignment(get_alignment(x))
|
||||||
|
.setDataType(dtype_to_cudnn_type(x.dtype()))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigList get_engine_configs(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
cudnn_frontend::OperationGraph& op_graph,
|
||||||
|
bool use_fallback = false) {
|
||||||
|
cudnn_frontend::GeneratorSource source;
|
||||||
|
if (use_fallback) {
|
||||||
|
source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setOperation(backend_type)
|
||||||
|
.build();
|
||||||
|
return fallback.getFallbackList();
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
source = [](cudnn_frontend::OperationGraph& op_graph) {
|
||||||
|
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
|
||||||
|
.setOperationGraph(op_graph)
|
||||||
|
.setHeurMode(CUDNN_HEUR_MODE_A)
|
||||||
|
.build();
|
||||||
|
return heuristics.getEngineConfig(heuristics.getEngineConfigCount());
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigGenerator generator(1, &source);
|
||||||
|
auto configs = generator.generate_engine_config(op_graph);
|
||||||
|
|
||||||
|
cudnn_frontend::EngineConfigList filtered_configs;
|
||||||
|
cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) {
|
||||||
|
if (cudnn_frontend::hasNumericalNote<
|
||||||
|
CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c) &&
|
||||||
|
dtype == float32 && !env::enable_tf32()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
return filtered_configs;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool execute_plan(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnn_frontend::ExecutionPlan& plan,
|
||||||
|
array& x,
|
||||||
|
array& w,
|
||||||
|
array& y) {
|
||||||
|
int workspace_size = plan.getWorkspaceSize();
|
||||||
|
array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8);
|
||||||
|
|
||||||
|
int64_t uids[3] = {'x', 'w', 'y'};
|
||||||
|
void* data_ptrs[3] = {
|
||||||
|
x.data<void>(),
|
||||||
|
w.data<void>(),
|
||||||
|
y.data<void>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
auto variantPack = cudnn_frontend::VariantPackBuilder()
|
||||||
|
.setWorkspacePointer(workspace.data<void>())
|
||||||
|
.setDataPointers(3, data_ptrs)
|
||||||
|
.setUids(3, uids)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
auto handle = encoder.device().cudnn_handle();
|
||||||
|
cudnnSetStream(handle, encoder.stream());
|
||||||
|
|
||||||
|
#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API
|
||||||
|
cudaGraph_t graph;
|
||||||
|
cudaGraphCreate(&graph, 0);
|
||||||
|
std::unique_ptr<cudaGraph_t, void (*)(cudaGraph_t*)> graph_freer(
|
||||||
|
&graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); });
|
||||||
|
if (cudnnBackendPopulateCudaGraph(
|
||||||
|
handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
encoder.add_graph_node(graph);
|
||||||
|
#else
|
||||||
|
auto capture = encoder.capture_context();
|
||||||
|
if (cudnnBackendExecute(
|
||||||
|
handle, plan.get_raw_desc(), variantPack.get_raw_desc()) !=
|
||||||
|
CUDNN_STATUS_SUCCESS) {
|
||||||
|
// Discard the captured graph when failed.
|
||||||
|
capture.discard = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool try_engines(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
const ConvCacheKey& cache_key,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
cudnn_frontend::EngineConfigList& configs,
|
||||||
|
const std::string& op_graph_tag,
|
||||||
|
array& x,
|
||||||
|
array& w,
|
||||||
|
array& y) {
|
||||||
|
for (auto& config : configs) {
|
||||||
|
try {
|
||||||
|
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||||
|
.setHandle(encoder.device().cudnn_handle())
|
||||||
|
.setEngineConfig(config, op_graph_tag)
|
||||||
|
.build();
|
||||||
|
if (execute_plan(encoder, plan, x, w, y)) {
|
||||||
|
conv_cache().emplace(
|
||||||
|
cache_key, std::make_pair(backend_type, std::move(plan)));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} catch (cudnn_frontend::cudnnException& error) {
|
||||||
|
if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) {
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_conv_op_settings(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array& x,
|
||||||
|
array& w,
|
||||||
|
array& y,
|
||||||
|
const std::vector<int>& kernel_strides,
|
||||||
|
const std::vector<int>& padding_lo_,
|
||||||
|
const std::vector<int>& padding_hi_,
|
||||||
|
const std::vector<int>& kernel_dilation,
|
||||||
|
const std::vector<int>& input_dilation) {
|
||||||
|
auto padding_lo = convert_vector<int64_t>(padding_lo_);
|
||||||
|
auto padding_hi = convert_vector<int64_t>(padding_hi_);
|
||||||
|
|
||||||
|
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||||
|
for (int i = 0; i < padding_lo.size(); ++i) {
|
||||||
|
int wt_size = 1 + kernel_dilation[i] * (w.shape(1 + i) - 1);
|
||||||
|
padding_lo[i] = wt_size - padding_lo[i] - 1;
|
||||||
|
int in_size = 1 + kernel_strides[i] * (x.shape(1 + i) - 1);
|
||||||
|
int out_size = 1 + input_dilation[i] * (y.shape(1 + i) - 1);
|
||||||
|
padding_hi[i] = out_size - in_size + padding_hi[i];
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
convert_vector<int64_t>(input_dilation),
|
||||||
|
std::move(padding_lo),
|
||||||
|
std::move(padding_hi),
|
||||||
|
convert_vector<int64_t>(kernel_dilation));
|
||||||
|
|
||||||
|
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
|
padding_hi = padding_lo;
|
||||||
|
return std::make_tuple(
|
||||||
|
convert_vector<int64_t>(kernel_dilation),
|
||||||
|
std::move(padding_lo),
|
||||||
|
std::move(padding_hi),
|
||||||
|
convert_vector<int64_t>(kernel_strides));
|
||||||
|
|
||||||
|
} else {
|
||||||
|
return std::make_tuple(
|
||||||
|
convert_vector<int64_t>(kernel_strides),
|
||||||
|
std::move(padding_lo),
|
||||||
|
std::move(padding_hi),
|
||||||
|
convert_vector<int64_t>(kernel_dilation));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
Dtype dtype,
|
||||||
|
array& x,
|
||||||
|
array& w,
|
||||||
|
array& y,
|
||||||
|
const SmallVector<int64_t>& stride,
|
||||||
|
const SmallVector<int64_t>& padding_lo,
|
||||||
|
const SmallVector<int64_t>& padding_hi,
|
||||||
|
const SmallVector<int64_t>& dilation) {
|
||||||
|
try {
|
||||||
|
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
||||||
|
? CUDNN_DATA_FLOAT
|
||||||
|
: dtype_to_cudnn_type(dtype);
|
||||||
|
auto conv_desc = cudnn_frontend::ConvDescBuilder()
|
||||||
|
.setDataType(compute_dtype)
|
||||||
|
.setMathMode(CUDNN_CROSS_CORRELATION)
|
||||||
|
.setNDims(stride.size())
|
||||||
|
.setStrides(stride.size(), stride.data())
|
||||||
|
.setPrePadding(padding_lo.size(), padding_lo.data())
|
||||||
|
.setPostPadding(padding_hi.size(), padding_hi.data())
|
||||||
|
.setDilation(dilation.size(), dilation.data())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
auto op = cudnn_frontend::OperationBuilder(backend_type)
|
||||||
|
.setxDesc(build_tensor('x', x))
|
||||||
|
.setwDesc(build_tensor('w', w))
|
||||||
|
.setyDesc(build_tensor('y', y))
|
||||||
|
.setcDesc(conv_desc)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
|
||||||
|
return cudnn_frontend::OperationGraphBuilder()
|
||||||
|
.setHandle(encoder.device().cudnn_handle())
|
||||||
|
.setOperationGraph(ops.size(), ops.data())
|
||||||
|
.build();
|
||||||
|
} catch (cudnn_frontend::cudnnException& error) {
|
||||||
|
if (error.getCudnnStatus() != CUDNN_STATUS_BAD_PARAM) {
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do necessary transposes and copies to prepare the inputs and outputs for
|
||||||
|
// building the cuDNN conv op. It is safe to be called multiple times in one
|
||||||
|
// eval_gpu, with cost of possible redundant copies.
|
||||||
|
std::tuple<array, array, array> prepare_args(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array in,
|
||||||
|
array wt,
|
||||||
|
array out,
|
||||||
|
Stream s) {
|
||||||
|
// Transpose the args depending on the backend type.
|
||||||
|
// TODO: Handle groups.
|
||||||
|
if (backend_type == CONV_BACKWARD_INPUT) {
|
||||||
|
wt = swapaxes_in_eval(wt, 0, -1);
|
||||||
|
} else if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
|
in = swapaxes_in_eval(in, 0, -1);
|
||||||
|
wt = swapaxes_in_eval(wt, 0, -1);
|
||||||
|
// Create a contiguous array that shares the data with |out|, but with dim
|
||||||
|
// C_in and C_out swapped.
|
||||||
|
Shape shape(out.shape());
|
||||||
|
std::swap(shape.front(), shape.back());
|
||||||
|
Strides strides(shape.size(), 1);
|
||||||
|
for (int i = shape.size() - 2; i >= 0; --i) {
|
||||||
|
strides[i] = shape[i + 1] * strides[i + 1];
|
||||||
|
}
|
||||||
|
array intermediate(std::move(shape), out.dtype(), nullptr, {});
|
||||||
|
intermediate.copy_shared_buffer(
|
||||||
|
out, std::move(strides), {true, true, false}, out.data_size());
|
||||||
|
out = intermediate;
|
||||||
|
}
|
||||||
|
|
||||||
|
// cuDNN requires contiguous input.
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
in = contiguous_copy_gpu(in, s);
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
}
|
||||||
|
if (!wt.flags().row_contiguous) {
|
||||||
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
|
encoder.add_temporary(wt);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {std::move(in), std::move(wt), std::move(out)};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the x/w/y args from the in/wt/out args depending on backend type.
|
||||||
|
inline std::tuple<array&, array&, array&> dispatch_args(
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array& in,
|
||||||
|
array& wt,
|
||||||
|
array& out) {
|
||||||
|
switch (backend_type) {
|
||||||
|
case CONV_BACKWARD_INPUT:
|
||||||
|
return {out, wt, in};
|
||||||
|
case CONV_BACKWARD_WEIGHT:
|
||||||
|
return {in, out, wt};
|
||||||
|
default:
|
||||||
|
return {in, wt, out};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register inputs and outputs before actually running conv op. Can only be
|
||||||
|
// called once per eval_gpu.
|
||||||
|
void register_args(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
cudnnBackendDescriptorType_t backend_type,
|
||||||
|
array& in,
|
||||||
|
array& wt,
|
||||||
|
array& intermediate_out,
|
||||||
|
array& final_out) {
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_input_array(wt);
|
||||||
|
encoder.set_output_array(final_out);
|
||||||
|
|
||||||
|
if (backend_type == CONV_BACKWARD_WEIGHT) {
|
||||||
|
// Turn |out| into a strided array, which will have C_in and C_out swapped
|
||||||
|
// in vjp and the final |grad_weight| will then be contiguous.
|
||||||
|
Strides strides = intermediate_out.strides();
|
||||||
|
std::swap(strides.front(), strides.back());
|
||||||
|
final_out.copy_shared_buffer(
|
||||||
|
intermediate_out,
|
||||||
|
std::move(strides),
|
||||||
|
{false, false, false},
|
||||||
|
intermediate_out.data_size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||||
|
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||||
|
if (out_.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
array in = inputs[0];
|
||||||
|
array wt = inputs[1];
|
||||||
|
array out = out_;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
Dtype dtype = out.dtype();
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
// Search cache.
|
||||||
|
ConvCacheKey cache_key{
|
||||||
|
encoder.device().cuda_device(),
|
||||||
|
dtype_to_cudnn_type(dtype),
|
||||||
|
fixed_vector(in.shape()),
|
||||||
|
fixed_vector(wt.shape()),
|
||||||
|
fixed_vector(kernel_strides_),
|
||||||
|
fixed_vector(padding_lo_),
|
||||||
|
fixed_vector(padding_hi_),
|
||||||
|
fixed_vector(kernel_dilation_),
|
||||||
|
groups_,
|
||||||
|
flip_,
|
||||||
|
get_alignment(in),
|
||||||
|
get_alignment(wt),
|
||||||
|
get_alignment(out)};
|
||||||
|
if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) {
|
||||||
|
auto& [backend_type, plan] = it->second;
|
||||||
|
std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s);
|
||||||
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
|
if (!execute_plan(encoder, plan, x, w, y)) {
|
||||||
|
throw std::runtime_error("[conv] Cached plan failed to execute.");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// There is no reliable way to deduce the proper cuDNN backend for the
|
||||||
|
// convolution, so we make a best guess and then try.
|
||||||
|
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
|
||||||
|
if (flip_) {
|
||||||
|
// When weight is flipped, we assume it is backward input convolution.
|
||||||
|
try_backends.push_back(CONV_BACKWARD_INPUT);
|
||||||
|
} else {
|
||||||
|
// Otherwise it could be backward weight convolution or forward convolution,
|
||||||
|
// mathematically there is no difference so we have to use heuristics.
|
||||||
|
// Empirically backward convolutions have large kernel dimensions, and
|
||||||
|
// usually have |in| and |wt| transposed.
|
||||||
|
if (!in.flags().row_contiguous && !wt.flags().row_contiguous &&
|
||||||
|
wt.shape(2) > out.shape(2)) {
|
||||||
|
try_backends = {CONV_BACKWARD_WEIGHT, CONV_FORWARD};
|
||||||
|
} else {
|
||||||
|
try_backends = {CONV_FORWARD, CONV_BACKWARD_WEIGHT};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to build op graph.
|
||||||
|
cudnnBackendDescriptorType_t backend_type;
|
||||||
|
std::optional<cudnn_frontend::OperationGraph> op_graph;
|
||||||
|
for (auto try_backend : try_backends) {
|
||||||
|
auto [in_copy, wt_copy, out_copy] =
|
||||||
|
prepare_args(encoder, try_backend, in, wt, out, s);
|
||||||
|
auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy);
|
||||||
|
auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings(
|
||||||
|
try_backend,
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
y,
|
||||||
|
kernel_strides_,
|
||||||
|
padding_lo_,
|
||||||
|
padding_hi_,
|
||||||
|
kernel_dilation_,
|
||||||
|
input_dilation_);
|
||||||
|
op_graph = build_op_graph(
|
||||||
|
encoder,
|
||||||
|
try_backend,
|
||||||
|
dtype,
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
y,
|
||||||
|
stride,
|
||||||
|
padding_lo,
|
||||||
|
padding_hi,
|
||||||
|
dilation);
|
||||||
|
if (op_graph) {
|
||||||
|
backend_type = try_backend;
|
||||||
|
in = std::move(in_copy);
|
||||||
|
wt = std::move(wt_copy);
|
||||||
|
out = std::move(out_copy);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!op_graph) {
|
||||||
|
throw std::runtime_error("[conv] Can not build op graph.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get ready to execute the graph.
|
||||||
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
|
|
||||||
|
// Try to run plans based on heuristics.
|
||||||
|
auto configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||||
|
auto tag = op_graph->getTag();
|
||||||
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
|
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Then try fallback plans.
|
||||||
|
configs = get_engine_configs(backend_type, dtype, *op_graph);
|
||||||
|
if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("[conv] Unable to find a working engine.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
87
mlx/backend/cuda/copy.cu
Normal file
87
mlx/backend/cuda/copy.cu
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void copy_gpu_inplace(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in,
|
||||||
|
const Strides& strides_out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
CopyType ctype,
|
||||||
|
const Stream& s,
|
||||||
|
const std::optional<array>& dynamic_offset_in,
|
||||||
|
const std::optional<array>& dynamic_offset_out) {
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||||
|
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
|
auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(
|
||||||
|
shape, std::vector{strides_in, strides_out}, INT32_MAX);
|
||||||
|
if (ctype == CopyType::General) {
|
||||||
|
copy_general_input(
|
||||||
|
encoder,
|
||||||
|
ctype,
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
offset_in,
|
||||||
|
offset_out,
|
||||||
|
shape_collapsed,
|
||||||
|
strides_vec[0]);
|
||||||
|
} else {
|
||||||
|
if (dynamic_offset_in || dynamic_offset_out) {
|
||||||
|
copy_general_dynamic(
|
||||||
|
encoder,
|
||||||
|
ctype,
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
offset_in,
|
||||||
|
offset_out,
|
||||||
|
shape_collapsed,
|
||||||
|
strides_vec[0],
|
||||||
|
strides_vec[1],
|
||||||
|
dynamic_offset_in ? *dynamic_offset_in : array(0, int64),
|
||||||
|
dynamic_offset_out ? *dynamic_offset_out : array(0, int64));
|
||||||
|
} else {
|
||||||
|
copy_general(
|
||||||
|
encoder,
|
||||||
|
ctype,
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
offset_in,
|
||||||
|
offset_out,
|
||||||
|
shape_collapsed,
|
||||||
|
strides_vec[0],
|
||||||
|
strides_vec[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill_gpu(const array& in, array& out, const Stream& s) {
|
||||||
|
if (out.size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
55
mlx/backend/cuda/copy/copy.cuh
Normal file
55
mlx/backend/cuda/copy/copy.cuh
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/backend/cuda/device/cast_op.cuh"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void copy_contiguous(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out);
|
||||||
|
|
||||||
|
void copy_general(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in,
|
||||||
|
const Strides& strides_out);
|
||||||
|
|
||||||
|
void copy_general_dynamic(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in,
|
||||||
|
const Strides& strides_out,
|
||||||
|
const array& dynamic_offset_in,
|
||||||
|
const array& dynamic_offset_out);
|
||||||
|
|
||||||
|
void copy_general_input(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
88
mlx/backend/cuda/copy/copy_contiguous.cu
Normal file
88
mlx/backend/cuda/copy/copy_contiguous.cu
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||||
|
__global__ void copy_s(const In* in, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
|
||||||
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
|
out[i] = cast_to<Out>(in[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = cast_to<Out>(in[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT, int N_READS>
|
||||||
|
__global__ void copy_v(const In* in, Out* out, IdxT size) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
|
||||||
|
if ((index + 1) * N_READS > size) {
|
||||||
|
for (IdxT i = index * N_READS; i < size; ++i) {
|
||||||
|
out[i] = cast_to<Out>(in[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto in_vec = load_vector<N_READS>(in, index);
|
||||||
|
|
||||||
|
AlignedVector<Out, N_READS> out_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < N_READS; ++i) {
|
||||||
|
out_vec[i] = cast_to<Out>(in_vec[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
store_vector<N_READS>(out, index, out_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void copy_contiguous(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t in_offset,
|
||||||
|
int64_t out_offset) {
|
||||||
|
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||||
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
|
dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) {
|
||||||
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
|
constexpr int N_READS = 16 / sizeof(InType);
|
||||||
|
auto kernel = cu::copy_s<InType, OutType, IdxT, N_READS>;
|
||||||
|
if (ctype == CopyType::Vector) {
|
||||||
|
kernel = cu::copy_v<InType, OutType, IdxT, N_READS>;
|
||||||
|
}
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
|
out.data_size(), out.shape(), out.strides(), large(), N_READS);
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
kernel,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in.data<InType>() + in_offset,
|
||||||
|
out.data<OutType>() + out_offset,
|
||||||
|
out.data_size());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
109
mlx/backend/cuda/copy/copy_general.cu
Normal file
109
mlx/backend/cuda/copy/copy_general.cu
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||||
|
__global__ void copy_gg_nd(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||||
|
index, shape.data(), strides_in.data(), strides_out.data());
|
||||||
|
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void copy_gg(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides strides_in,
|
||||||
|
const __grid_constant__ Strides strides_out,
|
||||||
|
int ndim) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [idx_in, idx_out] = elem_to_loc(
|
||||||
|
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||||
|
out[idx_out] = CastOp<In, Out>{}(in[idx_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void copy_general(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in,
|
||||||
|
const Strides& strides_out) {
|
||||||
|
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||||
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
|
dispatch_bool(
|
||||||
|
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||||
|
[&](auto large) {
|
||||||
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
|
int ndim = shape.size();
|
||||||
|
size_t data_size = 1;
|
||||||
|
for (auto& s : shape)
|
||||||
|
data_size *= s;
|
||||||
|
if (ndim <= 3) {
|
||||||
|
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(data_size, shape, out.strides(), large());
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
data_size,
|
||||||
|
const_param<ndim_constant()>(shape),
|
||||||
|
const_param<ndim_constant()>(strides_in),
|
||||||
|
const_param<ndim_constant()>(strides_out));
|
||||||
|
});
|
||||||
|
} else { // ndim >= 4
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(data_size, shape, out.strides(), large());
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::copy_gg<InType, OutType, IdxT>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
data_size,
|
||||||
|
const_param(shape),
|
||||||
|
const_param(strides_in),
|
||||||
|
const_param(strides_out),
|
||||||
|
ndim);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
118
mlx/backend/cuda/copy/copy_general_dynamic.cu
Normal file
118
mlx/backend/cuda/copy/copy_general_dynamic.cu
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||||
|
__global__ void copy_gg_dynamic_nd(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_out,
|
||||||
|
const int64_t* offset_in,
|
||||||
|
const int64_t* offset_out) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [idx_in, idx_out] = elem_to_loc_nd<NDIM>(
|
||||||
|
index, shape.data(), strides_in.data(), strides_out.data());
|
||||||
|
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void copy_gg_dynamic(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides strides_in,
|
||||||
|
const __grid_constant__ Strides strides_out,
|
||||||
|
int ndim,
|
||||||
|
const int64_t* offset_in,
|
||||||
|
const int64_t* offset_out) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
auto [idx_in, idx_out] = elem_to_loc(
|
||||||
|
index, shape.data(), strides_in.data(), strides_out.data(), ndim);
|
||||||
|
out[idx_out + *offset_out] = CastOp<In, Out>{}(in[idx_in + *offset_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void copy_general_dynamic(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in,
|
||||||
|
const Strides& strides_out,
|
||||||
|
const array& dynamic_offset_in,
|
||||||
|
const array& dynamic_offset_out) {
|
||||||
|
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||||
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
|
dispatch_bool(
|
||||||
|
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||||
|
[&](auto large) {
|
||||||
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
|
int ndim = shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::copy_gg_dynamic_nd<
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant()>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
out.size(),
|
||||||
|
const_param<dims_constant()>(shape),
|
||||||
|
const_param<dims_constant()>(strides_in),
|
||||||
|
const_param<dims_constant()>(strides_out),
|
||||||
|
dynamic_offset_in.data<int64_t>(),
|
||||||
|
dynamic_offset_out.data<int64_t>());
|
||||||
|
});
|
||||||
|
} else { // ndim >= 4
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::copy_gg_dynamic<InType, OutType, IdxT>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
out.size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(strides_in),
|
||||||
|
const_param(strides_out),
|
||||||
|
ndim,
|
||||||
|
dynamic_offset_in.data<int64_t>(),
|
||||||
|
dynamic_offset_out.data<int64_t>());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
97
mlx/backend/cuda/copy/copy_general_input.cu
Normal file
97
mlx/backend/cuda/copy/copy_general_input.cu
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/copy/copy.cuh"
|
||||||
|
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace cu {
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT, int NDIM>
|
||||||
|
__global__ void copy_g_nd(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> strides_in) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
IdxT idx_in = elem_to_loc_nd<NDIM>(index, shape.data(), strides_in.data());
|
||||||
|
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename In, typename Out, typename IdxT>
|
||||||
|
__global__ void copy_g(
|
||||||
|
const In* in,
|
||||||
|
Out* out,
|
||||||
|
IdxT size,
|
||||||
|
const __grid_constant__ Shape shape,
|
||||||
|
const __grid_constant__ Strides strides_in,
|
||||||
|
int ndim) {
|
||||||
|
IdxT index = cg::this_grid().thread_rank();
|
||||||
|
if (index < size) {
|
||||||
|
IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim);
|
||||||
|
out[index] = CastOp<In, Out>{}(in[idx_in]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cu
|
||||||
|
|
||||||
|
void copy_general_input(
|
||||||
|
cu::CommandEncoder& encoder,
|
||||||
|
CopyType ctype,
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int64_t offset_in,
|
||||||
|
int64_t offset_out,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides_in) {
|
||||||
|
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||||
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
|
dispatch_bool(
|
||||||
|
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||||
|
[&](auto large) {
|
||||||
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
|
int ndim = shape.size();
|
||||||
|
if (ndim <= 3) {
|
||||||
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
out.size(),
|
||||||
|
const_param<dims_constant()>(shape),
|
||||||
|
const_param<dims_constant()>(strides_in));
|
||||||
|
});
|
||||||
|
} else { // ndim >= 4
|
||||||
|
auto [num_blocks, block_dims] = get_launch_args(out, large());
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::copy_g<InType, OutType, IdxT>,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
out.size(),
|
||||||
|
const_param(shape),
|
||||||
|
const_param(strides_in),
|
||||||
|
ndim);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user