mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 00:39:06 +08:00
Compare commits
403 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54f1cc6e3e | ||
|
|
b3825ac149 | ||
|
|
7f4b7e553c | ||
|
|
ad16f41a7f | ||
|
|
f46877bc08 | ||
|
|
6f35017d1b | ||
|
|
b167f0df1c | ||
|
|
a9f0d6b160 | ||
|
|
940f4c7818 | ||
|
|
35f81728f1 | ||
|
|
4442ed86c1 | ||
|
|
698559c231 | ||
|
|
ecc4879b07 | ||
|
|
32b18d8b66 | ||
|
|
472c43a0c8 | ||
|
|
b7214ff01e | ||
|
|
76414c8971 | ||
|
|
49e4566df3 | ||
|
|
aad49f932f | ||
|
|
86765cce34 | ||
|
|
1bedcbd556 | ||
|
|
9ac7dbe877 | ||
|
|
1bf605d56d | ||
|
|
3c622ddd1d | ||
|
|
27ff069175 | ||
|
|
3b2ffcefc3 | ||
|
|
b65f882df3 | ||
|
|
b704e9e77a | ||
|
|
66519fb348 | ||
|
|
8973550ff3 | ||
|
|
3f866be665 | ||
|
|
23f81ed1c1 | ||
|
|
3fe2250c00 | ||
|
|
047114b988 | ||
|
|
9320eb89a8 | ||
|
|
75819d70ea | ||
|
|
60d80a3728 | ||
|
|
eba6a9d163 | ||
|
|
be9e2aebd6 | ||
|
|
df58b4133a | ||
|
|
27778156dc | ||
|
|
761f901a41 | ||
|
|
6ece97f69b | ||
|
|
d3bc6a9bff | ||
|
|
26ceb507eb | ||
|
|
910b3e3299 | ||
|
|
50fa315d18 | ||
|
|
1ff2b713b6 | ||
|
|
50514a6146 | ||
|
|
93d76b0f30 | ||
|
|
78678de0cd | ||
|
|
ed9c6b1117 | ||
|
|
39b04ce638 | ||
|
|
d9e6349657 | ||
|
|
b901a9f311 | ||
|
|
68c5fa1c95 | ||
|
|
793a31eeb6 | ||
|
|
74c1ed25bb | ||
|
|
ec72b44417 | ||
|
|
460691a0e8 | ||
|
|
969924cc69 | ||
|
|
d1e06117e8 | ||
|
|
539d8322d1 | ||
|
|
c4767d110f | ||
|
|
895217f25b | ||
|
|
0cfeeb60ca | ||
|
|
8f8af61a37 | ||
|
|
233384161e | ||
|
|
5bcf3a6794 | ||
|
|
7707196297 | ||
|
|
7e3471c987 | ||
|
|
9f0ba3ddf1 | ||
|
|
4bce5f9b2d | ||
|
|
e9eab527eb | ||
|
|
36ca62dba8 | ||
|
|
9cbb1b0148 | ||
|
|
9bfc476d72 | ||
|
|
25e2356316 | ||
|
|
226a1d24e0 | ||
|
|
630350ad3e | ||
|
|
380aeb58ae | ||
|
|
f37389d100 | ||
|
|
e89e8b4272 | ||
|
|
85a8824a8c | ||
|
|
f5d4397e5c | ||
|
|
343e33b6d5 | ||
|
|
0073096dd1 | ||
|
|
e3d004fed9 | ||
|
|
a393435d28 | ||
|
|
a7a94b29d7 | ||
|
|
22a5da76c8 | ||
|
|
287c63a093 | ||
|
|
1c9ae1eaa1 | ||
|
|
c2c3e0b0a2 | ||
|
|
b0cc71ae71 | ||
|
|
e88f2d4a8e | ||
|
|
9cee557423 | ||
|
|
bbf1423953 | ||
|
|
eb24267b56 | ||
|
|
dc371ae7a5 | ||
|
|
e76a8dd5c5 | ||
|
|
b466dea982 | ||
|
|
7a6adda1e6 | ||
|
|
1a9f820af6 | ||
|
|
d4f4ff3c5e | ||
|
|
7c7e48dbd1 | ||
|
|
fbbf3b9b3e | ||
|
|
bf01ad9367 | ||
|
|
ae438d05fa | ||
|
|
711a645807 | ||
|
|
aa9d44b3d4 | ||
|
|
ec2ab42888 | ||
|
|
787c0d90cd | ||
|
|
e8b604a6a3 | ||
|
|
50cc09887f | ||
|
|
3f730e77aa | ||
|
|
caecbe876a | ||
|
|
8afb6d62f2 | ||
|
|
6ccfa603cd | ||
|
|
36cad99a11 | ||
|
|
ee18e1cbf0 | ||
|
|
af120c2bc0 | ||
|
|
6a3acf2301 | ||
|
|
d6977f2a57 | ||
|
|
db5443e831 | ||
|
|
52b8384d10 | ||
|
|
44cc5da4bc | ||
|
|
dde3682b69 | ||
|
|
17310d91a6 | ||
|
|
b194d65a6a | ||
|
|
a44b27f5f8 | ||
|
|
e5a33f2223 | ||
|
|
c1e3340b23 | ||
|
|
8f163a367d | ||
|
|
89a3df9014 | ||
|
|
c5d2937aa5 | ||
|
|
b61a65e313 | ||
|
|
04cbb4191c | ||
|
|
c5460762e7 | ||
|
|
8ce49cd39e | ||
|
|
9c68b50853 | ||
|
|
111f1e71af | ||
|
|
827003d568 | ||
|
|
d363a76aa4 | ||
|
|
70560b6bd5 | ||
|
|
7ef8a6f2d5 | ||
|
|
31c6f6e33f | ||
|
|
584d48458e | ||
|
|
5cf984ca87 | ||
|
|
a9bac3d9e5 | ||
|
|
5458d43247 | ||
|
|
a4dba65220 | ||
|
|
3dcb286baf | ||
|
|
4822c3dbe9 | ||
|
|
2ca75bb529 | ||
|
|
db14e29a0b | ||
|
|
d2f540f4e0 | ||
|
|
333ffea273 | ||
|
|
f55b6f1f2f | ||
|
|
30561229c7 | ||
|
|
068a4612e9 | ||
|
|
5722c147de | ||
|
|
f6819a1f26 | ||
|
|
f93f87c802 | ||
|
|
9392fc3f88 | ||
|
|
e843c4d8d5 | ||
|
|
0c5fc63a36 | ||
|
|
e397177f6e | ||
|
|
f4c8888cbe | ||
|
|
25c1e03205 | ||
|
|
512281781c | ||
|
|
ac85ddfdb7 | ||
|
|
65d0d40232 | ||
|
|
cea9369610 | ||
|
|
e7c6e1db82 | ||
|
|
c5fcd5b61b | ||
|
|
1df9887998 | ||
|
|
73f22d6226 | ||
|
|
c422050ca7 | ||
|
|
1ba18ff7d9 | ||
|
|
37b440faa8 | ||
|
|
888b13ed63 | ||
|
|
4abb218d21 | ||
|
|
6441c21a94 | ||
|
|
dfb5022eab | ||
|
|
ac207ce7aa | ||
|
|
fce53b61d6 | ||
|
|
8ae4a76308 | ||
|
|
7fde1b6a1e | ||
|
|
aa7b47481a | ||
|
|
56be773610 | ||
|
|
a9bdd67baa | ||
|
|
f2adb5638d | ||
|
|
728d4db582 | ||
|
|
db5c7efcf6 | ||
|
|
7bb96e4249 | ||
|
|
fa89f0b150 | ||
|
|
ca973d1e83 | ||
|
|
828c5f1137 | ||
|
|
7d86a5c108 | ||
|
|
0b807893a7 | ||
|
|
6ad0889c8a | ||
|
|
737dd6d1ac | ||
|
|
aaf78f4c6b | ||
|
|
8831064493 | ||
|
|
be9bc96da4 | ||
|
|
86258f292f | ||
|
|
b26d88591c | ||
|
|
86c6a15571 | ||
|
|
8b25ce62d5 | ||
|
|
da5912e4f2 | ||
|
|
daafee676f | ||
|
|
d32519c8ee | ||
|
|
b405591249 | ||
|
|
3bf81ed1bd | ||
|
|
2204182bba | ||
|
|
3628e5d497 | ||
|
|
a0ae49d397 | ||
|
|
254476718b | ||
|
|
3adba92ebe | ||
|
|
ef631d63af | ||
|
|
970dbe8e25 | ||
|
|
641be9463b | ||
|
|
ab0e608862 | ||
|
|
1588659062 | ||
|
|
b9e88fb976 | ||
|
|
4ad53414dd | ||
|
|
d1165b215e | ||
|
|
dcb8319f3d | ||
|
|
5597fa089c | ||
|
|
9acec364c2 | ||
|
|
7d9d6ef456 | ||
|
|
6f5874a2f2 | ||
|
|
70dc336785 | ||
|
|
4e504039f5 | ||
|
|
d1f4d291e8 | ||
|
|
e1840853ce | ||
|
|
0f5ce173da | ||
|
|
588854195f | ||
|
|
28d068bce6 | ||
|
|
d107d8d495 | ||
|
|
1e496ddb82 | ||
|
|
74eccbf3fa | ||
|
|
08638223ca | ||
|
|
56cc858af9 | ||
|
|
f55c4ed1d6 | ||
|
|
93d70419e7 | ||
|
|
63f663d9c6 | ||
|
|
84b4d96efa | ||
|
|
aec67f2fa6 | ||
|
|
deee214a95 | ||
|
|
45adec102c | ||
|
|
31fc530c76 | ||
|
|
fbb3f65a1a | ||
|
|
6b1b8ea91b | ||
|
|
b2273733ea | ||
|
|
f409b229a4 | ||
|
|
30571e2326 | ||
|
|
d7734edd9f | ||
|
|
2ba69bc8fa | ||
|
|
cb349a291c | ||
|
|
f0a0b077a0 | ||
|
|
49114f28ab | ||
|
|
e7d2ebadd2 | ||
|
|
e569803d7c | ||
|
|
d34f887abc | ||
|
|
5201df5030 | ||
|
|
2d3c26c565 | ||
|
|
6325f60d52 | ||
|
|
42cc9cfbc7 | ||
|
|
8347575ba1 | ||
|
|
b6eec20260 | ||
|
|
0eb035b4b1 | ||
|
|
afb9817599 | ||
|
|
8fb3e7a26c | ||
|
|
8c7bc30ce4 | ||
|
|
85873cb162 | ||
|
|
e14ee12491 | ||
|
|
8b9a3f3cea | ||
|
|
fb4e8b896b | ||
|
|
2ca533b279 | ||
|
|
4a9b29a875 | ||
|
|
a4fcc893cd | ||
|
|
9d10239af7 | ||
|
|
19facd4b20 | ||
|
|
f5299f72cd | ||
|
|
0e0d9ac522 | ||
|
|
8917022deb | ||
|
|
ec0d5db67b | ||
|
|
e76e9b87f0 | ||
|
|
cfb6a244ea | ||
|
|
58f3860306 | ||
|
|
dd4f53db63 | ||
|
|
3d5e17e507 | ||
|
|
33bf1a244b | ||
|
|
772f471ff2 | ||
|
|
2c11d10f8d | ||
|
|
656ed7f780 | ||
|
|
81bb9a2a9e | ||
|
|
5adf185f86 | ||
|
|
c9a9180584 | ||
|
|
76831ed83d | ||
|
|
b3d7b85376 | ||
|
|
cad5c0241c | ||
|
|
b8022c578a | ||
|
|
bc53f8293f | ||
|
|
c552ff2451 | ||
|
|
4fda5fbdf9 | ||
|
|
580776559b | ||
|
|
a14aaa7c9d | ||
|
|
a6d780154f | ||
|
|
6871e2eeb7 | ||
|
|
8402a2acf4 | ||
|
|
fddb6933e1 | ||
|
|
c8b4787e4e | ||
|
|
2188199ff8 | ||
|
|
aa07429bad | ||
|
|
918761a25a | ||
|
|
a4fc671d3e | ||
|
|
f5f65ef48c | ||
|
|
c2dd81a8aa | ||
|
|
d7e680ffe4 | ||
|
|
c371baf53a | ||
|
|
ccf78f566c | ||
|
|
c9fa68664a | ||
|
|
c35f4d089a | ||
|
|
8590c0941e | ||
|
|
095163b8d1 | ||
|
|
99c33d011d | ||
|
|
62fecf3e13 | ||
|
|
7c4eb5d03e | ||
|
|
bae9a6b404 | ||
|
|
004c1d8ef2 | ||
|
|
7ebb2e0193 | ||
|
|
9ce77798b1 | ||
|
|
f8bad60609 | ||
|
|
5866b3857b | ||
|
|
1ca616844b | ||
|
|
2e8cf0b450 | ||
|
|
24f89173d1 | ||
|
|
c6a20b427a | ||
|
|
a5ac9244c4 | ||
|
|
c763fe1be0 | ||
|
|
52dc8c8cd5 | ||
|
|
aede70e81d | ||
|
|
85a8beb5e4 | ||
|
|
0bb89e9e5f | ||
|
|
5685ceb3c7 | ||
|
|
0408ba0a76 | ||
|
|
cbad6c3093 | ||
|
|
1b021f6984 | ||
|
|
95b7551d65 | ||
|
|
db5a7c6192 | ||
|
|
6ef2f67e7f | ||
|
|
f76ee1ffd2 | ||
|
|
54a71f270a | ||
|
|
55b4062dd8 | ||
|
|
79071bfba4 | ||
|
|
7774b87cbd | ||
|
|
35c87741cf | ||
|
|
4cbe605214 | ||
|
|
ab8883dd55 | ||
|
|
eebe73001a | ||
|
|
0359bf02c9 | ||
|
|
237f9e58a8 | ||
|
|
8576e6fe36 | ||
|
|
0654543dcc | ||
|
|
48ef3e74e2 | ||
|
|
7d4b378952 | ||
|
|
7ff5c41e06 | ||
|
|
602f43e3d1 | ||
|
|
a2cadb8218 | ||
|
|
c1eb9d05d9 | ||
|
|
cf6c939e86 | ||
|
|
130df35e1b | ||
|
|
0751263dec | ||
|
|
eca2f3eb97 | ||
|
|
3aa9cf3f9e | ||
|
|
8f3d208dce | ||
|
|
caaa3f1f8c | ||
|
|
659a51919f | ||
|
|
6661387066 | ||
|
|
a7fae8a176 | ||
|
|
0cae0bdac8 | ||
|
|
5a1a5d5ed1 | ||
|
|
1683975acf | ||
|
|
af705590ac | ||
|
|
825124af8f | ||
|
|
9c5e7da507 | ||
|
|
481349495b | ||
|
|
9daa6b003f | ||
|
|
a3a632d567 | ||
|
|
e496c5a4b4 | ||
|
|
ea890d8710 | ||
|
|
aa5d84f102 | ||
|
|
f1606486d2 | ||
|
|
87720a8908 | ||
|
|
bb6565ef14 | ||
|
|
7bb063bcb3 | ||
|
|
b36dd472bb | ||
|
|
167b759a38 | ||
|
|
99b9868859 | ||
|
|
6b2d5448f2 |
@@ -1,600 +0,0 @@
|
||||
version: 2.1
|
||||
|
||||
orbs:
|
||||
apple: ml-explore/pr-approval@0.1.0
|
||||
|
||||
parameters:
|
||||
nightly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
weekly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
parameters:
|
||||
upload-docs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install
|
||||
command: |
|
||||
brew install python@3.9
|
||||
brew install doxygen
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install -r docs/requirements.txt
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
||||
- when:
|
||||
condition:
|
||||
not: << parameters.upload-docs >>
|
||||
steps:
|
||||
- run:
|
||||
name: Build documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd docs && doxygen && make html O=-W
|
||||
- when:
|
||||
condition: << parameters.upload-docs >>
|
||||
steps:
|
||||
- add_ssh_keys:
|
||||
fingerprints:
|
||||
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
||||
- run:
|
||||
name: Upload documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
git config user.email "mlx@group.apple.com"
|
||||
git config user.name "CircleCI Docs"
|
||||
git checkout gh-pages
|
||||
git rebase main
|
||||
cd docs
|
||||
git rm -rf build/html
|
||||
doxygen && make html O=-W
|
||||
git add -f build/html
|
||||
git commit -m "rebase"
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Run style checks
|
||||
command: |
|
||||
pip install pre-commit
|
||||
pre-commit run --all
|
||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py develop
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests
|
||||
|
||||
mac_build_and_test:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.9
|
||||
brew install openmpi
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
pip install unittest-xml-reporting
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd examples/extensions
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext -j8
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source env/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: |
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: m2pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@<< parameters.python_version >>
|
||||
brew install openmpi
|
||||
python<< parameters.python_version >> -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
pip install build
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
<< parameters.build_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload dist/*
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_linux_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
extra_env:
|
||||
type: string
|
||||
default: "DEV_RELEASE=1"
|
||||
docker:
|
||||
- image: ubuntu:20.04
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
PYTHON=python<< parameters.python_version >>
|
||||
apt-get update
|
||||
apt-get upgrade -y
|
||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
||||
apt-get install -y apt-utils
|
||||
apt-get install -y software-properties-common
|
||||
add-apt-repository -y ppa:deadsnakes/ppa
|
||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
apt-get install -y build-essential git
|
||||
$PYTHON -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
and:
|
||||
- matches:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
and:
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
upload-docs: true
|
||||
|
||||
prb:
|
||||
when:
|
||||
matches:
|
||||
pattern: "^pull/\\d+(/head)?$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- hold:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- mac_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.nightly_build >>
|
||||
jobs:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
weekly_build:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.weekly_build >>
|
||||
jobs:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.linux_release >>
|
||||
jobs:
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
20
.github/actions/build-cuda-release/action.yml
vendored
Normal file
20
.github/actions/build-cuda-release/action.yml
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
name: 'Build CUDA wheel'
|
||||
description: 'Build CUDA wheel'
|
||||
|
||||
inputs:
|
||||
toolkit:
|
||||
description: 'The CUDA toolkit'
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build package
|
||||
shell: bash
|
||||
env:
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
||||
run: |
|
||||
pip install auditwheel build patchelf setuptools
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
bash python/scripts/repair_cuda.sh
|
||||
26
.github/actions/build-cuda/action.yml
vendored
Normal file
26
.github/actions/build-cuda/action.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: 'Build and Test with CUDA'
|
||||
description: 'Build and test MLX with CUDA'
|
||||
|
||||
inputs:
|
||||
toolkit:
|
||||
description: 'The CUDA toolkit'
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install Python package
|
||||
shell: bash
|
||||
env:
|
||||
DEBUG: 1
|
||||
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
||||
run: pip install --no-build-isolation -e ".[dev]" -v
|
||||
|
||||
- name: Build CPP only
|
||||
shell: bash
|
||||
run: |
|
||||
cmake . -B build \
|
||||
-DMLX_BUILD_CUDA=ON \
|
||||
-DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \
|
||||
-DCMAKE_BUILD_TYPE=DEBUG
|
||||
cmake --build build -j $(nproc)
|
||||
38
.github/actions/build-docs/action.yml
vendored
Normal file
38
.github/actions/build-docs/action.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: 'Build Documentation'
|
||||
description: 'Build documentation'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup machine
|
||||
uses: ./.github/actions/setup-linux
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get install -y doxygen
|
||||
source .venv/bin/activate
|
||||
pip install -r docs/requirements.txt
|
||||
pip install . -v
|
||||
|
||||
- name: Build documentation
|
||||
shell: bash
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
cd docs
|
||||
doxygen
|
||||
make html O=-W
|
||||
|
||||
- name: Create artifact tar
|
||||
shell: bash
|
||||
run: tar -cf artifact.tar -C docs --dereference build/html index.html
|
||||
|
||||
# Do it manually because upload-pages-artifact requires gtar
|
||||
- name: Upload artifact
|
||||
id: upload-artifact
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: github-pages
|
||||
path: artifact.tar
|
||||
retention-days: 1
|
||||
if-no-files-found: error
|
||||
40
.github/actions/build-linux-release/action.yml
vendored
Normal file
40
.github/actions/build-linux-release/action.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
name: 'Build Linux wheel'
|
||||
description: 'Build Linux wheel'
|
||||
|
||||
inputs:
|
||||
build-backend:
|
||||
description: 'Build the backend mlx-cpu package'
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
arch:
|
||||
description: 'Platform architecture tag'
|
||||
required: true
|
||||
type: choice
|
||||
options:
|
||||
- x86_64
|
||||
- aarch64
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Generate package stubs
|
||||
shell: bash
|
||||
run: |
|
||||
pip install -e ".[dev]" -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- name: Build Python package
|
||||
shell: bash
|
||||
run: |
|
||||
pip install auditwheel patchelf build
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 python -m build -w
|
||||
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
|
||||
- name: Build backend package
|
||||
if: ${{ inputs.build-backend }}
|
||||
shell: bash
|
||||
run: |
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}
|
||||
25
.github/actions/build-linux/action.yml
vendored
Normal file
25
.github/actions/build-linux/action.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: 'Build and Test on Linux'
|
||||
description: 'Build and test MLX on Linux'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install Python package
|
||||
shell: sh
|
||||
env:
|
||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
DEBUG: 1
|
||||
run: pip install --no-build-isolation -e ".[dev]" -v
|
||||
|
||||
- name: Generate package stubs
|
||||
shell: sh
|
||||
run: |
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
|
||||
- name: Build CPP only
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j $(nproc)
|
||||
30
.github/actions/build-macos-release/action.yml
vendored
Normal file
30
.github/actions/build-macos-release/action.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: 'Build macOS release'
|
||||
description: 'Build MLX releases macOS'
|
||||
|
||||
inputs:
|
||||
macos-target:
|
||||
description: 'macOS build target'
|
||||
required: false
|
||||
default: '15.0'
|
||||
build-backend:
|
||||
description: 'Build the backend mlx-metal package'
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Build Python package
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install build
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 python -m build -w
|
||||
|
||||
- name: Build backend package
|
||||
if: ${{ inputs.build-backend }}
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=2 python -m build -w
|
||||
88
.github/actions/build-macos/action.yml
vendored
Normal file
88
.github/actions/build-macos/action.yml
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
name: 'Build and Test on macOS'
|
||||
description: 'Build and test MLX on macOS'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
env:
|
||||
DEBUG: 1
|
||||
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install cmake setuptools nanobind==2.4.0
|
||||
pip install -e . -v
|
||||
|
||||
- name: Generate package stubs
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
|
||||
- name: Install tests dependencies
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install numpy torch tensorflow unittest-xml-reporting
|
||||
|
||||
- name: Run Python tests
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
run: |
|
||||
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
|
||||
- name: Build example extension
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
cd examples/extensions
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext --inplace
|
||||
python test.py
|
||||
|
||||
- name: Build CPP only
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake ..
|
||||
make -j $(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Run CPP tests
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
DEVICE: gpu
|
||||
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||
METAL_DEBUG_ERROR_MODE: 0
|
||||
run: ./build/tests/tests
|
||||
|
||||
- name: Build small binary with JIT
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
mkdir -p build
|
||||
cd build
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
make -j $(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Run Python tests with JIT
|
||||
shell: bash -l {0}
|
||||
env:
|
||||
LOW_MEMORY: 1
|
||||
DEVICE: gpu
|
||||
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||
METAL_DEBUG_ERROR_MODE: 0
|
||||
run: |
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
pip install -e . -v
|
||||
python -m xmlrunner discover \
|
||||
-v python/tests \
|
||||
-o test-results/gpu_jit
|
||||
85
.github/actions/setup-linux/action.yml
vendored
Normal file
85
.github/actions/setup-linux/action.yml
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
name: 'Setup Linux Environment'
|
||||
description: 'Install dependencies for Linux builds'
|
||||
|
||||
inputs:
|
||||
toolkit:
|
||||
description: 'Which toolkit to install'
|
||||
required: false
|
||||
default: 'cpu'
|
||||
python-version:
|
||||
description: 'Version of python to set up'
|
||||
required: false
|
||||
default: '3.10'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Use ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
||||
max-size: 1GB
|
||||
|
||||
- name: Install common dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
|
||||
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Setup Python venv
|
||||
shell: bash
|
||||
run: |
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install setuptools cmake nanobind==2.4.0
|
||||
echo PATH=$PATH >> $GITHUB_ENV
|
||||
# Make cmake search .venv for nanobind
|
||||
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
|
||||
|
||||
- name: Install MPI
|
||||
shell: bash
|
||||
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
|
||||
|
||||
- name: Install CUDA toolkit
|
||||
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||
shell: bash
|
||||
env:
|
||||
# Note: the CI machine does not meet CUDA 13's driver requirement.
|
||||
# Compatibility matrix:
|
||||
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
||||
# The `nvcc` is installed into `/usr/local/cuda-VERSION/bin/nvcc` - but
|
||||
# it's *not* on the default toolkit path.
|
||||
PACKAGES: |
|
||||
{
|
||||
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
|
||||
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
|
||||
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
|
||||
}
|
||||
run: |
|
||||
export ARCH=${{ runner.arch == 'arm64' && 'arm64' || 'x86_64' }}
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y \
|
||||
libnccl2 libnccl-dev \
|
||||
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
|
||||
|
||||
- name: CUDA packages and driver report
|
||||
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get install -y ubuntu-drivers-common dkms
|
||||
echo "NVIDIA Driver Packages Available:"
|
||||
sudo ubuntu-drivers list --gpgpu
|
||||
echo "NVIDIA Driver Version:"
|
||||
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
|
||||
echo "Installed NVIDIA and CUDA packages:"
|
||||
dpkg -l | egrep "cuda|nvidia" -i
|
||||
echo "DKMS Status:"
|
||||
dkms status || echo "dkms not found"
|
||||
echo "NVIDIA-SMI Status:"
|
||||
nvidia-smi || echo "nvidia-smi not found"
|
||||
24
.github/actions/setup-macos/action.yml
vendored
Normal file
24
.github/actions/setup-macos/action.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: 'Setup macOS Environment'
|
||||
description: 'Install dependencies for macOS builds'
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
description: 'Python version to use'
|
||||
required: false
|
||||
default: '3.10'
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install Homebrew packages
|
||||
shell: sh
|
||||
run: /opt/homebrew/bin/brew install openmpi
|
||||
|
||||
- name: Verify MetalToolchain installed
|
||||
shell: bash
|
||||
run: xcodebuild -showComponent MetalToolchain
|
||||
|
||||
- uses: conda-incubator/setup-miniconda@v3
|
||||
with:
|
||||
miniconda-version: "latest"
|
||||
python-version: ${{ inputs.python-version }}
|
||||
69
.github/actions/test-linux/action.yml
vendored
Normal file
69
.github/actions/test-linux/action.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
name: 'Run Linux tests'
|
||||
|
||||
inputs:
|
||||
cpu-only:
|
||||
description: 'Skip GPU tests'
|
||||
required: false
|
||||
default: false
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Run MPI tests
|
||||
shell: bash
|
||||
run: |
|
||||
echo "::group::MPI tests"
|
||||
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run distributed tests
|
||||
if: ${{ inputs.cpu-only == 'true' }}
|
||||
shell: bash
|
||||
run: |
|
||||
echo "::group::Distributed tests"
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if grep -Fq '[WARN]' stderr.log ; then
|
||||
grep -F '[WARN]' stderr.log
|
||||
echo "Distributed ring test failed";
|
||||
exit 1;
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run Python tests - CPU
|
||||
if: ${{ inputs.cpu-only == 'true' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: cpu
|
||||
run: |
|
||||
echo "::group::Python tests - CPU"
|
||||
python -m unittest discover python/tests -v
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run Python tests - GPU
|
||||
if: ${{ inputs.cpu-only == 'false' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: gpu
|
||||
run: |
|
||||
echo "::group::Python tests - GPU"
|
||||
python -m tests discover python/tests -v
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run CPP tests - CPU
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: cpu
|
||||
run: |
|
||||
echo "::group::CPP tests - CPU"
|
||||
./build/tests/tests
|
||||
echo "::endgroup::"
|
||||
|
||||
- name: Run CPP tests - GPU
|
||||
if: ${{ inputs.cpu-only == 'false' }}
|
||||
shell: bash
|
||||
env:
|
||||
DEVICE: gpu
|
||||
run: |
|
||||
echo "::group::CPP tests - GPU"
|
||||
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||
echo "::endgroup::"
|
||||
6
.github/dependabot.yml
vendored
Normal file
6
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
27
.github/scripts/setup+build-cpp-linux-fedora-container.sh
vendored
Executable file
27
.github/scripts/setup+build-cpp-linux-fedora-container.sh
vendored
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# [Setup] Install dependencies inside the container.
|
||||
dnf update -y
|
||||
dnf install -y \
|
||||
blas-devel \
|
||||
lapack-devel \
|
||||
openblas-devel \
|
||||
make \
|
||||
cmake \
|
||||
clang \
|
||||
git
|
||||
dnf clean all
|
||||
|
||||
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
|
||||
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||
export DEBUG=1
|
||||
export CMAKE_C_COMPILER=/usr/bin/clang
|
||||
export CMAKE_CXX_COMPILER=/usr/bin/clang++
|
||||
|
||||
mkdir -p build
|
||||
pushd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j $(nproc)
|
||||
./tests/tests
|
||||
popd
|
||||
28
.github/workflows/documentation.yml
vendored
Normal file
28
.github/workflows/documentation.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Documentation
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/build-docs
|
||||
|
||||
deploy:
|
||||
needs: build
|
||||
permissions:
|
||||
pages: write
|
||||
id-token: write
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v4
|
||||
98
.github/workflows/nightly.yml
vendored
Normal file
98
.github/workflows/nightly.yml
vendored
Normal file
@@ -0,0 +1,98 @@
|
||||
name: Nightly Build
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: 33 6 * * 1-5
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build_linux_release:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.14"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
- uses: ./.github/actions/build-linux-release
|
||||
with:
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
arch: "x86_64"
|
||||
- name: Upload mlx artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: linux-wheels-${{ matrix.python_version }}
|
||||
path: wheelhouse/mlx-*.whl
|
||||
retention-days: 7
|
||||
- name: Upload mlx-cpu artifacts
|
||||
if: matrix.python_version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: mlx-cpu
|
||||
path: wheelhouse/mlx_cpu-*.whl
|
||||
retention-days: 7
|
||||
|
||||
build_linux_with_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11", "3.12", "3.13", "3.14"]
|
||||
runner:
|
||||
- ubuntu-22.04
|
||||
- ubuntu-22.04-arm
|
||||
runs-on: ${{ matrix.runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
- uses: ./.github/actions/build-linux
|
||||
- uses: ./.github/actions/test-linux
|
||||
with:
|
||||
cpu-only: true
|
||||
|
||||
build_mac_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.13"]
|
||||
runs-on: [self-hosted, macos]
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-macos
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- uses: ./.github/actions/build-macos
|
||||
- name: Build macOS 15 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 15.0
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
- name: Build macOS 14 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 14.0
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
|
||||
build_cuda_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: ubuntu-22-large
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: mlx-cuda
|
||||
path: wheelhouse/mlx_cuda-*.whl
|
||||
retention-days: 7
|
||||
103
.github/workflows/pull_request.yml
vendored
103
.github/workflows/pull_request.yml
vendored
@@ -1,20 +1,103 @@
|
||||
name: Build and Test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
# For testing CI without starting a pull request:
|
||||
- test/*
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
|
||||
|
||||
jobs:
|
||||
check_lint:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/checkout@v5
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
|
||||
linux_build_and_test:
|
||||
needs: check_lint
|
||||
strategy:
|
||||
matrix:
|
||||
runner:
|
||||
- ubuntu-22.04
|
||||
- ubuntu-22.04-arm
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.runner }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
- uses: ./.github/actions/build-linux
|
||||
- uses: ./.github/actions/test-linux
|
||||
with:
|
||||
python-version: 3.8
|
||||
- name: Install dependencies
|
||||
cpu-only: true
|
||||
|
||||
mac_build_and_test:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
matrix:
|
||||
macos-target: ["14.0", "15.0"]
|
||||
runs-on: [self-hosted, macos]
|
||||
env:
|
||||
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-macos
|
||||
- uses: ./.github/actions/build-macos
|
||||
|
||||
cuda_build_and_test:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolkit: ['cuda-12.6', 'cuda-12.9']
|
||||
runs-on: gpu-t4-4-core
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/build-cuda
|
||||
with:
|
||||
toolkit: ${{ matrix.toolkit }}
|
||||
- uses: ./.github/actions/test-linux
|
||||
|
||||
build_documentation:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: ubuntu-22.04
|
||||
needs: check_lint
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/build-docs
|
||||
|
||||
linux_fedora_build_cpp:
|
||||
name: Linux Fedora CPP Build (${{ matrix.arch }})
|
||||
needs: check_lint
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- host: ubuntu-22.04
|
||||
arch: x86_64
|
||||
- host: ubuntu-22.04-arm
|
||||
arch: aarch64
|
||||
|
||||
runs-on: ${{ matrix.host }}
|
||||
container:
|
||||
image: fedora:42
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: CPP Build Test - No Release
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pre-commit black isort clang-format
|
||||
- name: Run lint
|
||||
run: |
|
||||
pre-commit run --all-files
|
||||
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
|
||||
|
||||
239
.github/workflows/release.yml
vendored
Normal file
239
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,239 @@
|
||||
name: PyPI Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dev_release:
|
||||
description: "Do a dev release or regular release"
|
||||
required: true
|
||||
default: "false"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Set publishing variables
|
||||
run: echo "Publishing setup complete"
|
||||
|
||||
build_documentation:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: [self-hosted, macos]
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/build-docs
|
||||
|
||||
deploy_documentation:
|
||||
needs: build_documentation
|
||||
permissions:
|
||||
pages: write
|
||||
id-token: write
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v4
|
||||
|
||||
build_linux_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
arch: ['x86_64', 'aarch64']
|
||||
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
- uses: ./.github/actions/build-linux-release
|
||||
with:
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
arch: ${{ matrix.arch }}
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: linux-wheels-${{ matrix.python_version }}
|
||||
path: wheelhouse/mlx-*.whl
|
||||
- name: Upload CPU artifacts
|
||||
if: matrix.python_version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-cpu
|
||||
path: wheelhouse/mlx_cpu-*.whl
|
||||
|
||||
build_mac_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
runs-on: [self-hosted, macos]
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-macos
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install cmake setuptools nanobind==2.4.0
|
||||
pip install -e . -v
|
||||
- name: Generate package stubs
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- name: Build macOS 14 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 14.0
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
- name: Build macOS 15 package
|
||||
uses: ./.github/actions/build-macos-release
|
||||
with:
|
||||
macos-target: 15.0
|
||||
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||
- name: Upload MLX artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mac-wheels-${{ matrix.python-version }}
|
||||
path: dist/mlx-*.whl
|
||||
- name: Upload Metal artifacts
|
||||
if: matrix.python-version == '3.10'
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-metal
|
||||
path: dist/mlx_metal-*.whl
|
||||
|
||||
build_cuda_release:
|
||||
if: github.repository == 'ml-explore/mlx'
|
||||
runs-on: ubuntu-22-large
|
||||
env:
|
||||
PYPI_RELEASE: 1
|
||||
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: ./.github/actions/setup-linux
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Build Python package
|
||||
uses: ./.github/actions/build-cuda-release
|
||||
with:
|
||||
toolkit: 'cuda-12.9'
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
overwrite: true
|
||||
name: mlx-cuda
|
||||
path: wheelhouse/mlx_cuda-*.whl
|
||||
|
||||
pypi-publish:
|
||||
name: Upload release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: [setup, build_linux_release, build_mac_release]
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
pattern: linux-wheels-*
|
||||
merge-multiple: true
|
||||
path: dist
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
pattern: mac-wheels-*
|
||||
merge-multiple: true
|
||||
path: dist
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R dist
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
|
||||
pypi-publish-cuda:
|
||||
name: Upload CUDA release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: [setup, build_cuda_release]
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-cuda
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: mlx-cuda
|
||||
path: dist
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R dist
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
|
||||
pypi-publish-cpu:
|
||||
name: Upload CPU release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: [setup, build_linux_release]
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-cpu
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: mlx-cpu
|
||||
path: dist
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R dist
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
|
||||
pypi-publish-metal:
|
||||
name: Upload Metal release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: [setup, build_mac_release]
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mlx-metal
|
||||
steps:
|
||||
- uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: mlx-metal
|
||||
path: dist
|
||||
- name: Display structure of downloaded files
|
||||
run: ls -R dist
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
repository-url: https://upload.pypi.org/legacy/
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
uv.lock
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
# - id: end-of-file-fixer
|
||||
# - id: trailing-whitespace
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v19.1.7
|
||||
hooks:
|
||||
|
||||
@@ -19,11 +19,17 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
||||
# Organizations
|
||||
|
||||
MLX has received contributions from the following companies:
|
||||
- NVIDIA Corporation & Affiliates
|
||||
|
||||
# Third-Party Software
|
||||
|
||||
MLX leverages several third-party software, listed here together with
|
||||
|
||||
@@ -26,6 +26,7 @@ set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
# ----------------------------- Configuration -----------------------------
|
||||
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
||||
@@ -34,13 +35,16 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
message(
|
||||
@@ -63,10 +67,18 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
if(MLX_USE_CCACHE)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
|
||||
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
@@ -77,18 +89,26 @@ cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
set(METAL_LIB "-framework Metal")
|
||||
set(FOUNDATION_LIB "-framework Foundation")
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
# Supress warnings: note: parameter passing for argument of type
|
||||
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||
# 10.1
|
||||
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
set(MLX_METAL_DEBUG OFF)
|
||||
elseif(MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
if(MLX_BUILD_METAL)
|
||||
find_library(METAL_LIB Metal)
|
||||
find_library(FOUNDATION_LIB Foundation)
|
||||
find_library(QUARTZ_LIB QuartzCore)
|
||||
if(METAL_LIB)
|
||||
message(STATUS "Metal found ${METAL_LIB}")
|
||||
else()
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
||||
endif()
|
||||
|
||||
if(MLX_METAL_DEBUG)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
@@ -97,7 +117,8 @@ elseif(MLX_BUILD_METAL)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
@@ -107,9 +128,12 @@ elseif(MLX_BUILD_METAL)
|
||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
|
||||
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
|
||||
message(FATAL_ERROR "MLX requires macOS >= 14.0")
|
||||
endif()
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
execute_process(
|
||||
@@ -118,7 +142,6 @@ elseif(MLX_BUILD_METAL)
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
@@ -126,6 +149,12 @@ elseif(MLX_BUILD_METAL)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||
# With newer clang/gcc versions following libs are implicitly linked, but when
|
||||
# building on old distributions they need to be explicitly listed.
|
||||
target_link_libraries(mlx PRIVATE dl pthread)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
if(MSVC)
|
||||
# GGUF does not build with MSVC.
|
||||
@@ -153,7 +182,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
message(STATUS "Accelerate not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
@@ -226,12 +255,19 @@ target_include_directories(
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
# Do not add mlx_EXPORTS define for shared library.
|
||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||
|
||||
if(USE_SYSTEM_FMT)
|
||||
find_package(fmt REQUIRED)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
endif()
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
include CMakeLists.txt
|
||||
include mlx.pc.in
|
||||
recursive-include mlx/ *
|
||||
include cmake/*
|
||||
include python/src/*
|
||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||
|
||||
57
README.md
57
README.md
@@ -2,7 +2,7 @@
|
||||
|
||||
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
|
||||
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
|
||||
[**Examples**](#examples)
|
||||
[**Examples**](#examples)
|
||||
|
||||
[](https://circleci.com/gh/ml-explore/mlx)
|
||||
|
||||
@@ -11,37 +11,37 @@ brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||
more complex models.
|
||||
|
||||
- **Composable function transformations**: MLX supports composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
and computation graph optimization.
|
||||
- **Composable function transformations**: MLX supports composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
and computation graph optimization.
|
||||
|
||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||
materialized when needed.
|
||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||
materialized when needed.
|
||||
|
||||
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||
dynamically. Changing the shapes of function arguments does not trigger
|
||||
slow compilations, and debugging is simple and intuitive.
|
||||
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||
dynamically. Changing the shapes of function arguments does not trigger
|
||||
slow compilations, and debugging is simple and intuitive.
|
||||
|
||||
- **Multi-device**: Operations can run on any of the supported devices
|
||||
(currently the CPU and the GPU).
|
||||
- **Multi-device**: Operations can run on any of the supported devices
|
||||
(currently the CPU and the GPU).
|
||||
|
||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||
Operations on MLX arrays can be performed on any of the supported
|
||||
device types without transferring data.
|
||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||
Operations on MLX arrays can be performed on any of the supported
|
||||
device types without transferring data.
|
||||
|
||||
MLX is designed by machine learning researchers for machine learning
|
||||
researchers. The framework is intended to be user-friendly, but still efficient
|
||||
to train and deploy models. The design of the framework itself is also
|
||||
conceptually simple. We intend to make it easy for researchers to extend and
|
||||
improve MLX with the goal of quickly exploring new ideas.
|
||||
improve MLX with the goal of quickly exploring new ideas.
|
||||
|
||||
The design of MLX is inspired by frameworks like
|
||||
[NumPy](https://numpy.org/doc/stable/index.html),
|
||||
@@ -68,25 +68,30 @@ in the documentation.
|
||||
|
||||
## Installation
|
||||
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||
macOS, run:
|
||||
|
||||
**With `pip`**:
|
||||
|
||||
```
|
||||
```bash
|
||||
pip install mlx
|
||||
```
|
||||
|
||||
**With `conda`**:
|
||||
To install the CUDA backend on Linux, run:
|
||||
|
||||
```bash
|
||||
pip install mlx[cuda]
|
||||
```
|
||||
conda install -c conda-forge mlx
|
||||
|
||||
To install a CPU-only Linux package, run:
|
||||
|
||||
```bash
|
||||
pip install mlx[cpu]
|
||||
```
|
||||
|
||||
Checkout the
|
||||
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
||||
for more information on building the C++ and Python APIs from source.
|
||||
|
||||
## Contributing
|
||||
## Contributing
|
||||
|
||||
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
||||
on contributing to MLX. See the
|
||||
@@ -105,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
||||
MLX useful in your research and wish to cite it, please use the following
|
||||
BibTex entry:
|
||||
|
||||
```
|
||||
```text
|
||||
@software{mlx2023,
|
||||
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
||||
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
@@ -74,7 +75,7 @@ void time_irregular_binary_ops_3D() {
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
mx::Shape shape = {8, 8, 512, 512};
|
||||
auto a = mx::random::uniform(shape);
|
||||
auto b = mx::random::uniform(shape);
|
||||
|
||||
@@ -114,7 +115,7 @@ void time_irregular_binary_ops_4D() {
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape;
|
||||
mx::Shape shape;
|
||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||
return mx::reshape(a, shape, device);
|
||||
};
|
||||
@@ -169,7 +170,7 @@ void time_irregular_astype_1D() {
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
mx::Shape shape = {size, size};
|
||||
|
||||
auto a = mx::random::uniform(shape);
|
||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||
|
||||
@@ -192,6 +192,22 @@ void time_reductions() {
|
||||
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
|
||||
auto indices = mx::array({1});
|
||||
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||
std::vector<int> axes{0};
|
||||
auto b = scatter(a, {indices}, updates, axes);
|
||||
mx::eval(b);
|
||||
|
||||
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||
TIME(max_along_0);
|
||||
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||
TIME(max_along_1);
|
||||
|
||||
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||
TIME(min_along_0);
|
||||
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||
TIME(min_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
|
||||
@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
||||
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
||||
|
||||
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
||||
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
|
||||
np.float32
|
||||
)
|
||||
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
dtypes = ("float32", "float16")
|
||||
dtypes = ("float32", "float16", "complex64")
|
||||
transposes = ("nn", "nt", "tn")
|
||||
shapes = (
|
||||
(16, 234, 768, 3072),
|
||||
@@ -187,7 +185,7 @@ if __name__ == "__main__":
|
||||
diff = gflops_mx / gflops_pt - 1.0
|
||||
|
||||
print(
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
|
||||
)
|
||||
if gflops_pt >= 2.0 * gflops_mx:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
@@ -196,7 +195,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
||||
|
||||
|
||||
for transpose in (False, True):
|
||||
for dtype in ("float32", "float16"):
|
||||
for dtype in ("float32", "float16", "complex64"):
|
||||
fig, axs = plt.subplots(
|
||||
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
||||
)
|
||||
@@ -215,7 +214,7 @@ for transpose in (False, True):
|
||||
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
||||
fig.savefig(
|
||||
os.path.join(
|
||||
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
|
||||
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
|
||||
)
|
||||
)
|
||||
plt.close(fig)
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.mps
|
||||
|
||||
|
||||
@@ -44,8 +45,10 @@ def bench(f, *args):
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
if x.device != torch.device("cpu"):
|
||||
if x.device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
elif x.device == torch.device("cuda"):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
@@ -340,7 +351,11 @@ if __name__ == "__main__":
|
||||
args.axis.pop(0)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = "cpu" if args.cpu else "mps"
|
||||
device = "mps"
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
if args.cpu:
|
||||
device = "cpu"
|
||||
|
||||
types = args.dtype
|
||||
if not types:
|
||||
@@ -460,5 +475,8 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
|
||||
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype = "float32"
|
||||
shapes = (
|
||||
(4, 32, 32, 21, 3, 3, 128),
|
||||
(4, 32, 32, 21, 3, 3, 37),
|
||||
(4, 32, 32, 370, 3, 3, 370),
|
||||
(4, 32, 32, 370, 7, 7, 128),
|
||||
(2, 320, 640, 21, 7, 7, 21),
|
||||
)
|
||||
for N, H, W, C, kh, kw, O in shapes:
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
@@ -1,5 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
@@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
|
||||
return y
|
||||
|
||||
|
||||
def time_layer_norm():
|
||||
def time_layer_norm(N, dt):
|
||||
L = 1024
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x, w, b):
|
||||
def layer_norm_loop(f, x, w, b):
|
||||
for _ in range(32):
|
||||
x = f(x, w, b)
|
||||
return x
|
||||
|
||||
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||
|
||||
def layer_norm_grad_loop(g, x, w, b):
|
||||
gx, gw, gb = x, w, b
|
||||
for _ in range(32):
|
||||
gx, gw, gb = g(gx, gw, gb, y)
|
||||
return gx, gw, gb
|
||||
|
||||
time_fn(layer_norm_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x):
|
||||
def layer_norm_grad_x_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(layer_norm_loop, g1, x)
|
||||
time_fn(layer_norm_loop, g2, x)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_layer_norm()
|
||||
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||
print(dt, n)
|
||||
time_layer_norm(n, dt)
|
||||
|
||||
212
benchmarks/python/masked_scatter.py
Normal file
212
benchmarks/python/masked_scatter.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.ticker import FuncFormatter
|
||||
|
||||
RESULTS_DIR = "./results"
|
||||
|
||||
|
||||
if not os.path.isdir(RESULTS_DIR):
|
||||
os.mkdir(RESULTS_DIR)
|
||||
|
||||
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
|
||||
|
||||
TORCH_DEVICE = torch.device(
|
||||
"mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
)
|
||||
|
||||
|
||||
N_WARMUP = 5
|
||||
N_ITER_BENCH = 50
|
||||
N_ITER_FUNC = 20
|
||||
|
||||
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
|
||||
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
|
||||
D_TYPES = ("float32", "float16")
|
||||
|
||||
|
||||
def _power_of_two_formatter(value, _position):
|
||||
if value <= 0:
|
||||
return ""
|
||||
exponent = int(round(math.log2(value)))
|
||||
if abs(value - (1 << exponent)) / value > 1e-6:
|
||||
return f"{value:g}"
|
||||
return f"$2^{{{exponent}}}$"
|
||||
|
||||
|
||||
def torch_sync():
|
||||
if TORCH_DEVICE.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif TORCH_DEVICE.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
|
||||
outs = []
|
||||
for _ in range(N_ITER_FUNC):
|
||||
out = copy(self_arr)
|
||||
out[mask_arr] = src_arr
|
||||
outs.append(out)
|
||||
mx.eval(outs)
|
||||
return outs
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
|
||||
outs = []
|
||||
for _ in range(N_ITER_FUNC):
|
||||
out = self_tensor.clone()
|
||||
out.masked_scatter_(mask_tensor, src_tensor)
|
||||
outs.append(out)
|
||||
torch_sync()
|
||||
return outs
|
||||
|
||||
|
||||
def measure(fn):
|
||||
for _ in range(N_WARMUP):
|
||||
fn()
|
||||
start = time.perf_counter_ns()
|
||||
for _ in range(N_ITER_BENCH):
|
||||
fn()
|
||||
end = time.perf_counter_ns()
|
||||
return (end - start) * 1e-9
|
||||
|
||||
|
||||
def bytes_touched(length, true_count, item_size):
|
||||
mask_bytes = length
|
||||
self_bytes = length * item_size * 2 # read + write
|
||||
src_bytes = true_count * item_size
|
||||
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
|
||||
|
||||
|
||||
def build_case(length, density, np_dtype, torch_dtype):
|
||||
true_count = max(1, int(round(length * density)))
|
||||
|
||||
rng = np.random.default_rng()
|
||||
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
|
||||
mask_np = np.zeros(length, dtype=bool)
|
||||
mask_np[:true_count] = True
|
||||
rng.shuffle(mask_np)
|
||||
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
|
||||
|
||||
self_mlx = mx.array(self_np)
|
||||
mask_mlx = mx.array(mask_np)
|
||||
src_mlx = mx.array(src_np)
|
||||
|
||||
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
|
||||
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||
|
||||
# Correctness check once per configuration
|
||||
mx_out = mx.array(self_np)
|
||||
mx_out[mask_mlx] = src_mlx
|
||||
mx.eval(mx_out)
|
||||
torch_out = self_torch.clone()
|
||||
torch_out.masked_scatter_(mask_torch, src_torch)
|
||||
|
||||
atol = 5e-3 if np_dtype == np.float16 else 1e-5
|
||||
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
|
||||
raise AssertionError("masked_scatter results diverged between MLX and Torch")
|
||||
|
||||
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
|
||||
|
||||
|
||||
def bench_case(length, density, dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
torch_dtype = getattr(torch, dtype)
|
||||
(
|
||||
self_mlx,
|
||||
mask_mlx,
|
||||
src_mlx,
|
||||
self_torch,
|
||||
mask_torch,
|
||||
src_torch,
|
||||
true_count,
|
||||
) = build_case(length, density, np_dtype, torch_dtype)
|
||||
|
||||
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
|
||||
time_torch = measure(
|
||||
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
|
||||
)
|
||||
|
||||
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
|
||||
bytes_per_gb = float(1024**3)
|
||||
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
|
||||
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
|
||||
|
||||
return time_mlx, time_torch, mlx_gbps, torch_gbps
|
||||
|
||||
|
||||
def plot_density(ax_perf, ax_speedup, density, dtype):
|
||||
mlx_gbps = []
|
||||
torch_gbps = []
|
||||
mlx_times = []
|
||||
torch_times = []
|
||||
|
||||
for length in VECTOR_LENGTHS:
|
||||
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
|
||||
mlx_gbps.append(gbps_mlx)
|
||||
torch_gbps.append(gbps_torch)
|
||||
mlx_times.append(t_mlx)
|
||||
torch_times.append(t_torch)
|
||||
|
||||
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
|
||||
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
|
||||
ax_perf.set_xscale("log", base=2)
|
||||
ax_perf.set_xticks(VECTOR_LENGTHS)
|
||||
formatter = FuncFormatter(_power_of_two_formatter)
|
||||
ax_perf.xaxis.set_major_formatter(formatter)
|
||||
ax_perf.set_title(f"density={density:.2f}")
|
||||
ax_perf.set_ylabel("GB/s")
|
||||
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||
ax_perf.legend()
|
||||
|
||||
speedup = np.array(torch_times) / np.array(mlx_times)
|
||||
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
|
||||
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
|
||||
ax_speedup.set_xscale("log", base=2)
|
||||
ax_speedup.set_xticks(VECTOR_LENGTHS)
|
||||
ax_speedup.xaxis.set_major_formatter(formatter)
|
||||
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
|
||||
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||
|
||||
|
||||
def main():
|
||||
for dtype in D_TYPES:
|
||||
fig, axs = plt.subplots(
|
||||
len(MASK_DENSITIES),
|
||||
2,
|
||||
figsize=(10, 12),
|
||||
layout="constrained",
|
||||
sharex=True,
|
||||
)
|
||||
|
||||
for i, density in enumerate(MASK_DENSITIES):
|
||||
plot_density(axs[i][0], axs[i][1], density, dtype)
|
||||
axs[i][0].set_xlabel("vector length")
|
||||
axs[i][1].set_xlabel("vector length")
|
||||
|
||||
fig.suptitle(
|
||||
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
|
||||
)
|
||||
output_path = os.path.join(
|
||||
RESULTS_DIR,
|
||||
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
|
||||
)
|
||||
fig.savefig(output_path)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -51,6 +51,20 @@ def time_maximum():
|
||||
time_fn(mx.maximum, a, b)
|
||||
|
||||
|
||||
def time_max():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.max, a, 0)
|
||||
|
||||
|
||||
def time_min():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.min, a, 0)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
@@ -108,6 +122,8 @@ if __name__ == "__main__":
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_min()
|
||||
time_max()
|
||||
time_maximum()
|
||||
time_exp()
|
||||
time_negative()
|
||||
|
||||
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@@ -0,0 +1,54 @@
|
||||
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
||||
# directories.
|
||||
|
||||
set(NCCL_ROOT_DIR
|
||||
$ENV{NCCL_ROOT_DIR}
|
||||
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||
|
||||
find_path(
|
||||
NCCL_INCLUDE_DIRS
|
||||
NAMES nccl.h
|
||||
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||
|
||||
if($ENV{USE_STATIC_NCCL})
|
||||
message(
|
||||
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||
set(NCCL_LIBNAME "libnccl_static.a")
|
||||
else()
|
||||
set(NCCL_LIBNAME "nccl")
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
NCCL_LIBRARIES
|
||||
NAMES ${NCCL_LIBNAME}
|
||||
HINTS ${NCCL_LIB_DIR}
|
||||
${NCCL_ROOT_DIR}
|
||||
${NCCL_ROOT_DIR}/lib
|
||||
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||
${NCCL_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||
NCCL_LIBRARIES)
|
||||
|
||||
if(NCCL_FOUND)
|
||||
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||
message(
|
||||
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||
file(
|
||||
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||
LIMIT_COUNT 1)
|
||||
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||
endif()
|
||||
message(
|
||||
STATUS
|
||||
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||
endif()
|
||||
3
cmake/Findnvpl.cmake
Normal file
3
cmake/Findnvpl.cmake
Normal file
@@ -0,0 +1,3 @@
|
||||
# This file does nothing but to suppress the cmake warning: "By not providing
|
||||
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
|
||||
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.
|
||||
@@ -11,13 +11,14 @@ include(CMakeParseArguments)
|
||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||
# files (like headers)
|
||||
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
||||
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
||||
#
|
||||
# clang format on
|
||||
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
@@ -26,6 +27,10 @@ macro(mlx_build_metallib)
|
||||
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
||||
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
||||
-frecord-sources)
|
||||
endif()
|
||||
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
sphinx
|
||||
breathe
|
||||
sphinx-book-theme
|
||||
sphinx-copybutton
|
||||
mlx
|
||||
|
||||
@@ -10,7 +10,7 @@ import mlx.core as mx
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "MLX"
|
||||
copyright = "2023, MLX Contributors"
|
||||
copyright = "2023, Apple"
|
||||
author = "MLX Contributors"
|
||||
version = ".".join(mx.__version__.split(".")[:3])
|
||||
release = version
|
||||
@@ -18,6 +18,7 @@ release = version
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
extensions = [
|
||||
"sphinx_copybutton",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.intersphinx",
|
||||
|
||||
@@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Every time you make a kernel, a new Metal library is created and possibly
|
||||
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||
:func:`fast.metal_kernel` and then use it many times.
|
||||
|
||||
.. note::
|
||||
We are only required to pass the body of the Metal kernel in ``source``.
|
||||
Only pass the body of the Metal kernel in ``source``. The function
|
||||
signature is generated automatically.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
@@ -78,44 +86,52 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||
dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||
generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||
when indexing.
|
||||
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||
is ``True`` by default. This will copy the array inputs if needed
|
||||
before the kernel is launched to ensure that the memory layout is row
|
||||
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||
have to worry about gaps or the ordering of the dims when indexing.
|
||||
|
||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||
input array ``a`` if any are present in ``source``.
|
||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||
the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||
relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@@ -123,7 +139,6 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
@@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
return output
|
||||
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
@@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||
define its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
requires a few extra :func:`fast.metal_kernel` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
@@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
|
||||
@@ -138,13 +138,13 @@ more concrete:
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
/** The name of primitive. */
|
||||
const char* name() const override {
|
||||
return "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
@@ -394,14 +394,14 @@ below.
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
std::stream kname;
|
||||
kname = "axpby_general_" + type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname, lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
@@ -70,6 +70,7 @@ are the CPU and GPU.
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/cuda
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
|
||||
@@ -13,22 +13,48 @@ silicon computer is
|
||||
|
||||
pip install mlx
|
||||
|
||||
To install from PyPI you must meet the following requirements:
|
||||
To install from PyPI your system must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.9
|
||||
- macOS >= 13.5
|
||||
- Using a native Python >= 3.10
|
||||
- macOS >= 14.0
|
||||
|
||||
.. note::
|
||||
MLX is only available on devices running macOS >= 13.5
|
||||
It is highly recommended to use macOS 14 (Sonoma)
|
||||
MLX is only available on devices running macOS >= 14.0 and higher.
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
MLX is also available on conda-forge. To install MLX with conda do:
|
||||
MLX has a CUDA backend which you can install with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
conda install conda-forge::mlx
|
||||
pip install mlx[cuda]
|
||||
|
||||
To install the CUDA package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Nvidia architecture >= SM 7.0 (Volta)
|
||||
- Nvidia driver >= 550.54.14
|
||||
- CUDA toolkit >= 12.0
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.10
|
||||
|
||||
|
||||
CPU-only (Linux)
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
For a CPU-only version of MLX that runs on Linux use:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx[cpu]
|
||||
|
||||
To install the CPU-only package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.10
|
||||
|
||||
|
||||
Troubleshooting
|
||||
@@ -65,6 +91,8 @@ Build Requirements
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
|
||||
.. _python install:
|
||||
|
||||
To build and install the MLX python library from source, first, clone MLX from
|
||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
@@ -76,20 +104,20 @@ Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
||||
pip install .
|
||||
|
||||
For developing, install the package with development dependencies, and use an
|
||||
editable install:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
||||
pip install -e ".[dev]"
|
||||
|
||||
Once the development dependencies are installed, you can build faster with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
Run the tests with:
|
||||
|
||||
@@ -107,6 +135,8 @@ IDE:
|
||||
C++ API
|
||||
^^^^^^^
|
||||
|
||||
.. _cpp install:
|
||||
|
||||
Currently, MLX must be built and installed from source.
|
||||
|
||||
Similarly to the python library, to build and install the MLX C++ library start
|
||||
@@ -185,6 +215,7 @@ should point to the path to the built metal library.
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -213,6 +244,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Linux
|
||||
^^^^^
|
||||
|
||||
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||
For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
apt-get update -y
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
From here follow the instructions to install either the :ref:`Python <python
|
||||
install>` or :ref:`C++ <cpp install>` APIs.
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
apt-get update -y
|
||||
apt-get -y install cuda-toolkit-12-9
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
||||
|
||||
|
||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||
|
||||
To build the C++ package run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ Array
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
array.real
|
||||
array.imag
|
||||
array.abs
|
||||
array.all
|
||||
array.any
|
||||
|
||||
9
docs/src/python/cuda.rst
Normal file
9
docs/src/python/cuda.rst
Normal file
@@ -0,0 +1,9 @@
|
||||
CUDA
|
||||
=====
|
||||
|
||||
.. currentmodule:: mlx.core.cuda
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
is_available
|
||||
@@ -13,3 +13,4 @@ Fast
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
metal_kernel
|
||||
cuda_kernel
|
||||
|
||||
@@ -20,3 +20,5 @@ FFT
|
||||
irfft2
|
||||
rfftn
|
||||
irfftn
|
||||
fftshift
|
||||
ifftshift
|
||||
|
||||
@@ -16,6 +16,8 @@ Linear Algebra
|
||||
cross
|
||||
qr
|
||||
svd
|
||||
eigvals
|
||||
eig
|
||||
eigvalsh
|
||||
eigh
|
||||
lu
|
||||
|
||||
@@ -27,6 +27,7 @@ simple functions.
|
||||
mish
|
||||
prelu
|
||||
relu
|
||||
relu2
|
||||
relu6
|
||||
selu
|
||||
sigmoid
|
||||
|
||||
@@ -50,6 +50,7 @@ Layers
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU2
|
||||
ReLU6
|
||||
RNN
|
||||
RoPE
|
||||
|
||||
@@ -112,6 +112,7 @@ Operations
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
median
|
||||
meshgrid
|
||||
min
|
||||
minimum
|
||||
|
||||
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Save the state
|
||||
state = tree_flatten(optimizer.state)
|
||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||
state = tree_flatten(optimizer.state, destination={})
|
||||
mx.save_safetensors("optimizer.safetensors", state)
|
||||
|
||||
# Later on, for example when loading from a checkpoint,
|
||||
# recreate the optimizer and load the state
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||
optimizer.state = state
|
||||
|
||||
Note, not every optimizer configuation parameter is saved in the state. For
|
||||
|
||||
@@ -19,3 +19,4 @@ Common Optimizers
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
Muon
|
||||
|
||||
@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
|
||||
.. code-block:: python
|
||||
|
||||
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||
timeit(nn.gelu, x)
|
||||
timeit(mx.compile(nn.gelu), x)
|
||||
timeit(gelu, x)
|
||||
timeit(mx.compile(gelu), x)
|
||||
|
||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||
five times faster.
|
||||
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z), state
|
||||
return mx.exp(z)
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
|
||||
@@ -7,12 +7,13 @@ Distributed Communication
|
||||
|
||||
MLX supports distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. At the
|
||||
moment we support two different communication backends:
|
||||
moment we support three different communication backends:
|
||||
|
||||
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
||||
full-featured and mature distributed communications library
|
||||
* A **ring** backend of our own that uses native TCP sockets and should be
|
||||
faster for thunderbolt connections.
|
||||
* A **ring** backend of our own that uses native TCP sockets. It should be
|
||||
faster for thunderbolt connections, but it also works over Ethernet.
|
||||
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
|
||||
|
||||
The list of all currently supported operations and their documentation can be
|
||||
seen in the :ref:`API docs<distributed>`.
|
||||
@@ -84,9 +85,8 @@ Selecting Backend
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
You can select the backend you want to use when calling :func:`init` by passing
|
||||
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
|
||||
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
|
||||
both fail then a singleton group is created.
|
||||
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
|
||||
available backends. If they all fail then a singleton group is created.
|
||||
|
||||
.. note::
|
||||
After a distributed backend is successfully initialized :func:`init` will
|
||||
@@ -184,7 +184,7 @@ almost identical to the example above:
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
||||
grads = mx.nn.average_gradients(grads) # <---- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
@@ -220,7 +220,7 @@ print 4 etc.
|
||||
Installing MPI
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
@@ -228,14 +228,16 @@ with the Anaconda package manager as follows:
|
||||
|
||||
$ conda install conda-forge::openmpi
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
||||
done automatically by ``mlx.launch``.
|
||||
done automatically by ``mlx.launch``. Some environments use a non-standard
|
||||
library filename that can be specified using the ``MPI_LIBNAME`` environment
|
||||
variable. This is automatically taken care of by ``mlx.launch`` as well.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
|
||||
$ # or simply
|
||||
$ mlx.launch -n 2 test.py
|
||||
|
||||
|
||||
@@ -7,17 +7,17 @@ Exporting Functions
|
||||
|
||||
MLX has an API to export and import functions to and from a file. This lets you
|
||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||
front-end (e.g. C++).
|
||||
front-end (e.g. C++).
|
||||
|
||||
This guide walks through the basics of the MLX export API with some examples.
|
||||
To see the full list of functions check-out the :ref:`API documentation
|
||||
<export>`.
|
||||
|
||||
Basics of Exporting
|
||||
Basics of Exporting
|
||||
-------------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
|
||||
# Both arguments to fun are positional
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
|
||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||
they are evaluated. The computation graph that gets exported will include
|
||||
the computation that produces enclosed inputs.
|
||||
|
||||
|
||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||
exported function would include the random initialization of the
|
||||
:obj:`mlx.nn.Module` parameters.
|
||||
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
||||
# Set the model's parameters to the input parameters
|
||||
model.update(tree_unflatten(list(params.items())))
|
||||
return model(x)
|
||||
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
|
||||
params = tree_flatten(model.parameters(), destination={})
|
||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||
|
||||
|
||||
@@ -164,13 +164,13 @@ to export a function which can be used for inputs with variable shapes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
|
||||
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
|
||||
imported_abs = mx.import_function("fun.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_abs(mx.array(-1.0))
|
||||
|
||||
# Also ok
|
||||
out, = imported_abs(mx.array([-1.0]))
|
||||
|
||||
# Also ok
|
||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||
|
||||
With ``shapeless=False`` (which is the default), the second call to
|
||||
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
||||
def fun(x, y=None):
|
||||
constant = mx.array(3.0)
|
||||
if y is not None:
|
||||
x += y
|
||||
x += y
|
||||
return x + constant
|
||||
|
||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
||||
print(out)
|
||||
|
||||
In the above example the function constant data, (i.e. ``constant``), is only
|
||||
saved once.
|
||||
saved once.
|
||||
|
||||
Transformations with Imported Functions
|
||||
---------------------------------------
|
||||
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
|
||||
# Prints: array(1, dtype=float32)
|
||||
print(dfdx(x))
|
||||
|
||||
# Compile the imported function
|
||||
# Compile the imported function
|
||||
mx.compile(imported_fun)
|
||||
# Prints: array(0, dtype=float32)
|
||||
print(compiled_fun(x)[0])
|
||||
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
|
||||
// Prints: array(2, dtype=float32)
|
||||
std::cout << outputs[0] << std::endl;
|
||||
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||
|
||||
|
||||
@@ -70,7 +70,8 @@ Differences from NumPy
|
||||
|
||||
* Indexing does not perform bounds checking. Indexing out of bounds is
|
||||
undefined behavior.
|
||||
* Boolean mask based indexing is not yet supported.
|
||||
* Boolean mask based indexing is supported for assignment only (see
|
||||
:ref:`boolean-mask-assignment`).
|
||||
|
||||
The reason for the lack of bounds checking is that exceptions cannot propagate
|
||||
from the GPU. Performing bounds checking for array indices before launching the
|
||||
@@ -107,6 +108,28 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
Note that unlike NumPy, slicing an array creates a copy, not a view. So
|
||||
mutating it does not mutate the original array:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> b = a[:]
|
||||
>>> b[2] = 0
|
||||
>>> b
|
||||
array([1, 2, 0], dtype=int32)
|
||||
>>> a
|
||||
array([1, 2, 3], dtype=int32)
|
||||
|
||||
Also unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> a[[0, 0]] = mx.array([4, 5])
|
||||
|
||||
The first element of ``a`` could be ``4`` or ``5``.
|
||||
|
||||
Transformations of functions which use in-place updates are allowed and work as
|
||||
expected. For example:
|
||||
|
||||
@@ -121,3 +144,51 @@ expected. For example:
|
||||
|
||||
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
||||
and ones elsewhere.
|
||||
|
||||
.. _boolean-mask-assignment:
|
||||
|
||||
Boolean Mask Assignment
|
||||
-----------------------
|
||||
|
||||
MLX supports boolean indices using NumPy syntax. A mask must already be
|
||||
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
|
||||
Other index types are routed through the standard scatter code.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1.0, 2.0, 3.0])
|
||||
>>> mask = mx.array([True, False, True])
|
||||
>>> updates = mx.array([5.0, 6.0])
|
||||
>>> a[mask] = updates
|
||||
>>> a
|
||||
array([5.0, 2.0, 6.0], dtype=float32)
|
||||
|
||||
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
|
||||
assignments, ``updates`` must provide at least as many elements as there are
|
||||
``True`` entries in ``mask``.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.zeros((2, 3))
|
||||
>>> mask = mx.array([[True, False, True],
|
||||
[False, False, True]])
|
||||
>>> a[mask] = 1.0
|
||||
>>> a
|
||||
array([[1.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0]], dtype=float32)
|
||||
|
||||
Boolean masks follow NumPy semantics:
|
||||
|
||||
- The mask shape must match the shape of the axes it indexes exactly. No mask
|
||||
broadcasting occurs.
|
||||
- Any axes not covered by the mask are taken in full.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
||||
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||
|
||||
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
||||
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
||||
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
|
||||
axes and therefore raise errors.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
@@ -16,6 +17,19 @@
|
||||
|
||||
namespace my_ext {
|
||||
|
||||
// A helper function to find the location of the current binary on disk.
|
||||
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
|
||||
std::string current_binary_dir() {
|
||||
static std::string binary_dir = []() {
|
||||
Dl_info info;
|
||||
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||
throw std::runtime_error("Unable to get current binary dir.");
|
||||
}
|
||||
return std::filesystem::path(info.dli_fname).parent_path().string();
|
||||
}();
|
||||
return binary_dir;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -167,16 +181,15 @@ void Axpby::eval_gpu(
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_";
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
std::string kname = "axpby_";
|
||||
kname += (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname += type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
auto kernel = d.get_kernel(kname, lib);
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
/** The name of primitive. */
|
||||
const char* name() const override {
|
||||
return "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.25
|
||||
mlx>=0.21.0
|
||||
nanobind==2.2.0
|
||||
nanobind==2.4.0
|
||||
|
||||
@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
|
||||
|
||||
a = mx.ones((3, 4))
|
||||
b = mx.ones((3, 4))
|
||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c shape: {c_cpu.shape}")
|
||||
print(f"c dtype: {c_cpu.dtype}")
|
||||
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
|
||||
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
|
||||
|
||||
@@ -21,7 +21,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
# Define MLX_VERSION only in the version.cpp file.
|
||||
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||
|
||||
@@ -49,5 +49,19 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||
endif()
|
||||
|
||||
@@ -14,7 +14,7 @@ class Buffer {
|
||||
void* ptr_;
|
||||
|
||||
public:
|
||||
Buffer(void* ptr) : ptr_(ptr) {};
|
||||
explicit Buffer(void* ptr) : ptr_(ptr) {};
|
||||
|
||||
// Get the raw data pointer from the buffer
|
||||
void* raw_ptr();
|
||||
|
||||
@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
|
||||
other.strides(),
|
||||
other.flags(),
|
||||
[](auto) {});
|
||||
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||
cpy.array_desc_->offset = other.array_desc_->offset;
|
||||
return cpy;
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ bool array::is_tracer() const {
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->offset = 0;
|
||||
array_desc_->data_size = size();
|
||||
array_desc_->flags.contiguous = true;
|
||||
array_desc_->flags.row_contiguous = true;
|
||||
@@ -156,7 +156,7 @@ void array::set_data(
|
||||
Flags flags,
|
||||
Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->offset = 0;
|
||||
array_desc_->data_size = data_size;
|
||||
array_desc_->strides = std::move(strides);
|
||||
array_desc_->flags = flags;
|
||||
@@ -167,14 +167,13 @@ void array::copy_shared_buffer(
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
int64_t offset /* = 0 */) {
|
||||
array_desc_->data = other.array_desc_->data;
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
array_desc_->data_size = data_size;
|
||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||
array_desc_->data_ptr = static_cast<void*>(
|
||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||
array_desc_->offset =
|
||||
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
|
||||
}
|
||||
|
||||
void array::copy_shared_buffer(const array& other) {
|
||||
@@ -241,8 +240,8 @@ array::ArrayDesc::ArrayDesc(
|
||||
std::vector<array> inputs)
|
||||
: shape(std::move(shape)),
|
||||
dtype(dtype),
|
||||
status(Status::unscheduled),
|
||||
primitive(std::move(primitive)),
|
||||
status(Status::unscheduled),
|
||||
inputs(std::move(inputs)) {
|
||||
init();
|
||||
}
|
||||
|
||||
36
mlx/array.h
36
mlx/array.h
@@ -10,6 +10,7 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/small_vector.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -18,8 +19,8 @@ class Primitive;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using ShapeElem = int32_t;
|
||||
using Shape = std::vector<ShapeElem>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
using Shape = SmallVector<ShapeElem>;
|
||||
using Strides = SmallVector<int64_t>;
|
||||
|
||||
class array {
|
||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||
@@ -224,6 +225,10 @@ class array {
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||
o.buffer = allocator::Buffer(nullptr);
|
||||
o.d = [](allocator::Buffer) {};
|
||||
}
|
||||
~Data() {
|
||||
d(buffer);
|
||||
}
|
||||
@@ -289,6 +294,11 @@ class array {
|
||||
return array_desc_->siblings;
|
||||
}
|
||||
|
||||
/** The array's position in the sibling list. */
|
||||
int sibling_position() const {
|
||||
return array_desc_->position;
|
||||
}
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
array_desc_->position = position;
|
||||
@@ -344,19 +354,27 @@ class array {
|
||||
return array_desc_->data;
|
||||
}
|
||||
|
||||
// Return a raw pointer to the arrays data
|
||||
// Return a raw pointer to the arrays data. This function may do a copy if
|
||||
// the underlying buffer is not accessible on the CPU. When accessing the
|
||||
// data for GPU kernels, be sure to use the correct method / function for the
|
||||
// given backend to access the GPU pointer.
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
return reinterpret_cast<T*>(
|
||||
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* data() const {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
return const_cast<array&>(*this).data<T>();
|
||||
}
|
||||
|
||||
int64_t offset() const {
|
||||
return array_desc_->offset;
|
||||
}
|
||||
|
||||
enum Status {
|
||||
// The ouptut of a computation which has not been scheduled.
|
||||
// The output of a computation which has not been scheduled.
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
@@ -421,7 +439,7 @@ class array {
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
int64_t offset = 0);
|
||||
|
||||
void copy_shared_buffer(const array& other);
|
||||
|
||||
@@ -456,8 +474,8 @@ class array {
|
||||
// can share the underlying data buffer.
|
||||
std::shared_ptr<Data> data;
|
||||
|
||||
// Properly offset data pointer
|
||||
void* data_ptr{nullptr};
|
||||
// Offset from beginning of data pointer
|
||||
int64_t offset{0};
|
||||
|
||||
// The size in elements of the data buffer the array accesses
|
||||
size_t data_size;
|
||||
|
||||
@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
BinaryOpType bopt) {
|
||||
BinaryOpType bopt,
|
||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||
bool b_donatable = is_donatable(b, out);
|
||||
bool a_donatable = is_donatable(a, out);
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b_donatable) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(b.data_size() * out.itemsize()),
|
||||
mallocfn(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
|
||||
out.copy_shared_buffer(a);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
mallocfn(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
mallocfn(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
|
||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(mallocfn(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
out.set_data(allocator::malloc(0));
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
|
||||
157
mlx/backend/common/buffer_cache.h
Normal file
157
mlx/backend/common/buffer_cache.h
Normal file
@@ -0,0 +1,157 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(
|
||||
size_t page_size,
|
||||
std::function<size_t(T*)> get_size,
|
||||
std::function<void(T*)> free)
|
||||
: page_size_(page_size),
|
||||
get_size_(std::move(get_size)),
|
||||
free_(std::move(free)) {}
|
||||
|
||||
~BufferCache() {
|
||||
clear();
|
||||
}
|
||||
|
||||
BufferCache(const BufferCache&) = delete;
|
||||
BufferCache& operator=(const BufferCache&) = delete;
|
||||
|
||||
T* reuse_from_cache(size_t size) {
|
||||
// Find the closest buffer in pool.
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
if (it == buffer_pool_.end() ||
|
||||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Collect from the cache.
|
||||
T* buf = it->second->buf;
|
||||
pool_size_ -= it->first;
|
||||
|
||||
// Remove from record.
|
||||
remove_from_list(it->second);
|
||||
buffer_pool_.erase(it);
|
||||
return buf;
|
||||
}
|
||||
|
||||
void recycle_to_cache(T* buf) {
|
||||
assert(buf);
|
||||
// Add to cache.
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
size_t size = get_size_(buf);
|
||||
pool_size_ += size;
|
||||
buffer_pool_.emplace(size, bh);
|
||||
}
|
||||
|
||||
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
// Release buffer.
|
||||
size_t size = get_size_(tail_->buf);
|
||||
total_bytes_freed += size;
|
||||
free_(tail_->buf);
|
||||
n_release++;
|
||||
|
||||
// Remove from record.
|
||||
auto its = buffer_pool_.equal_range(size);
|
||||
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||
return el.second == tail_;
|
||||
});
|
||||
assert(it != buffer_pool_.end());
|
||||
buffer_pool_.erase(it);
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
int clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
free_(holder->buf);
|
||||
n_release++;
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
size_t cache_size() const {
|
||||
return pool_size_;
|
||||
}
|
||||
|
||||
size_t page_size() const {
|
||||
return page_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||
|
||||
BufferHolder* prev{nullptr};
|
||||
BufferHolder* next{nullptr};
|
||||
T* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add) {
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void remove_from_list(BufferHolder* to_remove) {
|
||||
if (to_remove->prev && to_remove->next) { // if middle
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
delete to_remove;
|
||||
}
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_{nullptr};
|
||||
BufferHolder* tail_{nullptr};
|
||||
size_t pool_size_{0};
|
||||
|
||||
const size_t page_size_;
|
||||
std::function<size_t(T*)> get_size_;
|
||||
std::function<void(T*)> free_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,8 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -15,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) {
|
||||
return print_float_constant<float16_t>(os, x);
|
||||
case bfloat16:
|
||||
return print_float_constant<bfloat16_t>(os, x);
|
||||
case float64:
|
||||
return print_float_constant<double>(os, x);
|
||||
case complex64:
|
||||
return print_complex_constant<complex64_t>(os, x);
|
||||
case int8:
|
||||
@@ -51,6 +52,8 @@ std::string get_type_string(Dtype d) {
|
||||
return "float16_t";
|
||||
case bfloat16:
|
||||
return "bfloat16_t";
|
||||
case float64:
|
||||
return "double";
|
||||
case complex64:
|
||||
return "complex64_t";
|
||||
case bool_:
|
||||
@@ -79,55 +82,6 @@ std::string get_type_string(Dtype d) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
NodeNamer namer;
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// Fill the input names. This is not really necessary, I just like having A,
|
||||
// B, C, ... as the inputs.
|
||||
for (auto& x : inputs) {
|
||||
namer.get_name(x);
|
||||
}
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << (is_scalar(x) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const Shape& shape) {
|
||||
@@ -159,9 +113,10 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous) {
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous,
|
||||
const std::function<allocator::Buffer(size_t)>&
|
||||
mallocfn /* = allocator::malloc */) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
Strides strides;
|
||||
@@ -175,8 +130,7 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
in.is_donatable() && is_constant(i)) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
@@ -188,7 +142,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||
mallocfn(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
@@ -204,16 +158,86 @@ void compiled_allocate_outputs(
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
is_constant(i)) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant) {
|
||||
const Shape& shape = out.shape();
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
if (contiguous) {
|
||||
return {true, shape, {}};
|
||||
}
|
||||
|
||||
std::vector<Strides> strides_vec{out.strides()};
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants.
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip scalar inputs.
|
||||
const auto& x = inputs[i];
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
size_t j = 0;
|
||||
for (; j < shape.size() - x.ndim(); ++j) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides_vec.push_back(std::move(xstrides));
|
||||
}
|
||||
|
||||
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||
}
|
||||
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (const auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (const auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -14,19 +13,17 @@ inline bool is_static_cast(const Primitive& p) {
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids);
|
||||
|
||||
std::string get_type_string(Dtype d);
|
||||
|
||||
template <typename T>
|
||||
void print_float_constant(std::ostream& os, const array& x) {
|
||||
auto old_precision = os.precision();
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||
<< x.item<T>() << std::setprecision(old_precision);
|
||||
if constexpr (std::is_same_v<T, double>) {
|
||||
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
||||
} else {
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
||||
}
|
||||
os << x.item<T>() << std::setprecision(old_precision);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -60,8 +57,21 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous,
|
||||
const std::function<allocator::Buffer(size_t)>& mallocfn =
|
||||
allocator::malloc);
|
||||
|
||||
// Collapse contiguous dims ignoring scalars and constants.
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant);
|
||||
|
||||
// Return whether the kernel should use large index.
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -22,23 +22,27 @@ enum class CopyType {
|
||||
GeneralGeneral
|
||||
};
|
||||
|
||||
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
inline bool set_copy_output_data(
|
||||
const array& in,
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
return true;
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
mallocfn(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(mallocfn(out.nbytes()));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
if (n > (1 << 26)) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
||||
67
mlx/backend/common/matmul.h
Normal file
67
mlx/backend/common/matmul.h
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||
const array& a,
|
||||
const array& b) {
|
||||
if (a.ndim() == 2) {
|
||||
return {Shape{1}, Strides{0}, Strides{0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||
|
||||
auto a_batch_strides = batch_strides[0];
|
||||
auto b_batch_strides = batch_strides[1];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
a_batch_strides.push_back(0);
|
||||
b_batch_strides.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
||||
}
|
||||
|
||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||
collapse_batches(const array& a, const array& b, const array& c) {
|
||||
if (a.ndim() == 2) {
|
||||
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
auto C_batch_stride = batch_strides[2];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
A_batch_stride.push_back(0);
|
||||
B_batch_stride.push_back(0);
|
||||
C_batch_stride.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -5,11 +5,9 @@
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
Shape shape,
|
||||
Strides strides,
|
||||
const std::vector<int>& axes) {
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
@@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
return shapes_without_reduction_axes(
|
||||
std::move(shape), std::move(strides), axes);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
|
||||
@@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes);
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
Shape shape,
|
||||
Strides strides,
|
||||
const std::vector<int>& axes);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -14,17 +14,13 @@ std::tuple<int64_t, Strides> prepare_slice(
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
}
|
||||
// Normalize the offset
|
||||
if (data_offset < 0) {
|
||||
data_offset += in.data_size();
|
||||
}
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
size_t data_offset,
|
||||
int64_t data_offset,
|
||||
size_t data_size,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
@@ -45,23 +41,30 @@ void slice(
|
||||
const Shape& start_indices,
|
||||
const Shape& strides) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
out.set_data(allocator::malloc(0));
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate out strides, initial offset
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
int64_t data_end = 1;
|
||||
for (int i = 0; i < start_indices.size(); ++i) {
|
||||
if (in.shape()[i] > 1) {
|
||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
||||
data_end += end_idx * in.strides()[i];
|
||||
|
||||
// Get the location of the end based on the inp strides and out.shape()
|
||||
int64_t low_idx = 0;
|
||||
int64_t high_idx = 0;
|
||||
for (int i = 0; i < inp_strides.size(); ++i) {
|
||||
auto delta = inp_strides[i] * (out.shape()[i] - 1);
|
||||
if (inp_strides[i] > 0) {
|
||||
high_idx += delta;
|
||||
} else {
|
||||
low_idx += delta;
|
||||
}
|
||||
}
|
||||
if (data_end < 0) {
|
||||
data_end += in.data_size();
|
||||
int64_t data_size = (high_idx - low_idx) + 1;
|
||||
if (data_size < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[slice] Computed invalid data size: " << data_size << ".";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
size_t data_size = (data_end - data_offset);
|
||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ namespace mlx::core {
|
||||
enum class TernaryOpType {
|
||||
ScalarScalarScalar,
|
||||
VectorVectorVector,
|
||||
VectorVectorScalar,
|
||||
VectorScalarVector,
|
||||
General,
|
||||
};
|
||||
|
||||
@@ -25,6 +27,14 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
|
||||
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
||||
c.flags().col_contiguous)) {
|
||||
topt = TernaryOpType::VectorVectorVector;
|
||||
} else if (
|
||||
b.data_size() == 1 && a.flags().row_contiguous &&
|
||||
c.flags().row_contiguous) {
|
||||
topt = TernaryOpType::VectorScalarVector;
|
||||
} else if (
|
||||
c.data_size() == 1 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
topt = TernaryOpType::VectorVectorScalar;
|
||||
} else {
|
||||
topt = TernaryOpType::General;
|
||||
}
|
||||
@@ -36,7 +46,8 @@ inline void set_ternary_op_output_data(
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
TernaryOpType topt) {
|
||||
TernaryOpType topt,
|
||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||
auto maybe_donate = [&out](const array& x) {
|
||||
if (is_donatable(x, out)) {
|
||||
out.copy_shared_buffer(x);
|
||||
@@ -47,24 +58,25 @@ inline void set_ternary_op_output_data(
|
||||
|
||||
switch (topt) {
|
||||
case TernaryOpType::ScalarScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
|
||||
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
|
||||
break;
|
||||
case TernaryOpType::VectorVectorVector:
|
||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||
out.set_data(
|
||||
allocator::malloc(out.itemsize() * b.data_size()),
|
||||
mallocfn(out.itemsize() * b.data_size()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
}
|
||||
break;
|
||||
case TernaryOpType::VectorVectorScalar:
|
||||
case TernaryOpType::VectorScalarVector:
|
||||
case TernaryOpType::General:
|
||||
// Try to donate an input which is row_contiguous
|
||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(mallocfn(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
29
mlx/backend/common/unary.h
Normal file
29
mlx/backend/common/unary.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void set_unary_output_data(
|
||||
const array& in,
|
||||
array& out,
|
||||
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
mallocfn(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(mallocfn(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,9 +1,22 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::filesystem::path current_binary_dir() {
|
||||
static std::filesystem::path binary_dir = []() {
|
||||
Dl_info info;
|
||||
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||
throw std::runtime_error("Unable to get current binary dir.");
|
||||
}
|
||||
return std::filesystem::path(info.dli_fname).parent_path();
|
||||
}();
|
||||
return binary_dir;
|
||||
}
|
||||
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
@@ -101,4 +114,118 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||
}
|
||||
|
||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
while (true) {
|
||||
int presum = sum;
|
||||
// Check all the pows
|
||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||
pows[0]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||
pows[1]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||
pows[2]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == presum || sum == pow2) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
||||
}
|
||||
|
||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
||||
// Dims with strides of 0 are ignored as they
|
||||
// correspond to broadcasted dimensions
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
return std::make_tuple(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
Dims get_2d_grid_dims_common(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor) {
|
||||
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||
// divided by divisor.
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// No need to add this shape we can just remove it from the divisor.
|
||||
if (divisor % shape[i] == 0) {
|
||||
divisor /= shape[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
|
||||
if (divisor > 1) {
|
||||
if (grid_x % divisor == 0) {
|
||||
grid_x /= divisor;
|
||||
divisor = 1;
|
||||
} else if (grid_y % divisor == 0) {
|
||||
grid_y /= divisor;
|
||||
divisor = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
if (divisor > 1) {
|
||||
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
|
||||
}
|
||||
return std::make_tuple(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
||||
auto gx = (dim0 + bx - 1) / bx;
|
||||
auto gy = (dim1 + by - 1) / by;
|
||||
auto gz = (dim2 + bz - 1) / bz;
|
||||
|
||||
return std::make_pair(
|
||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -2,12 +2,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <filesystem>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Return the directory that contains current shared library.
|
||||
std::filesystem::path current_binary_dir();
|
||||
|
||||
inline int64_t
|
||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
int64_t loc = 0;
|
||||
@@ -70,6 +75,31 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const array& a,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
// Compute the thread block dimensions which fit the given
|
||||
// input dimensions.
|
||||
// - The thread block dimensions will be powers of two
|
||||
// - The thread block size will be less than 2^pow2
|
||||
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
|
||||
// Computes a 2D grid where each element is < UINT_MAX
|
||||
// Assumes:
|
||||
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||
// - shape and strides correspond to a contiguous (no holes) but
|
||||
// possibly broadcasted array
|
||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
||||
|
||||
// Same as above but we do an implicit division with divisor.
|
||||
// Basically, equivalent to factorizing
|
||||
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||
Dims get_2d_grid_dims_common(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
|
||||
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
||||
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int dims = shape_.size();
|
||||
@@ -165,4 +195,11 @@ void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
|
||||
template <typename T>
|
||||
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||
vec.erase(std::next(vec.begin(), index));
|
||||
return vec;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -40,11 +40,13 @@ add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
|
||||
@@ -14,10 +14,8 @@ template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
Strides strides = in.strides();
|
||||
Shape shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
Strides strides = remove_index(in.strides(), axis);
|
||||
Shape shape = remove_index(in.shape(), axis);
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
|
||||
|
||||
11
mlx/backend/cpu/available.cpp
Normal file
11
mlx/backend/cpu/available.cpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/available.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
9
mlx/backend/cpu/available.h
Normal file
9
mlx/backend/cpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
@@ -14,233 +14,11 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_float(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_int(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[binary_int] Type not supported");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Add(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Add(), stream());
|
||||
}
|
||||
|
||||
void DivMod::eval_cpu(
|
||||
@@ -324,14 +102,14 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Divide(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Divide(), stream());
|
||||
}
|
||||
|
||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Remainder(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Remainder(), stream());
|
||||
}
|
||||
|
||||
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -372,89 +150,90 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
});
|
||||
} else {
|
||||
comparison_op(a, b, out, detail::Equal(), stream());
|
||||
comparison_op_cpu(a, b, out, detail::Equal(), stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
|
||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream());
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
||||
comparison_op_cpu(
|
||||
inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
||||
}
|
||||
|
||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
|
||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream());
|
||||
}
|
||||
|
||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
||||
}
|
||||
|
||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary_float(a, b, out, detail::LogAddExp(), stream());
|
||||
binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream());
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd(), stream());
|
||||
binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream());
|
||||
}
|
||||
|
||||
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr(), stream());
|
||||
binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream());
|
||||
}
|
||||
|
||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Maximum(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Maximum(), stream());
|
||||
}
|
||||
|
||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Minimum(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Minimum(), stream());
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Multiply(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Multiply(), stream());
|
||||
}
|
||||
|
||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
||||
comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Power(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Power(), stream());
|
||||
}
|
||||
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Subtract(), stream());
|
||||
binary_op_cpu(a, b, out, detail::Subtract(), stream());
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -463,19 +242,19 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_int(a, b, out, detail::BitwiseAnd(), stream());
|
||||
binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_int(a, b, out, detail::BitwiseOr(), stream());
|
||||
binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_int(a, b, out, detail::BitwiseXor(), stream());
|
||||
binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_int(a, b, out, detail::LeftShift(), stream());
|
||||
binary_int_op_cpu(a, b, out, detail::LeftShift(), stream());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_int(a, b, out, detail::RightShift(), stream());
|
||||
binary_int_op_cpu(a, b, out, detail::RightShift(), stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -484,7 +263,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
binary_float(a, b, out, detail::ArcTan2(), stream());
|
||||
binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -290,4 +291,227 @@ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
||||
binary_op<T, T, Op>(a, b, out, bopt);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_op_cpu(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op_cpu(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_float_op_cpu(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_int_op_cpu(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[binary_int] Type not supported");
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy(
|
||||
copy_cpu(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/version.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -94,7 +95,11 @@ void* compile(
|
||||
kernel_file_name = kernel_name;
|
||||
}
|
||||
|
||||
auto output_dir = std::filesystem::temp_directory_path();
|
||||
auto output_dir =
|
||||
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
|
||||
if (!std::filesystem::exists(output_dir)) {
|
||||
std::filesystem::create_directories(output_dir);
|
||||
}
|
||||
|
||||
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
|
||||
auto shared_lib_path = (output_dir / shared_lib_name).string();
|
||||
@@ -146,18 +151,9 @@ inline void build_kernel(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous,
|
||||
int ndim) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
NodeNamer namer;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
@@ -166,25 +162,28 @@ inline void build_kernel(
|
||||
#endif
|
||||
|
||||
// Start the kernel
|
||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
||||
os << "void " << kernel_name
|
||||
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
auto tstr = get_type_string(x.dtype());
|
||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
// Scalars and contiguous need no strides
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
os << " const int64_t* " << xname << "_strides = strides["
|
||||
<< strides_index++ << "];" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,10 +193,8 @@ inline void build_kernel(
|
||||
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
||||
<< "*)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
|
||||
} else {
|
||||
// Add output size
|
||||
if (contiguous) {
|
||||
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
|
||||
@@ -211,10 +208,11 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
if (is_constant(i)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
@@ -238,7 +236,7 @@ inline void build_kernel(
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||
} else {
|
||||
x.primitive().print(os);
|
||||
os << x.primitive().name();
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
@@ -264,8 +262,9 @@ inline void build_kernel(
|
||||
} else {
|
||||
for (int d = ndim - 1; d >= 0; --d) {
|
||||
// Update pointers
|
||||
for (auto& x : inputs) {
|
||||
if (is_constant(x) || is_scalar(x)) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
if (is_constant(i) || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
auto& xname = namer.get_name(x);
|
||||
@@ -287,65 +286,33 @@ inline void build_kernel(
|
||||
void Compiled::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Collect function input arguments.
|
||||
std::vector<void*> args;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
const auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
|
||||
if (contiguous || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the input to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < shape.size() - x.ndim(); j++) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides.push_back(std::move(xstrides));
|
||||
args.push_back(strides.back().data());
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
int ndim = shape.size();
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
if (!contiguous) {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
kernel_name += std::to_string(ndim);
|
||||
}
|
||||
|
||||
// Get the function
|
||||
auto fn_ptr = compile(kernel_name, [&]() {
|
||||
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||
std::ostringstream kernel;
|
||||
kernel << get_kernel_preamble() << std::endl;
|
||||
kernel << "extern \"C\" {" << std::endl;
|
||||
@@ -355,7 +322,7 @@ void Compiled::eval_cpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
is_constant_,
|
||||
contiguous,
|
||||
ndim);
|
||||
// Close extern "C"
|
||||
@@ -363,26 +330,26 @@ void Compiled::eval_cpu(
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
Shape out_shape;
|
||||
if (!contiguous) {
|
||||
out_shape = outputs[0].shape();
|
||||
args.push_back((void*)out_shape.data());
|
||||
} else {
|
||||
if (contiguous) {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
encoder.dispatch(
|
||||
[fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
||||
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
|
||||
encoder.dispatch([fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
shape = std::move(shape)]() mutable {
|
||||
SmallVector<int64_t*> strides_ptrs;
|
||||
for (auto& s : strides) {
|
||||
strides_ptrs.push_back(s.data());
|
||||
}
|
||||
fun(shape.data(), strides_ptrs.data(), args.data());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -295,7 +295,11 @@ inline void copy_inplace_dispatch(
|
||||
|
||||
} // namespace
|
||||
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
void copy_cpu_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
@@ -305,7 +309,7 @@ void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
||||
}
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
bool donated = set_copy_output_data(src, dst, ctype);
|
||||
if (donated && src.dtype() == dst.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
@@ -315,10 +319,10 @@ void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_inplace(src, dst, ctype, stream);
|
||||
copy_cpu_inplace(src, dst, ctype, stream);
|
||||
}
|
||||
|
||||
void copy_inplace(
|
||||
void copy_cpu_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
@@ -373,4 +377,10 @@ void copy_inplace(
|
||||
});
|
||||
}
|
||||
|
||||
array contiguous_copy_cpu(const array& arr, Stream stream) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||
return arr_copy;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -10,10 +10,14 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
void copy_cpu_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream);
|
||||
|
||||
void copy_inplace(
|
||||
void copy_cpu_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
@@ -26,4 +30,7 @@ void copy_inplace(
|
||||
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||
|
||||
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||
array contiguous_copy_cpu(const array& arr, Stream stream);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -13,9 +13,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return {arr, false};
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General, stream);
|
||||
return {arr_copy, true};
|
||||
return {contiguous_copy_cpu(arr, stream), true};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -34,8 +32,7 @@ void AllReduce::eval_cpu(
|
||||
}
|
||||
return in;
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy(in, arr_copy, CopyType::General, s);
|
||||
array arr_copy = contiguous_copy_cpu(in, s);
|
||||
out.copy_shared_buffer(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
@@ -98,4 +95,9 @@ void Recv::eval_cpu(
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
}
|
||||
|
||||
void ReduceScatter::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
|
||||
}
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
173
mlx/backend/cpu/eig.cpp
Normal file
173
mlx/backend/cpu/eig.cpp
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void eig_impl(
|
||||
array& a,
|
||||
array& vectors,
|
||||
array& values,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using OT = std::complex<T>;
|
||||
auto a_ptr = a.data<T>();
|
||||
auto eig_ptr = values.data<OT>();
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(values);
|
||||
OT* vec_ptr = nullptr;
|
||||
if (compute_eigenvectors) {
|
||||
encoder.set_output_array(vectors);
|
||||
vec_ptr = vectors.data<OT>();
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
vec_ptr,
|
||||
eig_ptr,
|
||||
compute_eigenvectors,
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
char jobr = 'N';
|
||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||
int n_vecs_r = 1;
|
||||
int n_vecs_l = compute_eigenvectors ? N : 1;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
&work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
}
|
||||
|
||||
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
||||
auto vec_tmp_data =
|
||||
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
a_ptr,
|
||||
&N,
|
||||
eig_tmp,
|
||||
eig_tmp + N,
|
||||
vec_tmp,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
||||
}
|
||||
if (vec_ptr) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
if (eig_ptr[i].imag() != 0) {
|
||||
// This vector and the next are a pair
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {
|
||||
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
||||
vec_ptr[(i + 1) * N + j] = {
|
||||
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
||||
}
|
||||
i += 1;
|
||||
} else {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
||||
}
|
||||
}
|
||||
}
|
||||
vec_ptr += N * N;
|
||||
}
|
||||
a_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(a);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Eig::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const auto& a = inputs[0];
|
||||
auto& values = outputs[0];
|
||||
|
||||
auto vectors = compute_eigenvectors_
|
||||
? outputs[1]
|
||||
: array(a.shape(), complex64, nullptr, {});
|
||||
|
||||
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
||||
copy_cpu(
|
||||
a,
|
||||
a_copy,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
|
||||
values.set_data(allocator::malloc(values.nbytes()));
|
||||
|
||||
if (compute_eigenvectors_) {
|
||||
// Set the strides and flags so the eigenvectors
|
||||
// are in the columns of the output
|
||||
auto flags = vectors.flags();
|
||||
auto strides = vectors.strides();
|
||||
auto ndim = a.ndim();
|
||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||
|
||||
if (a.size() > 1) {
|
||||
flags.row_contiguous = false;
|
||||
if (ndim > 2) {
|
||||
flags.col_contiguous = false;
|
||||
} else {
|
||||
flags.col_contiguous = true;
|
||||
}
|
||||
}
|
||||
vectors.set_data(
|
||||
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
|
||||
}
|
||||
switch (a.dtype()) {
|
||||
case float32:
|
||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -12,6 +12,133 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct EighWork {};
|
||||
|
||||
template <typename T>
|
||||
struct EighWork<
|
||||
T,
|
||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||
using R = T;
|
||||
|
||||
char jobz;
|
||||
char uplo;
|
||||
int N;
|
||||
int lwork;
|
||||
int liwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EighWork(char jobz_, char uplo_, int N_)
|
||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||
}
|
||||
|
||||
void run(T* vectors, T* values) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vectors,
|
||||
&N,
|
||||
values,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(buffers[1].buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct EighWork<std::complex<float>> {
|
||||
using T = std::complex<float>;
|
||||
using R = float;
|
||||
|
||||
char jobz;
|
||||
char uplo;
|
||||
int N;
|
||||
int lwork;
|
||||
int lrwork;
|
||||
int liwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EighWork(char jobz_, char uplo_, int N_)
|
||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
|
||||
T work;
|
||||
R rwork;
|
||||
int iwork;
|
||||
heevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&rwork,
|
||||
&lrwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work.real());
|
||||
lrwork = static_cast<int>(rwork);
|
||||
liwork = iwork;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||
}
|
||||
|
||||
void run(T* vectors, R* values) {
|
||||
heevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vectors,
|
||||
&N,
|
||||
values,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||
&lrwork,
|
||||
static_cast<int*>(buffers[2].buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
if (jobz == 'V') {
|
||||
// We have pre-transposed the vectors but we also must conjugate them
|
||||
// when they are complex.
|
||||
//
|
||||
// We could vectorize this but it is so fast in comparison to heevd that
|
||||
// it doesn't really matter.
|
||||
for (int i = 0; i < N; i++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
*vectors = std::conj(*vectors);
|
||||
vectors++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void eigh_impl(
|
||||
array& vectors,
|
||||
@@ -19,8 +146,10 @@ void eigh_impl(
|
||||
const std::string& uplo,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using R = typename EighWork<T>::R;
|
||||
|
||||
auto vec_ptr = vectors.data<T>();
|
||||
auto eig_ptr = values.data<T>();
|
||||
auto eig_ptr = values.data<R>();
|
||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
@@ -33,49 +162,17 @@ void eigh_impl(
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
int lwork = -1;
|
||||
int liwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
EighWork<T> work(jobz, uplo, N);
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)};
|
||||
// Work loop
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vec_ptr,
|
||||
&N,
|
||||
eig_ptr,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
work.run(vec_ptr, eig_ptr);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
if (work.info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
<< work.info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -99,7 +196,7 @@ void Eigh::eval_cpu(
|
||||
|
||||
values.set_data(allocator::malloc(values.nbytes()));
|
||||
|
||||
copy(
|
||||
copy_cpu(
|
||||
a,
|
||||
vectors,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
@@ -131,6 +228,10 @@ void Eigh::eval_cpu(
|
||||
eigh_impl<double>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
eigh_impl<std::complex<float>>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
@@ -49,9 +48,15 @@ void matmul_bnns(
|
||||
size_t K = a_shape[ndim - 1];
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
if (beta != 1.0 && beta != 0.0) {
|
||||
// scale the output
|
||||
for (auto i = 0; i < batch_size * M * N; ++i) {
|
||||
out[i] *= beta;
|
||||
}
|
||||
beta = 1.0;
|
||||
}
|
||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||
/* float alpha = */ alpha,
|
||||
/* float beta = */ beta,
|
||||
|
||||
@@ -88,4 +88,47 @@ void matmul<double>(
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void matmul<complex64_t>(
|
||||
const complex64_t* a,
|
||||
const complex64_t* b,
|
||||
complex64_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
auto calpha = static_cast<complex64_t>(alpha);
|
||||
auto cbeta = static_cast<complex64_t>(beta);
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_cgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
&calpha,
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
lda,
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
ldb,
|
||||
&cbeta,
|
||||
out + M * N * i,
|
||||
ldc);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(
|
||||
copy_cpu(
|
||||
in,
|
||||
out,
|
||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
|
||||
@@ -257,15 +257,11 @@ void gather_axis(
|
||||
const array& ind,
|
||||
array& out,
|
||||
const int axis) {
|
||||
auto strides = ind.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = ind.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
||||
|
||||
strides = src.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
||||
auto shape = remove_index(ind.shape(), axis);
|
||||
ContiguousIterator ind_it(
|
||||
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
|
||||
ContiguousIterator src_it(
|
||||
shape, remove_index(src.strides(), axis), src.ndim() - 1);
|
||||
|
||||
auto ind_ptr = ind.data<IdxT>();
|
||||
auto src_ptr = src.data<T>();
|
||||
@@ -521,7 +517,7 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype, stream());
|
||||
copy_cpu(src, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
std::vector<array> inds;
|
||||
@@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
template <typename T, typename IdxT, typename OpT>
|
||||
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||
auto strides = idx.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
auto shape = idx.shape();
|
||||
shape.erase(shape.begin() + axis);
|
||||
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
||||
|
||||
strides = upd.strides();
|
||||
strides.erase(strides.begin() + axis);
|
||||
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
||||
auto shape = remove_index(idx.shape(), axis);
|
||||
ContiguousIterator idx_it(
|
||||
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
|
||||
ContiguousIterator upd_it(
|
||||
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
|
||||
|
||||
auto idx_ptr = idx.data<IdxT>();
|
||||
auto upd_ptr = upd.data<T>();
|
||||
@@ -694,7 +686,7 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(src, out, ctype, stream());
|
||||
copy_cpu(src, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(idx);
|
||||
@@ -755,4 +747,108 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void masked_scatter_impl(const array& mask, const array& src, array& out) {
|
||||
ContiguousIterator mask_it(mask);
|
||||
ContiguousIterator src_it(src);
|
||||
ContiguousIterator out_it(out);
|
||||
|
||||
const bool* mask_ptr = mask.data<bool>();
|
||||
const T* src_ptr = src.data<T>();
|
||||
T* dst_ptr = out.data<T>();
|
||||
|
||||
const size_t batch_count = mask.shape(0);
|
||||
const size_t mask_batch_size = mask.size() / batch_count;
|
||||
const size_t src_batch_size = src.size() / batch_count;
|
||||
|
||||
for (uint b = 0; b < batch_count; ++b) {
|
||||
size_t src_consumed = 0;
|
||||
src_it.seek(b * src_batch_size);
|
||||
|
||||
for (size_t i = 0; i < mask_batch_size; ++i) {
|
||||
if (mask_ptr[mask_it.loc]) {
|
||||
if (src_consumed >= src_batch_size) {
|
||||
throw std::runtime_error(
|
||||
"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.");
|
||||
}
|
||||
dst_ptr[out_it.loc] = src_ptr[src_it.loc];
|
||||
src_it.step();
|
||||
++src_consumed;
|
||||
}
|
||||
mask_it.step();
|
||||
out_it.step();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
|
||||
auto& dst = inputs[0];
|
||||
auto& mask = inputs[1];
|
||||
auto& src = inputs[2];
|
||||
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
auto ctype =
|
||||
dst.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy_cpu(dst, out, ctype, stream());
|
||||
|
||||
if (mask.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(mask);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([mask = array::unsafe_weak_copy(mask),
|
||||
src = array::unsafe_weak_copy(src),
|
||||
out = array::unsafe_weak_copy(out)]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
masked_scatter_impl<bool>(mask, src, out);
|
||||
break;
|
||||
case uint8:
|
||||
masked_scatter_impl<uint8_t>(mask, src, out);
|
||||
break;
|
||||
case uint16:
|
||||
masked_scatter_impl<uint16_t>(mask, src, out);
|
||||
break;
|
||||
case uint32:
|
||||
masked_scatter_impl<uint32_t>(mask, src, out);
|
||||
break;
|
||||
case uint64:
|
||||
masked_scatter_impl<uint64_t>(mask, src, out);
|
||||
break;
|
||||
case int8:
|
||||
masked_scatter_impl<int8_t>(mask, src, out);
|
||||
break;
|
||||
case int16:
|
||||
masked_scatter_impl<int16_t>(mask, src, out);
|
||||
break;
|
||||
case int32:
|
||||
masked_scatter_impl<int32_t>(mask, src, out);
|
||||
break;
|
||||
case int64:
|
||||
masked_scatter_impl<int64_t>(mask, src, out);
|
||||
break;
|
||||
case float16:
|
||||
masked_scatter_impl<float16_t>(mask, src, out);
|
||||
break;
|
||||
case float32:
|
||||
masked_scatter_impl<float>(mask, src, out);
|
||||
break;
|
||||
case float64:
|
||||
masked_scatter_impl<double>(mask, src, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
masked_scatter_impl<bfloat16_t>(mask, src, out);
|
||||
break;
|
||||
case complex64:
|
||||
masked_scatter_impl<complex64_t>(mask, src, out);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -115,7 +115,7 @@ void inverse_impl(
|
||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||
|
||||
// The inverse is computed in place, so just copy the input to the output.
|
||||
copy(
|
||||
copy_cpu(
|
||||
a,
|
||||
inv,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// Required for Visual Studio.
|
||||
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
|
||||
#ifdef _MSC_VER
|
||||
#include <complex>
|
||||
#define LAPACK_COMPLEX_CUSTOM
|
||||
#define lapack_complex_float std::complex<float>
|
||||
#define lapack_complex_double std::complex<double>
|
||||
#endif
|
||||
#define lapack_complex_float_real(z) ((z).real())
|
||||
#define lapack_complex_float_imag(z) ((z).imag())
|
||||
#define lapack_complex_double_real(z) ((z).real())
|
||||
#define lapack_complex_double_imag(z) ((z).imag())
|
||||
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#include <Accelerate/Accelerate.h>
|
||||
@@ -32,7 +32,7 @@
|
||||
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
|
||||
#define INSTANTIATE_LAPACK_REAL(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, float>) { \
|
||||
@@ -42,11 +42,24 @@
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_TYPES(geqrf)
|
||||
INSTANTIATE_LAPACK_TYPES(orgqr)
|
||||
INSTANTIATE_LAPACK_TYPES(syevd)
|
||||
INSTANTIATE_LAPACK_TYPES(potrf)
|
||||
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
||||
INSTANTIATE_LAPACK_TYPES(getrf)
|
||||
INSTANTIATE_LAPACK_TYPES(getri)
|
||||
INSTANTIATE_LAPACK_TYPES(trtri)
|
||||
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||
INSTANTIATE_LAPACK_REAL(syevd)
|
||||
INSTANTIATE_LAPACK_REAL(geev)
|
||||
INSTANTIATE_LAPACK_REAL(potrf)
|
||||
INSTANTIATE_LAPACK_REAL(gesdd)
|
||||
INSTANTIATE_LAPACK_REAL(getrf)
|
||||
INSTANTIATE_LAPACK_REAL(getri)
|
||||
INSTANTIATE_LAPACK_REAL(trtri)
|
||||
|
||||
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||
|
||||
@@ -87,8 +87,7 @@ void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_cpu(x, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user