mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 22:44:38 +08:00
Compare commits
397 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
eaf709b83e | ||
![]() |
f0e70afff0 | ||
![]() |
86984cad68 | ||
![]() |
fbc89e3ced | ||
![]() |
38c1e720c2 | ||
![]() |
600e87e03c | ||
![]() |
3836445241 | ||
![]() |
1d2c9d6a07 | ||
![]() |
e8ac6bd2f5 | ||
![]() |
fdadc4f22c | ||
![]() |
79b527f45f | ||
![]() |
dc4eada7f0 | ||
![]() |
70ebc3b598 | ||
![]() |
b13f2aed16 | ||
![]() |
5f04c0f818 | ||
![]() |
55935ccae7 | ||
![]() |
b529515eb1 | ||
![]() |
3cde719eb7 | ||
![]() |
5de6d94a90 | ||
![]() |
99eefd2ec0 | ||
![]() |
e9e268336b | ||
![]() |
7275ac7523 | ||
![]() |
c4189a38e4 | ||
![]() |
68d1b3256b | ||
![]() |
9c6953bda7 | ||
![]() |
ef7ece9851 | ||
![]() |
ddaa4b7dcb | ||
![]() |
dfae2c6989 | ||
![]() |
515f104926 | ||
![]() |
9ecefd56db | ||
![]() |
e5d35aa187 | ||
![]() |
00794c42bc | ||
![]() |
08a1bf3f10 | ||
![]() |
60c4154346 | ||
![]() |
f2c85308c1 | ||
![]() |
1a28b69ee2 | ||
![]() |
ba09f01ce8 | ||
![]() |
6cf48872b7 | ||
![]() |
7b3b8fa000 | ||
![]() |
ec5e2aae61 | ||
![]() |
86389bf970 | ||
![]() |
3290bfa690 | ||
![]() |
8777fd104f | ||
![]() |
c41f7565ed | ||
![]() |
9ba81e3da4 | ||
![]() |
c23888acd7 | ||
![]() |
f98ce25ab9 | ||
![]() |
de5f38fd48 | ||
![]() |
ec2854b13a | ||
![]() |
90823d2938 | ||
![]() |
5f5770e3a2 | ||
![]() |
28f39e9038 | ||
![]() |
b2d2b37888 | ||
![]() |
fe597e141c | ||
![]() |
72ca1539e0 | ||
![]() |
13b26775f1 | ||
![]() |
05d7118561 | ||
![]() |
98b901ad66 | ||
![]() |
5580b47291 | ||
![]() |
bc62932984 | ||
![]() |
a6b5d6e759 | ||
![]() |
a8931306e1 | ||
![]() |
fecdb8717e | ||
![]() |
916fd273ea | ||
![]() |
0da8506552 | ||
![]() |
eda7a7b43e | ||
![]() |
022eabb734 | ||
![]() |
aba899cef8 | ||
![]() |
6a40e1c176 | ||
![]() |
9307b2ab8b | ||
![]() |
522d8d3917 | ||
![]() |
a84cc0123f | ||
![]() |
f018e248cd | ||
![]() |
cfd7237a80 | ||
![]() |
4eef8102c9 | ||
![]() |
69e4dd506b | ||
![]() |
25814a9458 | ||
![]() |
2a980a76ce | ||
![]() |
d343782c8b | ||
![]() |
4e1994e9d7 | ||
![]() |
65a38c452b | ||
![]() |
7b7e2352cd | ||
![]() |
1177d28395 | ||
![]() |
005e7efa64 | ||
![]() |
b42d13ec84 | ||
![]() |
9adcd1a650 | ||
![]() |
3c164fca8c | ||
![]() |
95e335db7b | ||
![]() |
f90206ad74 | ||
![]() |
3779150750 | ||
![]() |
0a9777aa5c | ||
![]() |
45ad06aac8 | ||
![]() |
c6ea2ba329 | ||
![]() |
2770a10240 | ||
![]() |
d2a94f9e6a | ||
![]() |
32da94507a | ||
![]() |
736a340478 | ||
![]() |
117e1355a2 | ||
![]() |
3c3e558c60 | ||
![]() |
cffceda6ee | ||
![]() |
048805ad2c | ||
![]() |
d14c9fe7ea | ||
![]() |
5db90ce822 | ||
![]() |
d699cc1330 | ||
![]() |
c4230747a1 | ||
![]() |
5245f12a46 | ||
![]() |
a198b2787e | ||
![]() |
04edad8c59 | ||
![]() |
392b3060b0 | ||
![]() |
85b34d59bc | ||
![]() |
f599c11bc8 | ||
![]() |
0792ff02ff | ||
![]() |
fd0d63ba5b | ||
![]() |
3835a428c5 | ||
![]() |
9680f72cca | ||
![]() |
a0737273d3 | ||
![]() |
e613d0eaf0 | ||
![]() |
6bcd6bcf70 | ||
![]() |
ba12e4999a | ||
![]() |
4e7cd31d12 | ||
![]() |
5e6c130d93 | ||
![]() |
5d68082881 | ||
![]() |
607181644f | ||
![]() |
89d327075f | ||
![]() |
6bf00ef631 | ||
![]() |
7d042f17fe | ||
![]() |
28b8079e30 | ||
![]() |
7face5d9fd | ||
![]() |
a44dc4bdb0 | ||
![]() |
2d0f384b6f | ||
![]() |
8ff84b5c43 | ||
![]() |
10b271d963 | ||
![]() |
0ebc8a3d25 | ||
![]() |
bbda0fdbdb | ||
![]() |
c86422bdd4 | ||
![]() |
c707b2b0a6 | ||
![]() |
78ba24c37d | ||
![]() |
1a2cb72030 | ||
![]() |
344a29506e | ||
![]() |
71de73a668 | ||
![]() |
4c1dfa58b7 | ||
![]() |
5274c3c43f | ||
![]() |
1762793989 | ||
![]() |
6cec78d8f2 | ||
![]() |
2dc307f2e6 | ||
![]() |
7aea5b1895 | ||
![]() |
9733e16496 | ||
![]() |
7f2d1024f3 | ||
![]() |
428f589364 | ||
![]() |
5cd97f7ffe | ||
![]() |
e425dc00c0 | ||
![]() |
d274ae77f2 | ||
![]() |
55c5ac7820 | ||
![]() |
0145911bea | ||
![]() |
0a5215693e | ||
![]() |
2a45056ba8 | ||
![]() |
142b77751d | ||
![]() |
a5ededf1c3 | ||
![]() |
7df3f792a2 | ||
![]() |
9eb7d7362f | ||
![]() |
1c0c118f7c | ||
![]() |
1a1b2108ec | ||
![]() |
b6c6552d20 | ||
![]() |
83a0340fa7 | ||
![]() |
a62fc1b39f | ||
![]() |
af1b725fda | ||
![]() |
9174606d4c | ||
![]() |
ca305afdbe | ||
![]() |
fe5987b81d | ||
![]() |
a229c8cef0 | ||
![]() |
f6c0499b8d | ||
![]() |
1156c84e86 | ||
![]() |
ec7c7def40 | ||
![]() |
2d8e667400 | ||
![]() |
80c863b972 | ||
![]() |
f5cc1eea72 | ||
![]() |
b7c9f1d38f | ||
![]() |
c6fc07f1f4 | ||
![]() |
ded914f442 | ||
![]() |
4758c8baa1 | ||
![]() |
7064fed1b1 | ||
![]() |
1017ac4a9e | ||
![]() |
ccb61d7aae | ||
![]() |
2235dee906 | ||
![]() |
28091aa1ff | ||
![]() |
121d9a0702 | ||
![]() |
0cea88bcc5 | ||
![]() |
72146fc4cd | ||
![]() |
e6a7ab9675 | ||
![]() |
1f4c127fb9 | ||
![]() |
90532b1f37 | ||
![]() |
a8666a757a | ||
![]() |
a4667da1eb | ||
![]() |
0c259961ac | ||
![]() |
f288db8d34 | ||
![]() |
33421c1dd3 | ||
![]() |
5cc5201914 | ||
![]() |
252e423e81 | ||
![]() |
a4a2764a52 | ||
![]() |
ab8e832c18 | ||
![]() |
1ce0c0fcb0 | ||
![]() |
657f466402 | ||
![]() |
c7b0300af5 | ||
![]() |
da8c885784 | ||
![]() |
1ccaf80575 | ||
![]() |
ec36bfa317 | ||
![]() |
b8f76f717a | ||
![]() |
d1766f2c70 | ||
![]() |
516ded618b | ||
![]() |
c9c81d0584 | ||
![]() |
545f84d905 | ||
![]() |
d5ec172c95 | ||
![]() |
25b3a3e541 | ||
![]() |
058d6ce683 | ||
![]() |
eab93985b8 | ||
![]() |
b51d70a83c | ||
![]() |
259025100e | ||
![]() |
c9d30aa6ac | ||
![]() |
8544b42007 | ||
![]() |
6fa0501387 | ||
![]() |
ae69cb15e9 | ||
![]() |
a64a8dfe45 | ||
![]() |
491fa95b1f | ||
![]() |
92ec632ad5 | ||
![]() |
8ecdfb718b | ||
![]() |
4ba0c24a8f | ||
![]() |
935c8c4bb1 | ||
![]() |
88f993da38 | ||
![]() |
ebfe64b92d | ||
![]() |
0308e9af71 | ||
![]() |
c3628eea49 | ||
![]() |
e03f0372b1 | ||
![]() |
f17536af9c | ||
![]() |
ed4ec81bca | ||
![]() |
7480059306 | ||
![]() |
8bae22b0fa | ||
![]() |
49c34c4161 | ||
![]() |
5548fcc96d | ||
![]() |
070bd433ab | ||
![]() |
c8fb54951a | ||
![]() |
f110357aaa | ||
![]() |
a6b426422e | ||
![]() |
d03c01dfbc | ||
![]() |
a82996e9fb | ||
![]() |
af5a614aad | ||
![]() |
f9640e049d | ||
![]() |
4768c61b57 | ||
![]() |
dfccd17ab9 | ||
![]() |
635117c5d4 | ||
![]() |
50f3535693 | ||
![]() |
9111999af3 | ||
![]() |
6bd28d246e | ||
![]() |
4d595a2a39 | ||
![]() |
3a21f61772 | ||
![]() |
4e1e9520e1 | ||
![]() |
0bf19037ca | ||
![]() |
f3dfa36a3a | ||
![]() |
4f9b60dd53 | ||
![]() |
f76a49e555 | ||
![]() |
310ad8d9db | ||
![]() |
56db268f47 | ||
![]() |
92ab6bdeb8 | ||
![]() |
0070e360a1 | ||
![]() |
9df8fed046 | ||
![]() |
a59fae040f | ||
![]() |
29a620cab2 | ||
![]() |
87d7a2520e | ||
![]() |
40c62c1321 | ||
![]() |
35b412c099 | ||
![]() |
d0f471cff7 | ||
![]() |
6f316b8bf5 | ||
![]() |
7c10c93a1f | ||
![]() |
d92ea094f1 | ||
![]() |
6ae5423b4a | ||
![]() |
9635cffdc8 | ||
![]() |
96986fb362 | ||
![]() |
3ceb341a75 | ||
![]() |
50fa705125 | ||
![]() |
69a2991614 | ||
![]() |
fd3377dd1f | ||
![]() |
d0b6cb0425 | ||
![]() |
95c4a2e3af | ||
![]() |
bc2a29f033 | ||
![]() |
3bb5b4a302 | ||
![]() |
fc88fd9097 | ||
![]() |
c5b0928c1f | ||
![]() |
e047fd977d | ||
![]() |
9d40e521d7 | ||
![]() |
1445dcaa60 | ||
![]() |
e4eeb4e910 | ||
![]() |
aa86876813 | ||
![]() |
974bb54ab2 | ||
![]() |
9bc2183a31 | ||
![]() |
d4b222b6d3 | ||
![]() |
af2af818a6 | ||
![]() |
698e63a608 | ||
![]() |
211411faf2 | ||
![]() |
bb303c45a5 | ||
![]() |
6f7986d592 | ||
![]() |
7cbb4aef17 | ||
![]() |
02bec0bb6d | ||
![]() |
c79f6a4a8c | ||
![]() |
0c5eea226b | ||
![]() |
dcca0d7477 | ||
![]() |
0d5e7716ad | ||
![]() |
d8c824c594 | ||
![]() |
cb431dfc9f | ||
![]() |
61d787726a | ||
![]() |
5e89aace9b | ||
![]() |
2af7e8a9a6 | ||
![]() |
2419edd5b2 | ||
![]() |
bf481e8e5d | ||
![]() |
9d7fa6b8e6 | ||
![]() |
073076ac7d | ||
![]() |
9bd03dd9b4 | ||
![]() |
6931f84412 | ||
![]() |
16ec0556a0 | ||
![]() |
610af352d4 | ||
![]() |
b35f1e3c9c | ||
![]() |
dfa0b9aab4 | ||
![]() |
a4c47b0276 | ||
![]() |
111fefd5e9 | ||
![]() |
c1fe1ef081 | ||
![]() |
8c34c9dac4 | ||
![]() |
91c0277356 | ||
![]() |
9f0d5c12fc | ||
![]() |
59247c2b62 | ||
![]() |
9a3842a2d9 | ||
![]() |
726dbd9267 | ||
![]() |
54f05e7195 | ||
![]() |
26be608470 | ||
![]() |
248431eb3c | ||
![]() |
76f275b4df | ||
![]() |
f1951d6cce | ||
![]() |
62f297b51d | ||
![]() |
09bc32f62f | ||
![]() |
46d8b16ab4 | ||
![]() |
42533931fa | ||
![]() |
9bd3a7102f | ||
![]() |
9e516b71ea | ||
![]() |
eac961ddb1 | ||
![]() |
57c6aa7188 | ||
![]() |
cde5b4ad80 | ||
![]() |
4f72c66911 | ||
![]() |
960e3f0f05 | ||
![]() |
884af42da2 | ||
![]() |
048fabdabd | ||
![]() |
917252a5a1 | ||
![]() |
1a992e31e8 | ||
![]() |
d2ff04a4f2 | ||
![]() |
015c247393 | ||
![]() |
d3cd26820e | ||
![]() |
91f6c499d7 | ||
![]() |
35e9c87ab9 | ||
![]() |
8e88e30d95 | ||
![]() |
0eb56d5be0 | ||
![]() |
f70764a162 | ||
![]() |
dad1b00b13 | ||
![]() |
430ffef58a | ||
![]() |
3d17077187 | ||
![]() |
c9b41d460f | ||
![]() |
32972a5924 | ||
![]() |
f6afb9c09b | ||
![]() |
3ddc07e936 | ||
![]() |
c26208f67d | ||
![]() |
d15fa13daf | ||
![]() |
58a855682c | ||
![]() |
92d7cb71f8 | ||
![]() |
50d8bed468 | ||
![]() |
9dd72cd421 | ||
![]() |
343aa46b78 | ||
![]() |
b8ab89b413 | ||
![]() |
f9f8c167d4 | ||
![]() |
3f86399922 | ||
![]() |
2b8ace6a03 | ||
![]() |
0ab8e099e8 | ||
![]() |
020f048cd0 | ||
![]() |
881615b072 | ||
![]() |
0eef4febfd | ||
![]() |
b54a70ec2d | ||
![]() |
bf6ec92216 | ||
![]() |
c21331d47f | ||
![]() |
e1c9600da3 | ||
![]() |
1fa0d20a30 | ||
![]() |
3274c6a087 | ||
![]() |
9b12093739 | ||
![]() |
f374b6ca4d | ||
![]() |
0070e1db40 | ||
![]() |
95d04805b3 | ||
![]() |
e4534dac17 | ||
![]() |
fef3c4ec1d | ||
![]() |
1bdc038bf9 | ||
![]() |
5523d9c426 | ||
![]() |
d878015228 | ||
![]() |
5900e3249f | ||
![]() |
bacced53d3 | ||
![]() |
4a64d4bff1 |
@@ -13,8 +13,62 @@ parameters:
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
parameters:
|
||||
upload-docs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install
|
||||
command: |
|
||||
brew install python@3.9
|
||||
brew install doxygen
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install -r docs/requirements.txt
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
||||
- when:
|
||||
condition:
|
||||
not: << parameters.upload-docs >>
|
||||
steps:
|
||||
- run:
|
||||
name: Build documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd docs && doxygen && make html O=-W
|
||||
- when:
|
||||
condition: << parameters.upload-docs >>
|
||||
steps:
|
||||
- add_ssh_keys:
|
||||
fingerprints:
|
||||
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
||||
- run:
|
||||
name: Upload documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
git config user.email "mlx@group.apple.com"
|
||||
git config user.name "CircleCI Docs"
|
||||
git checkout gh-pages
|
||||
git rebase main
|
||||
cd docs
|
||||
git rm -rf build/html
|
||||
doxygen && make html O=-W
|
||||
git add -f build/html
|
||||
git commit -m "rebase"
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
@@ -31,10 +85,11 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.1.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
@@ -54,6 +109,8 @@ jobs:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
@@ -68,22 +125,27 @@ jobs:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.2.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.8
|
||||
brew install python@3.9
|
||||
brew install openmpi
|
||||
python3.8 -m venv env
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.1.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@@ -92,7 +154,9 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
@@ -105,7 +169,8 @@ jobs:
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
@@ -154,13 +219,18 @@ jobs:
|
||||
default: "3.9"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.2.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
resource_class: m2pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -172,7 +242,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.1.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
@@ -181,7 +251,7 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEV_RELEASE=1 \
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
@@ -208,7 +278,7 @@ jobs:
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_linux_test_release:
|
||||
build_linux_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
@@ -237,12 +307,13 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.1.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
@@ -253,6 +324,11 @@ jobs:
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
@@ -270,8 +346,9 @@ workflows:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
@@ -288,9 +365,79 @@ workflows:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
upload-docs: true
|
||||
|
||||
prb:
|
||||
when:
|
||||
matches:
|
||||
@@ -305,7 +452,7 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
@@ -317,8 +464,55 @@ workflows:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
weekly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -328,17 +522,79 @@ workflows:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.test_release >>
|
||||
- << pipeline.parameters.linux_release >>
|
||||
jobs:
|
||||
- build_linux_test_release:
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@@ -76,6 +76,9 @@ build/
|
||||
*.out
|
||||
*.app
|
||||
|
||||
# Debug symbols
|
||||
*.pdb
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
.DS_Store
|
||||
|
@@ -1,15 +1,16 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.8
|
||||
rev: v19.1.7
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.8.0
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
|
@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||
|
129
CMakeLists.txt
129
CMakeLists.txt
@@ -1,6 +1,24 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
cmake_minimum_required(VERSION 3.25)
|
||||
|
||||
project(mlx LANGUAGES C CXX)
|
||||
if(NOT MLX_VERSION)
|
||||
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
|
||||
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_major ${CMAKE_MATCH_1})
|
||||
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_minor ${CMAKE_MATCH_1})
|
||||
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_patch ${CMAKE_MATCH_1})
|
||||
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
||||
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
||||
else()
|
||||
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
||||
${MLX_VERSION})
|
||||
endif()
|
||||
|
||||
project(
|
||||
mlx
|
||||
LANGUAGES C CXX
|
||||
VERSION ${MLX_PROJECT_VERSION})
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
@@ -20,22 +38,16 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.18.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
||||
message(
|
||||
STATUS
|
||||
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
||||
)
|
||||
|
||||
set(MLX_BUILD_ARM OFF)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
if(NOT MLX_ENABLE_X64_MAC)
|
||||
@@ -47,14 +59,13 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
||||
)
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
endif()
|
||||
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
@@ -86,25 +97,26 @@ elseif(MLX_BUILD_METAL)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_VERSION} LESS 14.0)
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
||||
endif()
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
||||
)
|
||||
# Get the metal version
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
||||
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
@@ -112,20 +124,58 @@ elseif(MLX_BUILD_METAL)
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
||||
if(WIN32)
|
||||
if(MSVC)
|
||||
# GGUF does not build with MSVC.
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
# There is no prebuilt OpenBLAS distribution for MSVC.
|
||||
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
|
||||
endif()
|
||||
# Windows implementation of dlfcn.h APIs.
|
||||
FetchContent_Declare(
|
||||
dlfcn-win32
|
||||
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
||||
GIT_TAG v1.4.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
block()
|
||||
set(BUILD_SHARED_LIBS OFF)
|
||||
FetchContent_MakeAvailable(dlfcn-win32)
|
||||
endblock()
|
||||
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
|
||||
target_link_libraries(mlx PRIVATE dl)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
if(ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(MLX_USE_ACCELERATE)
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
|
||||
# Download and build OpenBLAS from source code.
|
||||
FetchContent_Declare(
|
||||
openblas
|
||||
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
|
||||
GIT_TAG v0.3.28
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(BUILD_STATIC_LIBS ON) # link statically
|
||||
set(NOFORTRAN ON) # msvc has no fortran compiler
|
||||
FetchContent_MakeAvailable(openblas)
|
||||
target_link_libraries(mlx PRIVATE openblas)
|
||||
target_include_directories(
|
||||
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
|
||||
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
|
||||
else()
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
@@ -143,7 +193,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old
|
||||
# version of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
@@ -156,29 +206,19 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
||||
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
find_package(MPI)
|
||||
if(MPI_FOUND)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "mpirun --version"
|
||||
OUTPUT_VARIABLE MPI_VERSION
|
||||
ERROR_QUIET)
|
||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
elseif(MPI_VERSION STREQUAL "")
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
||||
else()
|
||||
set(MPI_FOUND FALSE)
|
||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Downloading json")
|
||||
FetchContent_Declare(
|
||||
json
|
||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||
FetchContent_MakeAvailable(json)
|
||||
target_include_directories(
|
||||
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
||||
@@ -203,8 +243,7 @@ if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
OUTPUT_VARIABLE nanobind_ROOT)
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||
endif()
|
||||
|
@@ -5,26 +5,26 @@ possible.
|
||||
|
||||
## Pull Requests
|
||||
|
||||
1. Fork and submit pull requests to the repo.
|
||||
1. Fork and submit pull requests to the repo.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
||||
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
||||
4. If you've changed APIs, update the documentation.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||
This should install hooks for running `black` and `clang-format` to ensure
|
||||
consistent style for C++ and python code.
|
||||
|
||||
|
||||
You can also run the formatters manually as follows:
|
||||
|
||||
```
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```
|
||||
black file.py
|
||||
```
|
||||
|
||||
|
||||
```shell
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```shell
|
||||
black file.py
|
||||
```
|
||||
|
||||
or run `pre-commit run --all-files` to check all files in the repo.
|
||||
|
||||
## Issues
|
||||
|
@@ -6,7 +6,7 @@
|
||||
|
||||
[](https://circleci.com/gh/ml-explore/mlx)
|
||||
|
||||
MLX is an array framework for machine learning research on Apple silicon,
|
||||
MLX is an array framework for machine learning on Apple silicon,
|
||||
brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
@@ -5,35 +5,35 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_value_and_grad() {
|
||||
auto x = ones({200, 1000});
|
||||
eval(x);
|
||||
auto fn = [](array x) {
|
||||
auto x = mx::ones({200, 1000});
|
||||
mx::eval(x);
|
||||
auto fn = [](mx::array x) {
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
x = log(exp(x));
|
||||
x = mx::log(mx::exp(x));
|
||||
}
|
||||
return sum(x);
|
||||
return mx::sum(x);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(fn);
|
||||
auto grad_fn = mx::grad(fn);
|
||||
auto independent_value_and_grad = [&]() {
|
||||
auto value = fn(x);
|
||||
auto dfdx = grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(independent_value_and_grad);
|
||||
|
||||
auto value_and_grad_fn = value_and_grad(fn);
|
||||
auto value_and_grad_fn = mx::value_and_grad(fn);
|
||||
auto combined_value_and_grad = [&]() {
|
||||
auto [value, dfdx] = value_and_grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(combined_value_and_grad);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_value_and_grad();
|
||||
}
|
||||
|
@@ -4,21 +4,21 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_add_op() {
|
||||
std::vector<int> sizes(1, 1);
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
sizes.push_back(10 * sizes.back());
|
||||
}
|
||||
set_default_device(Device::cpu);
|
||||
set_default_device(mx::Device::cpu);
|
||||
for (auto size : sizes) {
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
std::cout << "Size " << size << std::endl;
|
||||
TIMEM("cpu", add, a, b, Device::cpu);
|
||||
TIMEM("gpu", add, a, b, Device::gpu);
|
||||
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
|
||||
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -6,105 +6,105 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_irregular_binary_ops_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
b = slice(b, {0}, {size}, {step});
|
||||
TIMEM("1D strided", add, a, b, device);
|
||||
TIMEM("1D strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
auto a = random::uniform({size, size});
|
||||
auto b = random::uniform({size, size});
|
||||
eval(a, b);
|
||||
TIMEM("2D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({size, size});
|
||||
auto b = mx::random::uniform({size, size});
|
||||
mx::eval(a, b);
|
||||
TIMEM("2D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b);
|
||||
eval(b);
|
||||
TIMEM("2D transpose", add, a, b, device);
|
||||
b = mx::transpose(b);
|
||||
mx::eval(b);
|
||||
TIMEM("2D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({size});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({size});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = reshape(b, {size, 1});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 1", add, a, b, device);
|
||||
b = mx::reshape(b, {size, 1});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_3D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int d0 = 32;
|
||||
int d1 = 512;
|
||||
int d2 = 512;
|
||||
auto a = random::uniform({d0, d1, d2});
|
||||
auto b = random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({d0, d1, d2});
|
||||
auto b = mx::random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 2, 1});
|
||||
TIMEM("3D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 2, 1});
|
||||
TIMEM("3D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
||||
b = mx::random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
auto a = random::uniform(shape);
|
||||
auto b = random::uniform(shape);
|
||||
auto a = mx::random::uniform(shape);
|
||||
auto b = mx::random::uniform(shape);
|
||||
|
||||
TIMEM("4D regular", add, a, b, device);
|
||||
TIMEM("4D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
std::string om = "4D broadcast dims ";
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = 1;
|
||||
b = random::uniform(shape);
|
||||
b = mx::random::uniform(shape);
|
||||
std::ostringstream msg;
|
||||
msg << om << i;
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
|
||||
for (int j = i + 1; j < shape.size(); ++j) {
|
||||
shape[j] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[j] = a.shape(j);
|
||||
|
||||
for (int k = j + 1; k < shape.size(); ++k) {
|
||||
shape[k] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j << ", " << k;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[k] = a.shape(k);
|
||||
}
|
||||
}
|
||||
@@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
|
||||
}
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape;
|
||||
auto reshape_fn = [&shape, device](const array& a) {
|
||||
return reshape(a, shape, device);
|
||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||
return mx::reshape(a, shape, device);
|
||||
};
|
||||
|
||||
int size = 64;
|
||||
int d = 2 * size;
|
||||
|
||||
auto a = random::uniform({d, d, d});
|
||||
auto a = mx::random::uniform({d, d, d});
|
||||
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D contiguous", reshape_fn, a);
|
||||
|
||||
a = transpose(a);
|
||||
a = mx::transpose(a);
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose", reshape_fn, a);
|
||||
|
||||
a = transpose(a, {1, 2, 0});
|
||||
a = mx::transpose(a, {1, 2, 0});
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||
}
|
||||
|
||||
void time_irregular_astype_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto a = mx::random::uniform({size});
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
TIMEM("1D strided", astype, a, int32, device);
|
||||
TIMEM("1D strided", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
|
||||
auto a = random::uniform(shape);
|
||||
TIMEM("2D regular", astype, a, int32, device);
|
||||
auto a = mx::random::uniform(shape);
|
||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = transpose(a);
|
||||
TIMEM("2D transpose", astype, a, int32, device);
|
||||
a = mx::transpose(a);
|
||||
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 1) {
|
||||
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
||||
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
|
||||
}
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_irregular_binary_ops_1D();
|
||||
time_irregular_binary_ops_2D();
|
||||
time_irregular_binary_ops_3D();
|
||||
|
@@ -3,20 +3,20 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_creation_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
||||
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
|
||||
TIME(full_fp32);
|
||||
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
||||
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
|
||||
TIME(zeros_fp32);
|
||||
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
||||
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
|
||||
TIME(ones_fp32);
|
||||
|
||||
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
||||
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
|
||||
TIME(arange_fp32);
|
||||
}
|
||||
|
||||
@@ -24,194 +24,196 @@ void time_type_conversions() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = zeros(shape, float32);
|
||||
eval(a);
|
||||
TIMEM("float32 to int32", astype, a, int32, device);
|
||||
TIMEM("float32 to uint32", astype, a, uint32, device);
|
||||
auto a = mx::zeros(shape, mx::float32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
|
||||
a = zeros(shape, int32);
|
||||
eval(a);
|
||||
TIMEM("int32 to float32", astype, a, float32, device);
|
||||
a = mx::zeros(shape, mx::int32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
|
||||
|
||||
a = zeros(shape, bool_);
|
||||
eval(a);
|
||||
TIMEM("bool to float32", astype, a, float32, device);
|
||||
TIMEM("bool to int32", astype, a, int32, device);
|
||||
TIMEM("bool to uint32", astype, a, uint32, device);
|
||||
a = mx::zeros(shape, mx::bool_);
|
||||
mx::eval(a);
|
||||
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
|
||||
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
}
|
||||
|
||||
void time_random_generation() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
|
||||
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
||||
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
|
||||
TIME(uniform);
|
||||
auto normal = [&]() { return random::normal({M, N}, float32); };
|
||||
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
|
||||
TIME(normal);
|
||||
}
|
||||
|
||||
void time_unary_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = random::normal({M, N});
|
||||
eval(a);
|
||||
auto a = mx::random::normal({M, N});
|
||||
mx::eval(a);
|
||||
TIME(mlx::core::abs, a, device);
|
||||
TIME(negative, a, device);
|
||||
TIME(sign, a, device);
|
||||
TIME(square, a, device);
|
||||
TIME(mx::negative, a, device);
|
||||
TIME(mx::sign, a, device);
|
||||
TIME(mx::square, a, device);
|
||||
TIME(mlx::core::sqrt, a, device);
|
||||
TIME(rsqrt, a, device);
|
||||
TIME(mx::rsqrt, a, device);
|
||||
TIME(mlx::core::exp, a, device);
|
||||
|
||||
a = random::uniform({M, N});
|
||||
a = mx::random::uniform({M, N});
|
||||
TIME(mlx::core::log, a, device);
|
||||
}
|
||||
|
||||
void time_binary_ops() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto condition = random::randint(0, 2, {M, N, K});
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
auto condition = mx::random::randint(0, 2, {M, N, K});
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
|
||||
TIME(add, a, b, device);
|
||||
TIME(subtract, a, b, device);
|
||||
TIME(multiply, a, b, device);
|
||||
TIME(divide, a, b, device);
|
||||
TIME(maximum, a, b, device);
|
||||
TIME(minimum, a, b, device);
|
||||
TIME(where, condition, a, b, device);
|
||||
TIME(mx::add, a, b, device);
|
||||
TIME(mx::subtract, a, b, device);
|
||||
TIME(mx::multiply, a, b, device);
|
||||
TIME(mx::divide, a, b, device);
|
||||
TIME(mx::maximum, a, b, device);
|
||||
TIME(mx::minimum, a, b, device);
|
||||
TIME(mx::where, condition, a, b, device);
|
||||
|
||||
condition = array({true});
|
||||
b = random::uniform({1});
|
||||
eval(b);
|
||||
TIMEM("scalar", add, a, b, device);
|
||||
TIMEM("vector-scalar", subtract, a, b, device);
|
||||
TIMEM("scalar-vector", subtract, b, a, device);
|
||||
TIMEM("scalar", multiply, a, b, device);
|
||||
TIMEM("vector-scalar", divide, a, b, device);
|
||||
TIMEM("scalar-vector", divide, b, a, device);
|
||||
TIMEM("scalar-vector", where, condition, a, b, device);
|
||||
condition = mx::array({true});
|
||||
b = mx::random::uniform({1});
|
||||
mx::eval(b);
|
||||
TIMEM("scalar", mx::add, a, b, device);
|
||||
TIMEM("vector-scalar", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-vector", mx::subtract, b, a, device);
|
||||
TIMEM("scalar", mx::multiply, a, b, device);
|
||||
TIMEM("vector-scalar", mx::divide, a, b, device);
|
||||
TIMEM("scalar-vector", mx::divide, b, a, device);
|
||||
TIMEM("scalar-vector", mx::where, condition, a, b, device);
|
||||
|
||||
condition = broadcast_to(array({true}), {1000, 100});
|
||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
|
||||
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
|
||||
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
mx::eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
|
||||
}
|
||||
|
||||
void time_strided_ops() {
|
||||
int M = 50, N = 50, O = 50, P = 50;
|
||||
auto a = random::uniform({M, N, O, P});
|
||||
auto b = random::uniform({M, N, O, P});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIMEM("non-strided", add, a, b, device);
|
||||
a = transpose(a, {1, 0, 2, 3});
|
||||
b = transpose(b, {3, 2, 0, 1});
|
||||
eval(a, b);
|
||||
TIMEM("strided", add, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, O, P});
|
||||
auto b = mx::random::uniform({M, N, O, P});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIMEM("non-strided", mx::add, a, b, device);
|
||||
a = mx::transpose(a, {1, 0, 2, 3});
|
||||
b = mx::transpose(b, {3, 2, 0, 1});
|
||||
mx::eval(a, b);
|
||||
TIMEM("strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_comparisons() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(equal, a, b, device);
|
||||
TIME(greater, a, b, device);
|
||||
TIME(greater_equal, a, b, device);
|
||||
TIME(less, a, b, device);
|
||||
TIME(less_equal, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::equal, a, b, device);
|
||||
TIME(mx::greater, a, b, device);
|
||||
TIME(mx::greater_equal, a, b, device);
|
||||
TIME(mx::less, a, b, device);
|
||||
TIME(mx::less_equal, a, b, device);
|
||||
}
|
||||
|
||||
void time_matvec() {
|
||||
int M = 2000, N = 200;
|
||||
auto a = random::uniform({M, N});
|
||||
auto b = random::uniform({N});
|
||||
auto c = random::uniform({M});
|
||||
eval(a, b, c);
|
||||
auto matvec = [&]() { return matmul(a, b); };
|
||||
auto a = mx::random::uniform({M, N});
|
||||
auto b = mx::random::uniform({N});
|
||||
auto c = mx::random::uniform({M});
|
||||
mx::eval(a, b, c);
|
||||
auto matvec = [&]() { return mx::matmul(a, b); };
|
||||
TIME(matvec);
|
||||
|
||||
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
||||
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
|
||||
TIME(matvec_transpose);
|
||||
}
|
||||
|
||||
void time_matmul() {
|
||||
int M = 1000, N = 1000, K = 1000;
|
||||
auto a = random::uniform({M, K});
|
||||
auto b = random::uniform({K, N});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(matmul, a, b, device);
|
||||
auto a = mx::random::uniform({M, K});
|
||||
auto b = mx::random::uniform({K, N});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::matmul, a, b, device);
|
||||
|
||||
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
||||
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
|
||||
TIME(transpose_matmul);
|
||||
}
|
||||
|
||||
void time_reductions() {
|
||||
auto a = random::normal({10000, 1000});
|
||||
eval(a);
|
||||
auto sum_all = [&a]() { return sum(a, false); };
|
||||
auto a = mx::random::normal({10000, 1000});
|
||||
mx::eval(a);
|
||||
auto sum_all = [&a]() { return mx::sum(a, false); };
|
||||
TIME(sum_all);
|
||||
|
||||
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
||||
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
|
||||
TIME(sum_along_0);
|
||||
|
||||
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
||||
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
|
||||
TIME(sum_along_1);
|
||||
|
||||
auto prod_all = [&a]() { return prod(a, false); };
|
||||
auto prod_all = [&a]() { return mx::prod(a, false); };
|
||||
TIME(prod_all);
|
||||
|
||||
auto all_true = [&a]() { return all(a, false); };
|
||||
auto all_true = [&a]() { return mx::all(a, false); };
|
||||
TIME(all_true);
|
||||
|
||||
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
||||
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
|
||||
TIME(all_along_0);
|
||||
|
||||
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
||||
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
|
||||
TIME(all_along_1);
|
||||
|
||||
auto any_true = [&a]() { return any(a, false); };
|
||||
auto any_true = [&a]() { return mx::any(a, false); };
|
||||
TIME(any_true);
|
||||
|
||||
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
||||
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
|
||||
TIME(argmin_along_0);
|
||||
|
||||
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
auto a = random::normal({1000, 768});
|
||||
eval(a);
|
||||
auto indices = random::randint(0, 1000, {256});
|
||||
eval(indices);
|
||||
auto a = mx::random::normal({1000, 768});
|
||||
mx::eval(a);
|
||||
auto indices = mx::random::randint(0, 1000, {256});
|
||||
mx::eval(indices);
|
||||
|
||||
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
||||
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
|
||||
TIME(embedding_lookup);
|
||||
|
||||
indices = random::randint(0, 768 * 1000, {256 * 768});
|
||||
eval(indices);
|
||||
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
|
||||
mx::eval(indices);
|
||||
|
||||
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
||||
auto single_element_lookup = [&a, &indices]() {
|
||||
return mx::take(a, indices);
|
||||
};
|
||||
TIME(single_element_lookup);
|
||||
|
||||
indices = random::randint(0, 1000, {256});
|
||||
auto updates = random::normal({256, 1, 768});
|
||||
eval(indices, updates);
|
||||
indices = mx::random::randint(0, 1000, {256});
|
||||
auto updates = mx::random::normal({256, 1, 768});
|
||||
mx::eval(indices, updates);
|
||||
|
||||
auto embedding_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -223,10 +225,10 @@ void time_gather_scatter() {
|
||||
};
|
||||
TIME(embedding_add);
|
||||
|
||||
a = reshape(a, {-1});
|
||||
indices = random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = random::normal({256 * 768, 1});
|
||||
eval(a, indices, updates);
|
||||
a = mx::reshape(a, {-1});
|
||||
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = mx::random::normal({256 * 768, 1});
|
||||
mx::eval(a, indices, updates);
|
||||
|
||||
auto single_element_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -240,21 +242,21 @@ void time_gather_scatter() {
|
||||
}
|
||||
|
||||
void time_divmod() {
|
||||
auto a = random::normal({1000});
|
||||
auto b = random::normal({1000});
|
||||
eval({a, b});
|
||||
auto a = mx::random::normal({1000});
|
||||
auto b = mx::random::normal({1000});
|
||||
mx::eval({a, b});
|
||||
|
||||
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
|
||||
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
|
||||
TIME(divmod_fused);
|
||||
|
||||
auto divmod_separate = [&a, &b]() {
|
||||
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
|
||||
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
|
||||
};
|
||||
TIME(divmod_separate);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_creation_ops();
|
||||
time_type_conversions();
|
||||
time_unary_ops();
|
||||
|
@@ -144,6 +144,13 @@ def reduction(op, axis, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
mx.eval(z)
|
||||
|
||||
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
@@ -505,5 +512,8 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
|
@@ -1,7 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from time import time
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
|
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_mm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = x @ w1.T
|
||||
x = x @ w2.T
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_mm()
|
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate(
|
||||
[
|
||||
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||
for i, j in enumerate(idx.tolist())
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_qmm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_qmm()
|
@@ -10,7 +10,12 @@ def layer_norm(x, w, b, eps):
|
||||
x = x.astype(mx.float32)
|
||||
mu = mx.mean(x, -1, keepdims=True)
|
||||
v = mx.var(x, -1, keepdims=True)
|
||||
return (x - mu) * mx.rsqrt(v + eps) * w + b
|
||||
y = (x - mu) * mx.rsqrt(v + eps)
|
||||
if w is not None:
|
||||
y = y * w
|
||||
if b is not None:
|
||||
y = y + b
|
||||
return y
|
||||
|
||||
|
||||
def time_layer_norm():
|
||||
@@ -36,6 +41,28 @@ def time_layer_norm():
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(layer_norm_loop, g1, x)
|
||||
time_fn(layer_norm_loop, g2, x)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_layer_norm()
|
||||
|
@@ -9,7 +9,10 @@ def rms_norm(x, w, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return (x * n).astype(ot) * w
|
||||
y = (x * n).astype(ot)
|
||||
if w is not None:
|
||||
y = y * w
|
||||
return y
|
||||
|
||||
|
||||
def time_rms_norm():
|
||||
@@ -34,6 +37,27 @@ def time_rms_norm():
|
||||
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
||||
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
||||
|
||||
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, y)
|
||||
|
||||
def rms_norm_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(rms_norm_loop, g1, x)
|
||||
time_fn(rms_norm_loop, g2, x)
|
||||
time_fn(rms_norm_loop, mx.compile(g1), x)
|
||||
time_fn(rms_norm_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_rms_norm()
|
||||
|
@@ -9,7 +9,7 @@ from time_utils import measure_runtime
|
||||
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
def scatter(dst, x, idx):
|
||||
dst[*idx] = x
|
||||
dst[tuple(idx)] = x
|
||||
mx.eval(dst)
|
||||
|
||||
idx = []
|
||||
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
|
||||
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
def gather(dst, x, idx, device):
|
||||
dst[*idx] = x
|
||||
def scatter(dst, x, idx, device):
|
||||
dst[tuple(idx)] = x
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
||||
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||
(100_000, 64),
|
||||
(1_000_000, 64),
|
||||
(100_000,),
|
||||
(2_000_00,),
|
||||
(200_000,),
|
||||
(20_000_000,),
|
||||
(10000, 64),
|
||||
(100, 64),
|
||||
@@ -91,6 +91,6 @@ if __name__ == "__main__":
|
||||
|
||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
||||
|
@@ -1,62 +1,223 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
import numpy as np
|
||||
|
||||
MAX_SEQ = 300
|
||||
START_SEQ = 100
|
||||
SEQ_INCREMENT = 50
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 5
|
||||
N_iter_bench = 40
|
||||
N_iter_func = 8
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
def bench(f, *args):
|
||||
for i in range(N_warmup):
|
||||
f(*args)
|
||||
|
||||
def sdpa_primitives(qs, ks, vs, alpha):
|
||||
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ vs
|
||||
return o
|
||||
|
||||
time_fn(sdpa_primitives, q, k, v, scale)
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(*args)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
def sdpa_fused(qs, ks, vs, alpha):
|
||||
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
|
||||
return o
|
||||
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||
|
||||
time_fn(sdpa_fused, q, k, v, scale)
|
||||
scale = 1.0 / math.sqrt(D)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
if mask is not None:
|
||||
if mask == "additive":
|
||||
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
||||
mask = mx.array(mask_np)
|
||||
elif mask == "bool":
|
||||
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||
mask = mx.array(mask_np)
|
||||
|
||||
return q_mx, k_mx, v_mx, scale, mask
|
||||
|
||||
|
||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
n_kv_heads = k.shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
B = q.shape[0]
|
||||
L = q.shape[2]
|
||||
kL = k.shape[2]
|
||||
|
||||
if n_repeats > 1:
|
||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||
k = mx.expand_dims(k, 2)
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
if mask == "causal":
|
||||
q_offset = max(0, kL - L)
|
||||
q_indices = mx.arange(q_offset, q_offset + L)
|
||||
k_indices = mx.arange(kL)
|
||||
mask = q_indices[:, None] >= k_indices[None]
|
||||
|
||||
if n_repeats > 1 and mask.ndim >= 3:
|
||||
if mask.shape[-3] == 1:
|
||||
mask = mx.expand_dims(mask, -3)
|
||||
else:
|
||||
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||
|
||||
if mask.dtype == mx.bool_:
|
||||
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||
else:
|
||||
scores += mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mlx_fused_attn(q, k, v, scale, mask):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||
if transpose:
|
||||
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||
else:
|
||||
return f(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
||||
q_out = q
|
||||
|
||||
for i in range(N_iter_func):
|
||||
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
||||
):
|
||||
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
||||
)
|
||||
|
||||
time_mlx_unfused = bench(
|
||||
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
time_mlx_fused = bench(
|
||||
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
|
||||
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
||||
o_mlx_unfused = do_attention(
|
||||
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
|
||||
atol = 1e-5 if dtype == "float32" else 2e-4
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
||||
print(
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
)
|
||||
|
||||
return time_mlx_fused, time_mlx_unfused
|
||||
|
||||
|
||||
def get_gflop_count(B, M, N, K):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
dtypes = ("float16", "float32")[:1]
|
||||
transposes = (False,)
|
||||
|
||||
# fmt: off
|
||||
shapes_64 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 32, 32, 64, 32, 32),
|
||||
( 1, 64, 64, 64, 32, 32),
|
||||
( 1, 128, 128, 64, 32, 32),
|
||||
( 1, 256, 256, 64, 32, 32),
|
||||
( 1, 512, 512, 64, 32, 32),
|
||||
( 1, 1024, 1024, 64, 32, 8),
|
||||
( 1, 2048, 2048, 64, 32, 8),
|
||||
( 1, 4096, 4096, 64, 32, 8),
|
||||
)
|
||||
|
||||
shapes_80 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 80, 32, 8),
|
||||
( 1, 2048, 2048, 80, 32, 8),
|
||||
( 1, 4096, 4096, 80, 32, 8),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 128, 32, 8),
|
||||
( 1, 2048, 2048, 128, 32, 8),
|
||||
( 1, 4096, 4096, 128, 32, 8),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_80 + shapes_128
|
||||
|
||||
masks = [None, "bool", "causal"]
|
||||
|
||||
print(
|
||||
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||
for mask_in in masks:
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B,
|
||||
qsl,
|
||||
ksl,
|
||||
head_dim,
|
||||
n_q_heads,
|
||||
n_kv_heads,
|
||||
dtype,
|
||||
transpose,
|
||||
mask_in,
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
|
95
benchmarks/python/sdpa_vector_bench.py
Normal file
95
benchmarks/python/sdpa_vector_bench.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 16384
|
||||
H = 32
|
||||
H_k = H // 4
|
||||
D = 128
|
||||
V = 128
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def upproject(x, w):
|
||||
if w is None:
|
||||
return x
|
||||
else:
|
||||
return x @ w.T
|
||||
|
||||
|
||||
def attention(q, k, v, mask=None, w=None):
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
_, _, _, V = v.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
if mask is not None:
|
||||
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
|
||||
s = mx.where(m, s, mx.finfo(s.dtype).min)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, V)
|
||||
|
||||
for i in range(loops):
|
||||
q = _sdpa(q, k, v)
|
||||
q = upproject(q, w)
|
||||
return q
|
||||
|
||||
|
||||
def sdpa(q, k, v, mask=None, w=None):
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
q = upproject(q, w)
|
||||
return q
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mx.eval(q, k, v, w)
|
||||
time_fn(attention, q, k, v, w=w)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mx.eval(q, k, v, w)
|
||||
time_fn(sdpa, q, k, v, w=w)
|
||||
|
||||
|
||||
def time_self_attention_sdpa_with_mask():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mask = mx.full((L,), True)
|
||||
mask[L // 2 :] = False
|
||||
mx.eval(q, k, v, mask, w)
|
||||
|
||||
def sdpa_mask(*args):
|
||||
return sdpa(*args, mask=mask, w=w)
|
||||
|
||||
def attention_mask(*args):
|
||||
return attention(*args, mask=mask, w=w)
|
||||
|
||||
time_fn(attention_mask, q, k, v)
|
||||
time_fn(sdpa_mask, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
time_self_attention_sdpa_with_mask()
|
55
benchmarks/python/synchronize_bench.py
Normal file
55
benchmarks/python/synchronize_bench.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
rank = mx.distributed.init().rank()
|
||||
|
||||
|
||||
def timeit(fn, a):
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(a))
|
||||
|
||||
its = 10
|
||||
tic = time.perf_counter()
|
||||
for _ in range(its):
|
||||
mx.eval(fn(a))
|
||||
toc = time.perf_counter()
|
||||
ms = 1000 * (toc - tic) / its
|
||||
return ms
|
||||
|
||||
|
||||
def all_reduce_benchmark():
|
||||
a = mx.ones((5, 5), mx.int32)
|
||||
|
||||
its_per_eval = 100
|
||||
|
||||
def fn(x):
|
||||
for _ in range(its_per_eval):
|
||||
x = mx.distributed.all_sum(x)
|
||||
x = x - 1
|
||||
return x
|
||||
|
||||
ms = timeit(fn, a) / its_per_eval
|
||||
if rank == 0:
|
||||
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
|
||||
|
||||
|
||||
def all_gather_benchmark():
|
||||
a = mx.ones((5, 5), mx.int32)
|
||||
its_per_eval = 100
|
||||
|
||||
def fn(x):
|
||||
for _ in range(its_per_eval):
|
||||
x = mx.distributed.all_gather(x)[0]
|
||||
return x
|
||||
|
||||
ms = timeit(fn, a) / its_per_eval
|
||||
if rank == 0:
|
||||
print(f"All gather: time per iteration {ms:.6f} (ms)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
all_reduce_benchmark()
|
||||
all_gather_benchmark()
|
@@ -1,5 +1,7 @@
|
||||
include(CMakeParseArguments)
|
||||
|
||||
# clang format off
|
||||
#
|
||||
# ##############################################################################
|
||||
# Build metal library
|
||||
#
|
||||
@@ -11,6 +13,8 @@ include(CMakeParseArguments)
|
||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||
# files (like headers)
|
||||
#
|
||||
# clang format on
|
||||
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||
@@ -21,7 +25,7 @@ macro(mlx_build_metallib)
|
||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
|
@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
|
||||
CREATE_SUBDIRS = NO
|
||||
FULL_PATH_NAMES = YES
|
||||
RECURSIVE = YES
|
||||
GENERATE_HTML = YES
|
||||
GENERATE_HTML = NO
|
||||
GENERATE_LATEX = NO
|
||||
GENERATE_XML = YES
|
||||
XML_PROGRAMLISTING = YES
|
||||
|
@@ -60,6 +60,7 @@ html_theme_options = {
|
||||
},
|
||||
}
|
||||
|
||||
html_favicon = html_theme_options["logo"]["image_light"]
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
.. _custom_metal_kernels:
|
||||
|
||||
Custom Metal Kernels
|
||||
====================
|
||||
|
||||
@@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
|
@@ -22,12 +22,12 @@ You can do that in MLX directly:
|
||||
This function performs that operation while leaving the implementation and
|
||||
function transformations to MLX.
|
||||
|
||||
However you may need to customize the underlying implementation, perhaps to
|
||||
make it faster or for custom differentiation. In this tutorial we will go
|
||||
through adding custom extensions. It will cover:
|
||||
However, you may want to customize the underlying implementation, perhaps to
|
||||
make it faster. In this tutorial we will go through adding custom extensions.
|
||||
It will cover:
|
||||
|
||||
* The structure of the MLX library.
|
||||
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
|
||||
* Implementing a CPU operation.
|
||||
* Implementing a GPU operation using metal.
|
||||
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||
* Building a custom extension and binding it to python.
|
||||
@@ -45,7 +45,7 @@ Operations
|
||||
Operations are the front-end functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
|
||||
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
|
||||
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||
C++:
|
||||
|
||||
@@ -55,7 +55,7 @@ C++:
|
||||
* Scale and sum two vectors element-wise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Use NumPy-style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
@@ -66,7 +66,7 @@ C++:
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
The simplest way to this operation is in terms of existing operations:
|
||||
The simplest way to implement this is with existing operations:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -93,9 +93,9 @@ Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
defines how to create output arrays given input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
@@ -128,7 +128,7 @@ more concrete:
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
@@ -153,9 +153,6 @@ more concrete:
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||
@@ -188,7 +185,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = is_floating_point(promoted_dtype)
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
@@ -234,49 +231,57 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
||||
Implementing the CPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's start by implementing a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
Let's start by implementing :meth:`Axpby::eval_cpu`.
|
||||
|
||||
Our naive method will go over each element of the output array, find the
|
||||
The method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
const T* y_ptr = y.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(y);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
// Launch the CPU kernel
|
||||
encoder.dispatch([x_ptr = x.data<T>(),
|
||||
y_ptr = y.data<T>(),
|
||||
out_ptr = out.data<T>(),
|
||||
size = out.size(),
|
||||
shape = out.shape(),
|
||||
x_strides = x.strides(),
|
||||
y_strides = y.strides(),
|
||||
alpha_,
|
||||
beta_]() {
|
||||
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Our implementation should work for all incoming floating point arrays.
|
||||
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
@@ -284,112 +289,32 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[Axpby] Only supports floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
This is good as a fallback implementation. We can use the ``axpby`` routine
|
||||
provided by the Accelerate_ framework for a faster implementation in certain
|
||||
cases:
|
||||
|
||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||
floats. We can only use it for ``float32`` types.
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
|
||||
elements have fixed strides between them. We only direct to Accelerate
|
||||
if both ``x`` and ``y`` are row contiguous or column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
|
||||
MLX expects to write the output to a new array. We must copy the elements
|
||||
of ``y`` into the output and use that as an input to ``axpby``.
|
||||
|
||||
Let's write an implementation that uses Accelerate in the right conditions.
|
||||
It allocates data for the output, copies ``y`` into it, and then calls the
|
||||
:func:`catlas_saxpby` from accelerate.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
T* y_ptr = out.data<T>();
|
||||
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Call the inplace accelerate operator
|
||||
catlas_saxpby(
|
||||
/* N = */ out.size(),
|
||||
/* ALPHA = */ alpha,
|
||||
/* X = */ x_ptr,
|
||||
/* INCX = */ 1,
|
||||
/* BETA = */ beta,
|
||||
/* Y = */ y_ptr,
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
For inputs that do not fit the criteria for accelerate, we fall back to
|
||||
:meth:`Axpby::eval`. With this in mind, let's finish our
|
||||
:meth:`Axpby::eval_cpu`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common back-end if specializations are not available
|
||||
eval(inputs, outputs);
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||
you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
primitive here.
|
||||
|
||||
Implementing the GPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@@ -420,8 +345,8 @@ element in the output.
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int64_t* x_strides [[buffer(6)]],
|
||||
constant const int64_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
// Convert linear indices to offsets in array
|
||||
@@ -438,24 +363,10 @@ each instantiation a unique host name so we can identify it.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] \
|
||||
[[kernel]] void axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
instantiate_kernel("axpby_general_float32", axpby_general, float)
|
||||
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
|
||||
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
|
||||
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
|
||||
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
@@ -480,7 +391,7 @@ below.
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
@@ -494,7 +405,7 @@ below.
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -509,14 +420,14 @@ below.
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
@@ -530,7 +441,7 @@ below.
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
@@ -558,7 +469,7 @@ one we just defined:
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// The jvp transform on the primitive can be built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
@@ -570,7 +481,7 @@ one we just defined:
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// If argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
@@ -824,7 +735,7 @@ Let's look at a simple script and its results:
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
@@ -832,13 +743,13 @@ Output:
|
||||
|
||||
c shape: [3, 4]
|
||||
c dtype: float32
|
||||
c correctness: True
|
||||
c is correct: True
|
||||
|
||||
Results
|
||||
^^^^^^^
|
||||
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
with the naive :meth:`simple_axpby` we first defined.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -846,13 +757,11 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
from mlx_sample_extensions import axpby
|
||||
import time
|
||||
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
M = 256
|
||||
N = 512
|
||||
M = 4096
|
||||
N = 4096
|
||||
|
||||
x = mx.random.normal((M, N))
|
||||
y = mx.random.normal((M, N))
|
||||
@@ -863,24 +772,24 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
|
||||
def bench(f):
|
||||
# Warm up
|
||||
for i in range(100):
|
||||
for i in range(5):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
|
||||
# Timed run
|
||||
s = time.time()
|
||||
for i in range(5000):
|
||||
for i in range(100):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
e = time.time()
|
||||
return e - s
|
||||
return 1000 * (e - s) / 100
|
||||
|
||||
simple_time = bench(simple_axpby)
|
||||
custom_time = bench(axpby)
|
||||
|
||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
|
||||
|
||||
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
|
||||
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
|
||||
modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations, in
|
||||
|
121
docs/src/dev/mlx_in_cpp.rst
Normal file
121
docs/src/dev/mlx_in_cpp.rst
Normal file
@@ -0,0 +1,121 @@
|
||||
.. _mlx_in_cpp:
|
||||
|
||||
Using MLX in C++
|
||||
================
|
||||
|
||||
You can use MLX in a C++ project with CMake.
|
||||
|
||||
.. note::
|
||||
|
||||
This guide is based one the following `example using MLX in C++
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
|
||||
|
||||
First install MLX:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U mlx
|
||||
|
||||
You can also install the MLX Python package from source or just the C++
|
||||
library. For more information see the :ref:`documentation on installing MLX
|
||||
<build_and_install>`.
|
||||
|
||||
Next make an example program in ``example.cpp``:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
auto x = mx::array({1, 2, 3});
|
||||
auto y = mx::array({1, 2, 3});
|
||||
std::cout << x + y << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
The next step is to setup a CMake file in ``CMakeLists.txt``:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(example LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
|
||||
Depending on how you installed MLX, you may need to tell CMake where to
|
||||
find it.
|
||||
|
||||
If you installed MLX with Python, then add the following to the CMake file:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
|
||||
If you installed the MLX C++ package to a system path, then CMake should be
|
||||
able to find it. If you installed it to a non-standard location or CMake can't
|
||||
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
set(MLX_ROOT "/path/to/mlx/")
|
||||
|
||||
Next, instruct CMake to find MLX:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
Finally, add the ``example.cpp`` program as an executable and link MLX.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
add_executable(example example.cpp)
|
||||
target_link_libraries(example PRIVATE mlx)
|
||||
|
||||
You can build the example with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
|
||||
And run it with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./build/example
|
||||
|
||||
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
|
||||
|
||||
.. list-table:: Package Variables
|
||||
:widths: 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - Variable
|
||||
- Description
|
||||
* - MLX_FOUND
|
||||
- ``True`` if MLX is found
|
||||
* - MLX_INCLUDE_DIRS
|
||||
- Include directory
|
||||
* - MLX_LIBRARIES
|
||||
- Libraries to link against
|
||||
* - MLX_CXX_FLAGS
|
||||
- Additional compiler flags
|
||||
* - MLX_BUILD_ACCELERATE
|
||||
- ``True`` if MLX was built with Accelerate
|
||||
* - MLX_BUILD_METAL
|
||||
- ``True`` if MLX was built with Metal
|
@@ -45,6 +45,7 @@ are the CPU and GPU.
|
||||
usage/numpy
|
||||
usage/distributed
|
||||
usage/using_streams
|
||||
usage/export
|
||||
|
||||
.. toctree::
|
||||
:caption: Examples
|
||||
@@ -61,6 +62,7 @@ are the CPU and GPU.
|
||||
python/array
|
||||
python/data_types
|
||||
python/devices_and_streams
|
||||
python/export
|
||||
python/ops
|
||||
python/random
|
||||
python/transforms
|
||||
@@ -68,6 +70,7 @@ are the CPU and GPU.
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
@@ -86,3 +89,4 @@ are the CPU and GPU.
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
dev/custom_metal_kernels
|
||||
dev/mlx_in_cpp
|
||||
|
@@ -1,3 +1,5 @@
|
||||
.. _build_and_install:
|
||||
|
||||
Build and Install
|
||||
=================
|
||||
|
||||
@@ -14,7 +16,7 @@ silicon computer is
|
||||
To install from PyPI you must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.8
|
||||
- Using a native Python >= 3.9
|
||||
- macOS >= 13.5
|
||||
|
||||
.. note::
|
||||
@@ -53,7 +55,7 @@ Build Requirements
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
|
||||
- Xcode >= 15.0 and macOS SDK >= 14.0
|
||||
|
||||
.. note::
|
||||
@@ -209,7 +211,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists accross reboots.
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
@@ -240,7 +242,7 @@ x86 Shell
|
||||
|
||||
.. _build shell:
|
||||
|
||||
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||
Rosetta instead of natively.
|
||||
|
||||
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
||||
@@ -264,4 +266,4 @@ Also check that cmake is using the correct architecture:
|
||||
|
||||
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
||||
but the build errors out with "Building for x86_64 on macOS is not supported."
|
||||
wipe your build cahce with ``rm -rf build/`` and try again.
|
||||
wipe your build cache with ``rm -rf build/`` and try again.
|
||||
|
@@ -38,6 +38,7 @@ Array
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logcumsumexp
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
|
@@ -51,11 +51,20 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``float32``
|
||||
- 4
|
||||
- 32-bit float
|
||||
* - ``float64``
|
||||
- 4
|
||||
- 64-bit double
|
||||
* - ``complex64``
|
||||
- 8
|
||||
- 64-bit complex float
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Arrays with type ``float64`` only work with CPU operations. Using
|
||||
``float64`` arrays on the GPU will result in an exception.
|
||||
|
||||
|
||||
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
``dtype`` (or category) is a subtype of another category.
|
||||
@@ -66,3 +75,4 @@ documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
Dtype
|
||||
DtypeCategory
|
||||
issubdtype
|
||||
finfo
|
||||
|
14
docs/src/python/export.rst
Normal file
14
docs/src/python/export.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
.. _export:
|
||||
|
||||
Export Functions
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
export_function
|
||||
import_function
|
||||
exporter
|
||||
export_to_dot
|
@@ -12,5 +12,4 @@ Fast
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
affine_quantize
|
||||
metal_kernel
|
||||
|
@@ -5,8 +5,8 @@ Linear Algebra
|
||||
|
||||
.. currentmodule:: mlx.core.linalg
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
inv
|
||||
tri_inv
|
||||
@@ -16,3 +16,10 @@ Linear Algebra
|
||||
cross
|
||||
qr
|
||||
svd
|
||||
eigvalsh
|
||||
eigh
|
||||
lu
|
||||
lu_factor
|
||||
pinv
|
||||
solve
|
||||
solve_triangular
|
||||
|
16
docs/src/python/memory_management.rst
Normal file
16
docs/src/python/memory_management.rst
Normal file
@@ -0,0 +1,16 @@
|
||||
Memory Management
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
@@ -8,12 +8,5 @@ Metal
|
||||
|
||||
is_available
|
||||
device_info
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
clear_cache
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@@ -174,6 +174,7 @@ In detail:
|
||||
|
||||
value_and_grad
|
||||
quantize
|
||||
average_gradients
|
||||
|
||||
.. toctree::
|
||||
|
||||
|
@@ -12,6 +12,7 @@ Layers
|
||||
ALiBi
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
AvgPool3d
|
||||
BatchNorm
|
||||
CELU
|
||||
Conv1d
|
||||
@@ -41,6 +42,7 @@ Layers
|
||||
LSTM
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
MaxPool3d
|
||||
Mish
|
||||
MultiHeadAttention
|
||||
PReLU
|
||||
|
@@ -32,13 +32,16 @@ Operations
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
bitwise_and
|
||||
bitwise_invert
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
broadcast_arrays
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
contiguous
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
@@ -80,6 +83,7 @@ Operations
|
||||
greater_equal
|
||||
hadamard_transform
|
||||
identity
|
||||
imag
|
||||
inner
|
||||
isfinite
|
||||
isclose
|
||||
@@ -88,6 +92,7 @@ Operations
|
||||
isneginf
|
||||
isposinf
|
||||
issubdtype
|
||||
kron
|
||||
left_shift
|
||||
less
|
||||
less_equal
|
||||
@@ -98,6 +103,7 @@ Operations
|
||||
log10
|
||||
log1p
|
||||
logaddexp
|
||||
logcumsumexp
|
||||
logical_not
|
||||
logical_and
|
||||
logical_or
|
||||
@@ -125,11 +131,13 @@ Operations
|
||||
quantize
|
||||
quantized_matmul
|
||||
radians
|
||||
real
|
||||
reciprocal
|
||||
remainder
|
||||
repeat
|
||||
reshape
|
||||
right_shift
|
||||
roll
|
||||
round
|
||||
rsqrt
|
||||
save
|
||||
@@ -141,6 +149,8 @@ Operations
|
||||
sign
|
||||
sin
|
||||
sinh
|
||||
slice
|
||||
slice_update
|
||||
softmax
|
||||
sort
|
||||
split
|
||||
@@ -165,6 +175,7 @@ Operations
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
unflatten
|
||||
var
|
||||
view
|
||||
where
|
||||
|
@@ -18,3 +18,4 @@ Common Optimizers
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
|
@@ -45,3 +45,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||
truncated_normal
|
||||
uniform
|
||||
laplace
|
||||
permutation
|
||||
|
@@ -9,6 +9,7 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
async_eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
|
@@ -33,12 +33,12 @@ Let's start with a simple example:
|
||||
# Compile the function
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
The output of both the regular function and the compiled function is the same
|
||||
up to numerical precision.
|
||||
|
||||
|
||||
The first time you call a compiled function, MLX will build the compute
|
||||
graph, optimize it, and generate and compile code. This can be relatively
|
||||
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||
@@ -96,7 +96,7 @@ element-wise operations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def gelu(x):
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
If you use this function with small arrays, it will be overhead bound. If you
|
||||
@@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
|
||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||
five times faster.
|
||||
|
||||
.. note::
|
||||
|
||||
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
|
||||
functions can still be helpful, but won't typically result in as large a
|
||||
speedup as compiling operations that run on the GPU.
|
||||
|
||||
|
||||
Debugging
|
||||
---------
|
||||
|
||||
@@ -287,7 +280,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
|
||||
Compiling Training Graphs
|
||||
Compiling Training Graphs
|
||||
-------------------------
|
||||
|
||||
This section will step through how to use :func:`compile` with a simple example
|
||||
@@ -297,7 +290,7 @@ full forward, backward, and update with :func:`compile`.
|
||||
|
||||
To start, here is the simple example without any compilation:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -330,7 +323,7 @@ To start, here is the simple example without any compilation:
|
||||
To compile the update we can put it all in a function and compile it with the
|
||||
appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -355,7 +348,7 @@ appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
# The state that will be captured as input and output
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x, y):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
@@ -410,7 +403,7 @@ Compiling transformed functions works just as expected:
|
||||
|
||||
In order to compile as much as possible, a transformation of a compiled
|
||||
function will not by default be compiled. To compile the transformed
|
||||
function simply pass it through :func:`compile`.
|
||||
function simply pass it through :func:`compile`.
|
||||
|
||||
You can also compile functions which themselves call compiled functions. A
|
||||
good practice is to compile the outer most function to give :func:`compile`
|
||||
@@ -428,3 +421,77 @@ the most opportunity to optimize the computation graph:
|
||||
# Compiling the outer function is good to do as it will likely
|
||||
# be faster even though the inner functions are compiled
|
||||
fun = mx.compile(outer)
|
||||
|
||||
|
||||
|
||||
.. _shapeless_compile:
|
||||
|
||||
Shapeless Compilation
|
||||
---------------------
|
||||
|
||||
When the shape of an input to a compiled function changes, the function is
|
||||
recompiled. You can compile a function once and run it on inputs with
|
||||
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
|
||||
case changes to the shapes of the inputs do not cause the function to be
|
||||
recompiled.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.abs(x + y)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(-2.0)
|
||||
|
||||
# Firt call compiles the function
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
# Second call with different shapes
|
||||
# does not recompile the function
|
||||
x = mx.array([1.0, -6.0])
|
||||
y = mx.array([-2.0, 3.0])
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
|
||||
Use shapeless compilations carefully. Since compilation is not triggered when
|
||||
shapes change, any graphs which are conditional on the input shapes will not
|
||||
work as expected. Shape-dependent computations are common and sometimes subtle
|
||||
to detect. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return x.reshape(x.shape[0] * x.shape[1], -1)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.random.uniform(shape=(2, 3, 4))
|
||||
|
||||
out = compiled_fun(x)
|
||||
|
||||
x = mx.random.uniform(shape=(5, 5, 3))
|
||||
|
||||
# Error, can't reshape (5, 5, 3) to (6, -1)
|
||||
out = compiled_fun(x)
|
||||
|
||||
The second call to the ``compiled_fun`` fails because of the call to
|
||||
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
|
||||
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return x.flatten(0, 1)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.random.uniform(shape=(2, 3, 4))
|
||||
|
||||
out = compiled_fun(x)
|
||||
|
||||
x = mx.random.uniform(shape=(5, 5, 3))
|
||||
|
||||
# Ok
|
||||
out = compiled_fun(x)
|
||||
|
@@ -5,21 +5,27 @@ Distributed Communication
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
||||
provide distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. You can
|
||||
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
||||
MLX supports distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. At the
|
||||
moment we support two different communication backends:
|
||||
|
||||
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
||||
full-featured and mature distributed communications library
|
||||
* A **ring** backend of our own that uses native TCP sockets and should be
|
||||
faster for thunderbolt connections.
|
||||
|
||||
The list of all currently supported operations and their documentation can be
|
||||
seen in the :ref:`API docs<distributed>`.
|
||||
|
||||
.. note::
|
||||
A lot of operations may not be supported or not as fast as they should be.
|
||||
Some operations may not be supported or not as fast as they should be.
|
||||
We are adding more and tuning the ones we have as we are figuring out the
|
||||
best way to do distributed computing on Macs using MLX.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. The minimal distributed program in MLX is as simple as:
|
||||
A distributed program in MLX is as simple as:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@@ -30,74 +36,79 @@ machine. The minimal distributed program in MLX is as simple as:
|
||||
print(world.rank(), x)
|
||||
|
||||
The program above sums the array ``mx.ones(10)`` across all
|
||||
distributed processes. If simply run with ``python``, however, only one
|
||||
process is launched and no distributed communication takes place.
|
||||
distributed processes. However, when this script is run with ``python`` only
|
||||
one process is launched and no distributed communication takes place. Namely,
|
||||
all operations in ``mx.distributed`` are noops when the distributed group has a
|
||||
size of one. This property allows us to avoid code that checks if we are in a
|
||||
distributed setting similar to the one below:
|
||||
|
||||
To launch the program in distributed mode we need to use ``mpirun`` or
|
||||
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
||||
following:
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
x = ...
|
||||
world = mx.distributed.init()
|
||||
# No need for the check we can simply do x = mx.distributed.all_sum(x)
|
||||
if world.size() > 1:
|
||||
x = mx.distributed.all_sum(x)
|
||||
|
||||
Running Distributed Programs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
|
||||
Continuing with our initial example we can run it on localhost with 4 processes using
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 python test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
$ mlx.launch -n 4 my_script.py
|
||||
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
---------------
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
We can also run it on some remote hosts by providing their IPs (provided that
|
||||
the script exists on all hosts and they are reachable by ssh)
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ conda install openmpi
|
||||
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
|
||||
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
||||
Consult the dedicated :doc:`usage guide<launching_distributed>` for more
|
||||
information on using ``mlx.launch``.
|
||||
|
||||
.. code:: shell
|
||||
Selecting Backend
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
-----------------------
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
||||
full path to force all machines to use a specific path.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
You can select the backend you want to use when calling :func:`init` by passing
|
||||
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
|
||||
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
|
||||
both fail then a singleton group is created.
|
||||
|
||||
.. note::
|
||||
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
||||
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
||||
After a distributed backend is successfully initialized :func:`init` will
|
||||
return **the same backend** if called without arguments or with backend set to
|
||||
``any``.
|
||||
|
||||
An easy way to pass the host names to MPI is using a host file. A host file
|
||||
looks like the following, where ``host1`` and ``host2`` should be the fully
|
||||
qualified domain names or IPs for these hosts.
|
||||
The following examples aim to clarify the backend initialization logic in MLX:
|
||||
|
||||
.. code::
|
||||
.. code:: python
|
||||
|
||||
host1 slots=1
|
||||
host2 slots=1
|
||||
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
|
||||
world = mx.distributed.init(backend="mpi")
|
||||
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
|
||||
|
||||
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
||||
process per host. The hostfile also needs to contain the current
|
||||
host if you want to run on the local host. Passing the host file to
|
||||
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
||||
# Case 2: Initialize any backend
|
||||
world = mx.distributed.init(backend="any") # equivalent to no arguments
|
||||
world2 = mx.distributed.init() # same as above
|
||||
|
||||
# Case 3: Initialize both backends at the same time
|
||||
world_mpi = mx.distributed.init(backend="mpi")
|
||||
world_ring = mx.distributed.init(backend="ring")
|
||||
world_any = mx.distributed.init() # same as MPI because it was initialized first!
|
||||
|
||||
Training Example
|
||||
----------------
|
||||
@@ -141,12 +152,13 @@ everything else remaining the same.
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def all_reduce_grads(grads):
|
||||
N = mx.distributed.init()
|
||||
N = mx.distributed.init().size()
|
||||
if N == 1:
|
||||
return grads
|
||||
return tree_map(
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads)
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads
|
||||
)
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
@@ -154,13 +166,179 @@ everything else remaining the same.
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
Tuning All Reduce
|
||||
-----------------
|
||||
Utilizing ``nn.average_gradients``
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We are working on improving the performance of all reduce on MLX but for now
|
||||
the two main things one can do to extract the most out of distributed training with MLX are:
|
||||
Although the code example above works correctly; it performs one communication
|
||||
per gradient. It is significantly more efficient to aggregate several gradients
|
||||
together and perform fewer communication steps.
|
||||
|
||||
1. Perform a few large reductions instead of many small ones to improve
|
||||
bandwidth and latency
|
||||
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
||||
connections between each host to improve bandwidth
|
||||
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
|
||||
almost identical to the example above:
|
||||
|
||||
.. code:: python
|
||||
|
||||
model = ...
|
||||
optimizer = ...
|
||||
dataset = ...
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(model, x, y)
|
||||
mx.eval(loss, model.parameters())
|
||||
|
||||
|
||||
Getting Started with MPI
|
||||
------------------------
|
||||
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. Launching distributed MLX programs that use MPI can be done with
|
||||
``mpirun`` as expected. However, in the following examples we will be using
|
||||
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
|
||||
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
|
||||
library.
|
||||
|
||||
The simplest possible usage is the following which, assuming the minimal
|
||||
example in the beginning of this page, should result in:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mlx.launch --backend mpi -n 2 test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ conda install conda-forge::openmpi
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
||||
done automatically by ``mlx.launch``.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
$ # or simply
|
||||
$ mlx.launch -n 2 test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
|
||||
Tuning MPI All Reduce
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. note::
|
||||
|
||||
For faster all reduce consider using the ring backend either with Thunderbolt
|
||||
connections or over Ethernet.
|
||||
|
||||
Configure MPI to use N tcp connections between each host to improve bandwidth
|
||||
by passing ``--mca btl_tcp_links N``.
|
||||
|
||||
Force MPI to use the most performant network interface by setting ``--mca
|
||||
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
|
||||
to use.
|
||||
|
||||
Getting Started with Ring
|
||||
-------------------------
|
||||
|
||||
The ring backend does not depend on any third party library so it is always
|
||||
available. It uses TCP sockets so the nodes need to be reachable via a network.
|
||||
As the name suggests the nodes are connected in a ring which means that rank 1
|
||||
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
|
||||
and so on and so forth. As a result :func:`send` and :func:`recv` with
|
||||
arbitrary sender and receiver is not supported in the ring backend.
|
||||
|
||||
Defining a Ring
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
The easiest way to define and use a ring is via a JSON hostfile and the
|
||||
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
|
||||
defines a hostname to ssh into to run commands on this node and one or more IPs
|
||||
that this node will listen to for connections.
|
||||
|
||||
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
|
||||
rank 0, ``hostname2`` rank 1 etc.
|
||||
|
||||
.. code:: json
|
||||
|
||||
[
|
||||
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
||||
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
||||
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
|
||||
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
|
||||
]
|
||||
|
||||
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
|
||||
node, run the script which will listen for connections in each of the provided
|
||||
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
|
||||
connection from ``123.123.123.4`` and so on and so forth.
|
||||
|
||||
Thunderbolt Ring
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Although the ring backend can have benefits over MPI even for Ethernet, its
|
||||
main purpose is to use Thunderbolt rings for higher bandwidth communication.
|
||||
Setting up such thunderbolt rings can be done manually, but is a relatively
|
||||
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
|
||||
|
||||
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
|
||||
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
|
||||
utility as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
|
||||
|
||||
By default the script will attempt to discover the thunderbolt ring and provide
|
||||
you with the commands to configure each node as well as the ``hostfile.json``
|
||||
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
|
||||
then ``--auto-setup`` can be used to configure them automatically.
|
||||
|
||||
To validate your connection without configuring anything
|
||||
``mlx.distributed_config`` can also plot the ring using DOT format.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
|
||||
dot -Tpng ring.dot >ring.png
|
||||
open ring.png
|
||||
|
||||
If you want to go through the process manually, the steps are as follows:
|
||||
|
||||
* Disable the thunderbolt bridge interface
|
||||
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
|
||||
corresponding to that cable in nodes ``i`` and ``i + 1``.
|
||||
* Set up a unique subnetwork connecting the two nodes for the corresponding
|
||||
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
|
||||
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
|
||||
``192.168.0.2`` respectively to the two nodes. For more details you can see
|
||||
the commands prepared by the utility script.
|
||||
|
288
docs/src/usage/export.rst
Normal file
288
docs/src/usage/export.rst
Normal file
@@ -0,0 +1,288 @@
|
||||
.. _export_usage:
|
||||
|
||||
Exporting Functions
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX has an API to export and import functions to and from a file. This lets you
|
||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||
front-end (e.g. C++).
|
||||
|
||||
This guide walks through the basics of the MLX export API with some examples.
|
||||
To see the full list of functions check-out the :ref:`API documentation
|
||||
<export>`.
|
||||
|
||||
Basics of Exporting
|
||||
-------------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
To export a function, provide sample input arrays that the function
|
||||
can be called with. The data doesn't matter, but the shapes and types of the
|
||||
arrays do. In the above example we exported ``fun`` with two ``float32``
|
||||
scalar arrays. We can then import the function and run it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
add_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
out, = add_fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints: array(3, dtype=float32)
|
||||
print(out)
|
||||
|
||||
out, = add_fun(mx.array(1.0), mx.array(3.0))
|
||||
# Prints: array(4, dtype=float32)
|
||||
print(out)
|
||||
|
||||
# Raises an exception
|
||||
add_fun(mx.array(1), mx.array(3.0))
|
||||
|
||||
# Raises an exception
|
||||
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
|
||||
|
||||
Notice the third and fourth calls to ``add_fun`` raise exceptions because the
|
||||
shapes and types of the inputs are different than the shapes and types of the
|
||||
example inputs we exported the function with.
|
||||
|
||||
Also notice that even though the original ``fun`` returns a single output
|
||||
array, the imported function always returns a tuple of one or more arrays.
|
||||
|
||||
The inputs to :func:`export_function` and to an imported function can be
|
||||
specified as variable positional arguments or as a tuple of arrays:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
# Both arguments to fun are positional
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
# Same as above
|
||||
mx.export_function("add.mlxfn", fun, (x, y))
|
||||
|
||||
imported_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_fun(x, y)
|
||||
|
||||
# Also ok
|
||||
out, = imported_fun((x, y))
|
||||
|
||||
You can pass example inputs to functions as positional or keyword arguments. If
|
||||
you use keyword arguments to export the function, then you have to use the same
|
||||
keyword arguments when calling the imported function.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
# One argument to fun is positional, the other is a kwarg
|
||||
mx.export_function("add.mlxfn", fun, x, y=y)
|
||||
|
||||
imported_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_fun(x, y=y)
|
||||
|
||||
# Also ok
|
||||
out, = imported_fun((x,), {"y": y})
|
||||
|
||||
# Raises since the keyword argument is missing
|
||||
out, = imported_fun(x, y)
|
||||
|
||||
# Raises since the keyword argument has the wrong key
|
||||
out, = imported_fun(x, z=y)
|
||||
|
||||
|
||||
Exporting Modules
|
||||
-----------------
|
||||
|
||||
An :obj:`mlx.nn.Module` can be exported with or without the parameters included
|
||||
in the exported function. Here's an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Linear(4, 4)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def call(x):
|
||||
return model(x)
|
||||
|
||||
mx.export_function("model.mlxfn", call, mx.zeros(4))
|
||||
|
||||
In the above example, the :obj:`mlx.nn.Linear` module is exported. Its
|
||||
parameters are also saved to the ``model.mlxfn`` file.
|
||||
|
||||
.. note::
|
||||
|
||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||
they are evaluated. The computation graph that gets exported will include
|
||||
the computation that produces enclosed inputs.
|
||||
|
||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||
exported function would include the random initialization of the
|
||||
:obj:`mlx.nn.Module` parameters.
|
||||
|
||||
If you only want to export the ``Module.__call__`` function without the
|
||||
parameters, pass them as inputs to the ``call`` wrapper:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Linear(4, 4)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def call(x, **params):
|
||||
# Set the model's parameters to the input parameters
|
||||
model.update(tree_unflatten(list(params.items())))
|
||||
return model(x)
|
||||
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||
|
||||
|
||||
Shapeless Exports
|
||||
-----------------
|
||||
|
||||
Just like :func:`compile`, functions can also be exported for dynamically shaped
|
||||
inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`
|
||||
to export a function which can be used for inputs with variable shapes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
|
||||
imported_abs = mx.import_function("fun.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_abs(mx.array(-1.0))
|
||||
|
||||
# Also ok
|
||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||
|
||||
With ``shapeless=False`` (which is the default), the second call to
|
||||
``imported_abs`` would raise an exception with a shape mismatch.
|
||||
|
||||
Shapeless exporting works the same as shapeless compilation and should be
|
||||
used carefully. See the :ref:`documentation on shapeless compilation
|
||||
<shapeless_compile>` for more information.
|
||||
|
||||
Exporting Multiple Traces
|
||||
-------------------------
|
||||
|
||||
In some cases, functions build different computation graphs for different
|
||||
input arguments. A simple way to manage this is to export to a new file with
|
||||
each set of inputs. This is a fine option in many cases. But it can be
|
||||
suboptimal if the exported functions have a large amount of duplicate constant
|
||||
data (for example the parameters of a :obj:`mlx.nn.Module`).
|
||||
|
||||
The export API in MLX lets you export multiple traces of the same function to
|
||||
a single file by creating an exporting context manager with :func:`exporter`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y=None):
|
||||
constant = mx.array(3.0)
|
||||
if y is not None:
|
||||
x += y
|
||||
return x + constant
|
||||
|
||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||
exporter(mx.array(1.0))
|
||||
exporter(mx.array(1.0), y=mx.array(0.0))
|
||||
|
||||
imported_function = mx.import_function("fun.mlxfn")
|
||||
|
||||
# Call the function with y=None
|
||||
out, = imported_function(mx.array(1.0))
|
||||
print(out)
|
||||
|
||||
# Call the function with y specified
|
||||
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
|
||||
print(out)
|
||||
|
||||
In the above example the function constant data, (i.e. ``constant``), is only
|
||||
saved once.
|
||||
|
||||
Transformations with Imported Functions
|
||||
---------------------------------------
|
||||
|
||||
Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work
|
||||
on imported functions just like regular Python functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return mx.sin(x)
|
||||
|
||||
x = mx.array(0.0)
|
||||
mx.export_function("sine.mlxfn", fun, x)
|
||||
|
||||
imported_fun = mx.import_function("sine.mlxfn")
|
||||
|
||||
# Take the derivative of the imported function
|
||||
dfdx = mx.grad(lambda x: imported_fun(x)[0])
|
||||
# Prints: array(1, dtype=float32)
|
||||
print(dfdx(x))
|
||||
|
||||
# Compile the imported function
|
||||
mx.compile(imported_fun)
|
||||
# Prints: array(0, dtype=float32)
|
||||
print(compiled_fun(x)[0])
|
||||
|
||||
|
||||
Importing Functions in C++
|
||||
--------------------------
|
||||
|
||||
Importing and running functions in C++ is basically the same as importing and
|
||||
running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to
|
||||
setup a simple C++ project that uses MLX as a library.
|
||||
|
||||
Next, export a simple function from Python:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(x + y)
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
mx.export_function("fun.mlxfn", fun, x, y)
|
||||
|
||||
|
||||
Import and run the function in C++ with only a few lines of code:
|
||||
|
||||
.. code-block:: c++
|
||||
|
||||
auto fun = mx::import_function("fun.mlxfn");
|
||||
|
||||
auto inputs = {mx::array(1.0), mx::array(1.0)};
|
||||
auto outputs = fun(inputs);
|
||||
|
||||
// Prints: array(2, dtype=float32)
|
||||
std::cout << outputs[0] << std::endl;
|
||||
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||
|
||||
More Examples
|
||||
-------------
|
||||
|
||||
Here are a few more complete examples exporting more complex functions from
|
||||
Python and importing and running them in C++:
|
||||
|
||||
* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_
|
@@ -25,7 +25,7 @@ Here is a simple example:
|
||||
|
||||
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
||||
case it is the gradient of the sine function which is exactly the cosine
|
||||
function. To get the second derivative you can do:
|
||||
function. To get the second derivative you can do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
@@ -50,7 +50,7 @@ Automatic Differentiation
|
||||
.. _auto diff:
|
||||
|
||||
Automatic differentiation in MLX works on functions rather than on implicit
|
||||
graphs.
|
||||
graphs.
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -114,7 +114,7 @@ way to do that is the following:
|
||||
|
||||
def loss_fn(params, x, y):
|
||||
w, b = params["weight"], params["bias"]
|
||||
h = w * x + b
|
||||
h = w * x + b
|
||||
return mx.mean(mx.square(h - y))
|
||||
|
||||
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
||||
@@ -132,7 +132,7 @@ way to do that is the following:
|
||||
|
||||
Notice the tree structure of the parameters is preserved in the gradients.
|
||||
|
||||
In some cases you may want to stop gradients from propagating through a
|
||||
In some cases you may want to stop gradients from propagating through a
|
||||
part of the function. You can use the :func:`stop_gradient` for that.
|
||||
|
||||
|
||||
@@ -161,19 +161,19 @@ A naive way to add the elements from two sets of vectors is with a loop:
|
||||
ys = mx.random.uniform(shape=(100, 4096))
|
||||
|
||||
def naive_add(xs, ys):
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
|
||||
|
||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
# Vectorize over the second dimension of x and the
|
||||
# first dimension of y
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
|
||||
|
||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||
where the vectorized axes should be in the outputs.
|
||||
where the vectorized axes should be in the outputs.
|
||||
|
||||
Let's time these two different versions:
|
||||
|
||||
@@ -184,8 +184,8 @@ Let's time these two different versions:
|
||||
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||
|
||||
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
|
||||
vectorized version takes only ``0.025`` seconds, more than ten times faster.
|
||||
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
|
||||
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
|
||||
|
||||
Of course, this operation is quite contrived. A better approach is to simply do
|
||||
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
||||
|
@@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
|
||||
.. code-block:: shell
|
||||
|
||||
>>> arr = mx.arange(10)
|
||||
>>> idx = mx.array([5, 7])
|
||||
>>> idx = mx.array([5, 7])
|
||||
>>> arr[idx]
|
||||
array([5, 7], dtype=int32)
|
||||
|
||||
@@ -77,12 +77,12 @@ from the GPU. Performing bounds checking for array indices before launching the
|
||||
kernel would be extremely inefficient.
|
||||
|
||||
Indexing with boolean masks is something that MLX may support in the future. In
|
||||
general, MLX has limited support for operations for which outputs
|
||||
general, MLX has limited support for operations for which output
|
||||
*shapes* are dependent on input *data*. Other examples of these types of
|
||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||
single input version of :func:`numpy.where`.
|
||||
|
||||
In Place Updates
|
||||
In Place Updates
|
||||
----------------
|
||||
|
||||
In place updates to indexed arrays are possible in MLX. For example:
|
||||
|
105
docs/src/usage/launching_distributed.rst
Normal file
105
docs/src/usage/launching_distributed.rst
Normal file
@@ -0,0 +1,105 @@
|
||||
:orphan:
|
||||
|
||||
.. _usage_launch_distributed:
|
||||
|
||||
Launching Distributed Programs
|
||||
==============================
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
Installing the MLX python package provides a helper script ``mlx.launch`` that
|
||||
can be used to run python scripts distributed on several nodes. It allows
|
||||
launching using either the MPI backend or the ring backend. See the
|
||||
:doc:`distributed docs <distributed>` for the different backends.
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
The minimal usage example of ``mlx.launch`` is simply
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.launch --hosts ip1,ip2 my_script.py
|
||||
|
||||
or for testing on localhost
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.launch -n 2 my_script.py
|
||||
|
||||
The ``mlx.launch`` command connects to the provided host and launches the input
|
||||
script on each host. It monitors each of the launched processes and terminates
|
||||
the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
|
||||
It also takes care of forwarding the output of each remote process to stdout
|
||||
and stderr respectively.
|
||||
|
||||
Providing Hosts
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Hosts can be provided as command line arguments, like above, but the way that
|
||||
allows to fully define a list of hosts is via a JSON hostfile. The hostfile has
|
||||
a very simple schema. It is simply a list of objects that define each host via
|
||||
a hostname to ssh to and a list of IPs to utilize for the communication.
|
||||
|
||||
.. code:: json
|
||||
|
||||
[
|
||||
{"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
|
||||
{"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
|
||||
]
|
||||
|
||||
You can use ``mlx.distributed_config --over ethernet`` to create a hostfile
|
||||
with IPs corresponding to the ``en0`` interface.
|
||||
|
||||
Setting up Remote Hosts
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
In order to be able to launch the script on each host we need to be able to
|
||||
connect via ssh. Moreover the input script and python binary need to be on each
|
||||
host and on the same path. A good checklist to debug errors is the following:
|
||||
|
||||
* ``ssh hostname`` works without asking for password or host confirmation
|
||||
* the python binary is available on all hosts at the same path. You can use
|
||||
``mlx.launch --print-python`` to see what that path is.
|
||||
* the script you want to run is available on all hosts at the same path
|
||||
|
||||
.. _mpi_specifics:
|
||||
|
||||
MPI Specifics
|
||||
-------------
|
||||
|
||||
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
|
||||
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
|
||||
|
||||
* The IPs in the hostfile are ignored
|
||||
* The ssh connectivity requirement is stronger as every node needs to be able
|
||||
to connect to every other node
|
||||
* ``mpirun`` needs to be available on every node at the same path
|
||||
|
||||
Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance
|
||||
to choose a specific interface for the byte-transfer-layer of MPI we can call
|
||||
``mlx.launch`` as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
|
||||
|
||||
|
||||
.. _ring_specifics:
|
||||
|
||||
Ring Specifics
|
||||
--------------
|
||||
|
||||
The ring backend, which is also the default backend, can be explicitly selected
|
||||
with the argument ``--backend ring``. The ring backend has some specific
|
||||
requirements and arguments that are different to MPI:
|
||||
|
||||
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
|
||||
ssh to a hostname that does not correspond to the IP we want to bind to we
|
||||
have to provide a hostfile.
|
||||
* ``--starting-port`` defines the port to bind to on the remote hosts.
|
||||
Specifically rank 0 for the first IP will use this port and each subsequent
|
||||
IP or rank will add 1 to this port.
|
||||
* ``--connections-per-ip`` allows us to increase the number of connections
|
||||
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
|
||||
``mpirun``.
|
@@ -13,7 +13,7 @@ compute graph is recorded. The actual computation only happens if an
|
||||
:func:`eval` is performed.
|
||||
|
||||
MLX uses lazy evaluation because it has some nice features, some of which we
|
||||
describe below.
|
||||
describe below.
|
||||
|
||||
Transforming Compute Graphs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@@ -109,14 +109,14 @@ Here is a concrete example:
|
||||
|
||||
An important behavior to be aware of is when the graph will be implicitly
|
||||
evaluated. Anytime you ``print`` an array, convert it to an
|
||||
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
||||
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
|
||||
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||
saving functions) will also evaluate the array.
|
||||
|
||||
|
||||
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
||||
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
||||
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
||||
will be a partial evaluation, computing only the forward pass.
|
||||
|
||||
|
@@ -3,10 +3,10 @@
|
||||
Conversion to NumPy and Other Frameworks
|
||||
========================================
|
||||
|
||||
MLX array supports conversion between other frameworks with either:
|
||||
MLX array supports conversion between other frameworks with either:
|
||||
|
||||
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||
|
||||
Let's convert an array to NumPy and back.
|
||||
|
||||
@@ -21,11 +21,13 @@ Let's convert an array to NumPy and back.
|
||||
|
||||
.. note::
|
||||
|
||||
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
|
||||
``np.array(a.astype(mx.float32))``.
|
||||
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
|
||||
Since NumPy does not support ``bfloat16`` arrays, you will need to convert
|
||||
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
|
||||
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
|
||||
buffer format string does not match the dtype V item size 0.``
|
||||
|
||||
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
|
||||
By default, NumPy copies data to a new array. This can be prevented by creating
|
||||
an array view:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -35,10 +37,16 @@ By default, NumPy copies data to a new array. This can be prevented by creating
|
||||
a_view[0] = 1
|
||||
print(a[0].item()) # 1
|
||||
|
||||
A NumPy array view is a normal NumPy array, except that it does not own its memory.
|
||||
This means writing to the view is reflected in the original array.
|
||||
.. note::
|
||||
|
||||
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
|
||||
NumPy arrays with type ``float64`` will be default converted to MLX arrays
|
||||
with type ``float32``.
|
||||
|
||||
A NumPy array view is a normal NumPy array, except that it does not own its
|
||||
memory. This means writing to the view is reflected in the original array.
|
||||
|
||||
While this is quite powerful to prevent copying arrays, it should be noted that
|
||||
external changes to the memory of arrays cannot be reflected in gradients.
|
||||
|
||||
Let's demonstrate this in an example:
|
||||
|
||||
@@ -56,22 +64,24 @@ Let's demonstrate this in an example:
|
||||
|
||||
|
||||
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
||||
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
|
||||
representing the gradient of the sum operation alone.
|
||||
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
|
||||
It's important to note that a similar issue arises during array conversion and copying.
|
||||
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||
However, this modification is not reflected in the gradient, as seen in the
|
||||
last line outputting ``1.0``, representing the gradient of the sum operation
|
||||
alone. The squaring of ``x`` occurs externally to MLX, meaning that no
|
||||
gradient is incorporated. It's important to note that a similar issue arises
|
||||
during array conversion and copying. For instance, a function defined as
|
||||
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||
even though no in-place operations on MLX memory are executed.
|
||||
|
||||
PyTorch
|
||||
-------
|
||||
|
||||
.. warning::
|
||||
.. warning::
|
||||
|
||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||
|
||||
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||
PyTorch supports the buffer protocol, but it requires an explicit
|
||||
:obj:`memoryview`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -82,7 +92,8 @@ PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryvi
|
||||
b = torch.tensor(memoryview(a))
|
||||
c = mx.array(b.numpy())
|
||||
|
||||
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
|
||||
Conversion from PyTorch tensors back to arrays must be done via intermediate
|
||||
NumPy arrays with ``numpy()``.
|
||||
|
||||
JAX
|
||||
---
|
||||
@@ -100,7 +111,8 @@ JAX fully supports the buffer protocol.
|
||||
TensorFlow
|
||||
----------
|
||||
|
||||
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||
TensorFlow supports the buffer protocol, but it requires an explicit
|
||||
:obj:`memoryview`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
|
||||
and :func:`jvp` for Jacobian-vector products.
|
||||
|
||||
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
||||
gradient with respect to the function's input.
|
||||
gradient with respect to the function's input.
|
||||
|
@@ -8,33 +8,33 @@ Saving and Loading Arrays
|
||||
MLX supports multiple array serialization formats.
|
||||
|
||||
.. list-table:: Serialization Formats
|
||||
:widths: 20 8 25 25
|
||||
:widths: 20 8 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Format
|
||||
- Extension
|
||||
* - Format
|
||||
- Extension
|
||||
- Function
|
||||
- Notes
|
||||
* - NumPy
|
||||
- ``.npy``
|
||||
- Notes
|
||||
* - NumPy
|
||||
- ``.npy``
|
||||
- :func:`save`
|
||||
- Single arrays only
|
||||
* - NumPy archive
|
||||
- ``.npz``
|
||||
* - NumPy archive
|
||||
- ``.npz``
|
||||
- :func:`savez` and :func:`savez_compressed`
|
||||
- Multiple arrays
|
||||
- Multiple arrays
|
||||
* - Safetensors
|
||||
- ``.safetensors``
|
||||
- ``.safetensors``
|
||||
- :func:`save_safetensors`
|
||||
- Multiple arrays
|
||||
* - GGUF
|
||||
- ``.gguf``
|
||||
- Multiple arrays
|
||||
* - GGUF
|
||||
- ``.gguf``
|
||||
- :func:`save_gguf`
|
||||
- Multiple arrays
|
||||
|
||||
The :func:`load` function will load any of the supported serialization
|
||||
formats. It determines the format from the extensions. The output of
|
||||
:func:`load` depends on the format.
|
||||
:func:`load` depends on the format.
|
||||
|
||||
Here's an example of saving a single array to a file:
|
||||
|
||||
|
@@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
|
||||
|
||||
In MLX, rather than moving arrays to devices, you specify the device when you
|
||||
run the operation. Any device can perform any operation on ``a`` and ``b``
|
||||
without needing to move them from one memory location to another. For example:
|
||||
without needing to move them from one memory location to another. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
22
examples/cmake_project/CMakeLists.txt
Normal file
22
examples/cmake_project/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(example LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
# Comment the following two commands only the MLX C++ library is installed and
|
||||
# set(MLX_ROOT "/path/to/mlx") directly if needed.
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
add_executable(example example.cpp)
|
||||
target_link_libraries(example PRIVATE mlx)
|
26
examples/cmake_project/README.md
Normal file
26
examples/cmake_project/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
## Build and Run
|
||||
|
||||
Install MLX with Python:
|
||||
|
||||
```bash
|
||||
pip install mlx>=0.22
|
||||
```
|
||||
|
||||
Build the C++ example:
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
Run the C++ example:
|
||||
|
||||
```
|
||||
./build/example
|
||||
```
|
||||
|
||||
which should output:
|
||||
|
||||
```
|
||||
array([2, 4, 6], dtype=int32)
|
||||
```
|
14
examples/cmake_project/example.cpp
Normal file
14
examples/cmake_project/example.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
auto x = mx::array({1, 2, 3});
|
||||
auto y = mx::array({1, 2, 3});
|
||||
std::cout << x + y << std::endl;
|
||||
return 0;
|
||||
}
|
@@ -4,19 +4,19 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
if (!distributed::is_available()) {
|
||||
if (!mx::distributed::is_available()) {
|
||||
std::cout << "No communication backend found" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto global_group = distributed::init();
|
||||
auto global_group = mx::distributed::init();
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_sum(x, global_group);
|
||||
mx::array x = mx::ones({10});
|
||||
mx::array out = mx::distributed::all_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
||||
|
@@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of linear regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.01;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
|
||||
// The input examples (design matrix)
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
|
||||
// Noisy labels
|
||||
auto eps = 1e-2 * random::normal({num_examples});
|
||||
auto y = matmul(X, w_star) + eps;
|
||||
auto eps = 1e-2 * mx::random::normal({num_examples});
|
||||
auto y = mx::matmul(X, w_star) + eps;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto yhat = matmul(X, w);
|
||||
return (0.5f / num_examples) * sum(square(yhat - y));
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto yhat = mx::matmul(X, w);
|
||||
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
|
||||
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
||||
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
||||
|
@@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of logistic regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.1;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
|
||||
// The input examples
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
|
||||
// Labels
|
||||
auto y = matmul(X, w_star) > 0;
|
||||
auto y = mx::matmul(X, w_star) > 0;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto logits = matmul(X, w);
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto logits = mx::matmul(X, w);
|
||||
auto scale = (1.0f / num_examples);
|
||||
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
|
||||
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
|
||||
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
||||
<< throughput << " (it/s)." << std::endl;
|
||||
|
@@ -5,27 +5,27 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
// To use Metal debugging and profiling:
|
||||
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||
metal::start_capture("mlx_trace.gputrace");
|
||||
mx::metal::start_capture("mlx_trace.gputrace");
|
||||
|
||||
// Start at index two because the default GPU and CPU streams have indices
|
||||
// zero and one, respectively. This naming matches the label assigned to each
|
||||
// stream's command queue.
|
||||
auto s2 = new_stream(Device::gpu);
|
||||
auto s3 = new_stream(Device::gpu);
|
||||
auto s2 = new_stream(mx::Device::gpu);
|
||||
auto s3 = new_stream(mx::Device::gpu);
|
||||
|
||||
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||
auto x = add(a, a, s2);
|
||||
auto y = add(b, b, s3);
|
||||
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
|
||||
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
|
||||
auto x = mx::add(a, a, s2);
|
||||
auto y = mx::add(b, b, s3);
|
||||
|
||||
// The multiply will happen on the default stream.
|
||||
std::cout << multiply(x, y) << std::endl;
|
||||
std::cout << mx::multiply(x, y) << std::endl;
|
||||
|
||||
metal::stop_capture();
|
||||
mx::metal::stop_capture();
|
||||
}
|
||||
|
@@ -5,11 +5,11 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void array_basics() {
|
||||
// Make a scalar array:
|
||||
array x(1.0);
|
||||
mx::array x(1.0);
|
||||
|
||||
// Get the value out of it:
|
||||
auto s = x.item<float>();
|
||||
@@ -29,31 +29,31 @@ void array_basics() {
|
||||
|
||||
// The datatype should be float32:
|
||||
auto dtype = x.dtype();
|
||||
assert(dtype == float32);
|
||||
assert(dtype == mx::float32);
|
||||
|
||||
// Specify the dtype when constructing the array:
|
||||
x = array(1, int32);
|
||||
assert(x.dtype() == int32);
|
||||
x = mx::array(1, mx::int32);
|
||||
assert(x.dtype() == mx::int32);
|
||||
x.item<int>(); // OK
|
||||
// x.item<float>(); // Undefined!
|
||||
|
||||
// Make a multidimensional array:
|
||||
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
// mlx is row-major by default so the first row of this array
|
||||
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||
|
||||
// Make an array of shape {2, 2} filled with ones:
|
||||
auto y = ones({2, 2});
|
||||
auto y = mx::ones({2, 2});
|
||||
|
||||
// Pointwise add x and y:
|
||||
auto z = add(x, y);
|
||||
auto z = mx::add(x, y);
|
||||
|
||||
// Same thing:
|
||||
z = x + y;
|
||||
|
||||
// mlx is lazy by default. At this point `z` only
|
||||
// has a shape and a type but no actual data:
|
||||
assert(z.dtype() == float32);
|
||||
assert(z.dtype() == mx::float32);
|
||||
assert(z.shape(0) == 2);
|
||||
assert(z.shape(1) == 2);
|
||||
|
||||
@@ -63,33 +63,33 @@ void array_basics() {
|
||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||
// all of its dependencies are recursively evaluated to produce the result.
|
||||
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||
eval(z);
|
||||
mx::eval(z);
|
||||
|
||||
// Of course the array can still be an input to other operations. You can even
|
||||
// call eval on the array again, this will just be a no-op:
|
||||
eval(z); // no-op
|
||||
// Of course the array can still be an input to other operations. You can
|
||||
// even call eval on the array again, this will just be a no-op:
|
||||
mx::eval(z); // no-op
|
||||
|
||||
// Some functions or methods on arrays implicitly evaluate them. For example
|
||||
// accessing a value in an array or printing the array implicitly evaluate it:
|
||||
z = ones({1});
|
||||
z = mx::ones({1});
|
||||
z.item<float>(); // implicit evaluation
|
||||
|
||||
z = ones({2, 2});
|
||||
z = mx::ones({2, 2});
|
||||
std::cout << z << std::endl; // implicit evaluation
|
||||
}
|
||||
|
||||
void automatic_differentiation() {
|
||||
auto fn = [](array x) { return square(x); };
|
||||
auto fn = [](mx::array x) { return mx::square(x); };
|
||||
|
||||
// Computing the derivative function of a function
|
||||
auto grad_fn = grad(fn);
|
||||
auto grad_fn = mx::grad(fn);
|
||||
// Call grad_fn on the input to get the derivative
|
||||
auto x = array(1.5);
|
||||
auto x = mx::array(1.5);
|
||||
auto dfdx = grad_fn(x);
|
||||
// dfdx is 2 * x
|
||||
|
||||
// Get the second derivative by composing grad with grad
|
||||
auto d2fdx2 = grad(grad(fn))(x);
|
||||
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
|
||||
// d2fdx2 is 2
|
||||
}
|
||||
|
||||
|
22
examples/export/CMakeLists.txt
Normal file
22
examples/export/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(import_mlx LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
add_executable(eval_mlp eval_mlp.cpp)
|
||||
target_link_libraries(eval_mlp PRIVATE mlx)
|
||||
|
||||
add_executable(train_mlp train_mlp.cpp)
|
||||
target_link_libraries(train_mlp PRIVATE mlx)
|
49
examples/export/README.md
Normal file
49
examples/export/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
## Setup
|
||||
|
||||
Install MLX:
|
||||
|
||||
```bash
|
||||
pip install mlx>=0.22
|
||||
```
|
||||
|
||||
Build the C++ examples:
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
### Eval MLP
|
||||
|
||||
Run the Python script to export the eval function:
|
||||
|
||||
```bash
|
||||
python eval_mlp.py
|
||||
```
|
||||
|
||||
Then run the C++ program to import and run the function:
|
||||
|
||||
```
|
||||
./build/eval_mlp
|
||||
```
|
||||
|
||||
The Python and C++ programs should output the same result.
|
||||
|
||||
### Train MLP
|
||||
|
||||
Run the Python script to export the model initialization and training
|
||||
functions:
|
||||
|
||||
```bash
|
||||
python train_mlp.py
|
||||
```
|
||||
|
||||
Then run the C++ program to import and run the functions:
|
||||
|
||||
```
|
||||
./build/train_mlp
|
||||
```
|
||||
|
||||
The Python and C++ programs should output the same results.
|
25
examples/export/eval_mlp.cpp
Normal file
25
examples/export/eval_mlp.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
|
||||
// Make the input
|
||||
mx::random::seed(42);
|
||||
auto example_x = mx::random::uniform({batch_size, input_dim});
|
||||
|
||||
// Import the function
|
||||
auto forward = mx::import_function("eval_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
auto out = forward({example_x})[0];
|
||||
|
||||
std::cout << out << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
52
examples/export/eval_mlp.py
Normal file
52
examples/export/eval_mlp.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 8
|
||||
input_dim = 32
|
||||
output_dim = 10
|
||||
|
||||
# Load the model
|
||||
mx.random.seed(0) # Seed for params
|
||||
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
|
||||
mx.eval(model)
|
||||
|
||||
# Note, the model parameters are saved in the export function
|
||||
def forward(x):
|
||||
return model(x)
|
||||
|
||||
mx.random.seed(42) # Seed for input
|
||||
example_x = mx.random.uniform(shape=(batch_size, input_dim))
|
||||
|
||||
mx.export_function("eval_mlp.mlxfn", forward, example_x)
|
||||
|
||||
# Import in Python
|
||||
imported_forward = mx.import_function("eval_mlp.mlxfn")
|
||||
expected = forward(example_x)
|
||||
(out,) = imported_forward(example_x)
|
||||
assert mx.allclose(expected, out)
|
||||
print(out)
|
35
examples/export/train_mlp.cpp
Normal file
35
examples/export/train_mlp.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
int output_dim = 10;
|
||||
|
||||
auto state = mx::import_function("init_mlp.mlxfn")({});
|
||||
|
||||
// Make the input
|
||||
mx::random::seed(42);
|
||||
auto example_X = mx::random::normal({batch_size, input_dim});
|
||||
auto example_y = mx::random::randint(0, output_dim, {batch_size});
|
||||
|
||||
// Import the function
|
||||
auto step = mx::import_function("train_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
for (int it = 0; it < 100; ++it) {
|
||||
state.insert(state.end(), {example_X, example_y});
|
||||
state = step(state);
|
||||
eval(state);
|
||||
auto loss = state.back();
|
||||
state.pop_back();
|
||||
if (it % 10 == 0) {
|
||||
std::cout << "Loss " << loss.item<float>() << std::endl;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
76
examples/export/train_mlp.py
Normal file
76
examples/export/train_mlp.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 8
|
||||
input_dim = 32
|
||||
output_dim = 10
|
||||
|
||||
def init():
|
||||
# Seed for the parameter initialization
|
||||
mx.random.seed(0)
|
||||
model = MLP(
|
||||
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
|
||||
)
|
||||
optimizer = optim.SGD(learning_rate=1e-1)
|
||||
optimizer.init(model.parameters())
|
||||
state = [model.parameters(), optimizer.state]
|
||||
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
|
||||
return model, optimizer, tree_structure, state
|
||||
|
||||
# Export the model parameter initialization
|
||||
model, optimizer, tree_structure, state = init()
|
||||
mx.eval(state)
|
||||
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
|
||||
|
||||
def loss_fn(params, X, y):
|
||||
model.update(params)
|
||||
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||
|
||||
def step(*inputs):
|
||||
*state, X, y = inputs
|
||||
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
|
||||
optimizer.state = opt_state
|
||||
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
|
||||
params = optimizer.apply_gradients(grads, params)
|
||||
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
|
||||
return *state, loss
|
||||
|
||||
# Make some random data
|
||||
mx.random.seed(42)
|
||||
example_X = mx.random.normal(shape=(batch_size, input_dim))
|
||||
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
|
||||
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
|
||||
|
||||
# Export one step of SGD
|
||||
imported_step = mx.import_function("train_mlp.mlxfn")
|
||||
|
||||
for it in range(100):
|
||||
*state, loss = imported_step(*state, example_X, example_y)
|
||||
if it % 10 == 0:
|
||||
print(f"Loss {loss.item():.6}")
|
@@ -10,7 +10,6 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
find_package(
|
||||
Python 3.8
|
||||
COMPONENTS Interpreter Development.Module
|
||||
@@ -18,10 +17,15 @@ find_package(
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
OUTPUT_VARIABLE nanobind_ROOT)
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
# ----------------------------- Extensions -----------------------------
|
||||
|
||||
# Add library
|
||||
|
@@ -1,25 +1,20 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include "axpby/axpby.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/cblas_new.h>
|
||||
#endif
|
||||
|
||||
#ifdef _METAL_
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
namespace my_ext {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation Implementation
|
||||
@@ -32,24 +27,24 @@ namespace mlx::core {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input mx::array x
|
||||
const mx::array& y, // Input mx::array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Promote dtypes between x and y as needed
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
: promote_types(promoted_dtype, mx::float32);
|
||||
|
||||
// Cast x and y up to the determined dtype (on the same stream s)
|
||||
auto x_casted = astype(x, out_dtype, s);
|
||||
auto y_casted = astype(y, out_dtype, s);
|
||||
auto x_casted = mx::astype(x, out_dtype, s);
|
||||
auto y_casted = mx::astype(y, out_dtype, s);
|
||||
|
||||
// Broadcast the shapes of x and y (on the same stream s)
|
||||
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||
@@ -57,12 +52,12 @@ array axpby(
|
||||
|
||||
// Construct the array as the output of the Axpby primitive
|
||||
// with the broadcasted and upcasted arrays as inputs
|
||||
return array(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
return mx::array(
|
||||
/* const mx::Shape& shape = */ out_shape,
|
||||
/* mx::Dtype dtype = */ out_dtype,
|
||||
/* std::shared_ptr<mx::Primitive> primitive = */
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -71,140 +66,69 @@ array axpby(
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
const T* y_ptr = y.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(y);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
// Launch the CPU kernel
|
||||
encoder.dispatch([x_ptr = x.data<T>(),
|
||||
y_ptr = y.data<T>(),
|
||||
out_ptr = out.data<T>(),
|
||||
size = out.size(),
|
||||
shape = out.shape(),
|
||||
x_strides = x.strides(),
|
||||
y_strides = y.strides(),
|
||||
alpha_,
|
||||
beta_]() {
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Accelerate Backend Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
|
||||
// This specialization requires both x and y be contiguous in the same mode
|
||||
// i.e: corresponding linear indices in both point to corresponding elements
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
T* y_ptr = out.data<T>();
|
||||
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Call the inplace accelerate operator
|
||||
catlas_saxpby(
|
||||
/* N = */ out.size(),
|
||||
/* ALPHA = */ alpha,
|
||||
/* X = */ x_ptr,
|
||||
/* INCX = */ 1,
|
||||
/* BETA = */ beta,
|
||||
/* Y = */ y_ptr,
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#else // Accelerate not available
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Metal Backend Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -213,10 +137,9 @@ void Axpby::eval_cpu(
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
@@ -225,7 +148,7 @@ void Axpby::eval_gpu(
|
||||
// and each stream carries its device identifiers
|
||||
auto& s = stream();
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
auto& d = mx::metal::device(s.device);
|
||||
|
||||
// Prepare to specialize based on contiguity
|
||||
bool contiguous_kernel =
|
||||
@@ -235,12 +158,12 @@ void Axpby::eval_gpu(
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
@@ -257,7 +180,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -272,15 +195,15 @@ void Axpby::eval_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim if needed
|
||||
if (!contiguous_kernel) {
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_vector_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
}
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
@@ -295,15 +218,15 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#else // Metal is not available
|
||||
|
||||
/** Fail evaluation on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& out) {
|
||||
throw std::runtime_error("Axpby has no GPU implementation.");
|
||||
}
|
||||
|
||||
@@ -314,9 +237,9 @@ void Axpby::eval_gpu(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> Axpby::jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
@@ -328,8 +251,8 @@ std::vector<array> Axpby::jvp(
|
||||
// scaled by beta
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
auto scale_arr = mx::array(scale, tangents[0].dtype());
|
||||
return {mx::multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
@@ -339,24 +262,24 @@ std::vector<array> Axpby::jvp(
|
||||
}
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
std::vector<mx::array> Axpby::vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
const std::vector<mx::array>&) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
std::vector<mx::array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||
auto scale_arr = mx::array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
}
|
||||
@@ -367,4 +290,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace my_ext
|
||||
|
@@ -1,11 +1,13 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
namespace mx = mlx::core;
|
||||
|
||||
namespace my_ext {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation
|
||||
@@ -18,22 +20,22 @@ namespace mlx::core {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input array x
|
||||
const mx::array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class Axpby : public Primitive {
|
||||
class Axpby : public mx::Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
explicit Axpby(mx::Stream stream, float alpha, float beta)
|
||||
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
@@ -42,23 +44,25 @@ class Axpby : public Primitive {
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
void eval_gpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
std::vector<mx::array> vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
const std::vector<mx::array>& outputs) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
@@ -66,8 +70,8 @@ class Axpby : public Primitive {
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
@@ -76,14 +80,11 @@ class Axpby : public Primitive {
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
bool is_equivalent(const mx::Primitive& other) const override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace my_ext
|
||||
|
@@ -1,8 +1,7 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T>
|
||||
@@ -13,8 +12,8 @@ template <typename T>
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int64_t* x_strides [[buffer(6)]],
|
||||
constant const int64_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
@@ -35,29 +34,14 @@ template <typename T>
|
||||
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||
}
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
||||
axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
||||
axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
// clang-format off
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \
|
||||
instantiate_kernel( \
|
||||
"axpby_contiguous_" #type_name, axpby_contiguous, type)
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
// clang-format on
|
||||
|
@@ -8,14 +8,12 @@
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
&my_ext::axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
"alpha"_a,
|
||||
|
@@ -1,8 +1,8 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.24",
|
||||
"mlx>=0.17.0",
|
||||
"nanobind==2.1.0",
|
||||
"cmake>=3.25",
|
||||
"mlx>=0.18.0",
|
||||
"nanobind==2.4.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.17.0
|
||||
nanobind==2.1.0
|
||||
cmake>=3.25
|
||||
mlx>=0.21.0
|
||||
nanobind==2.2.0
|
||||
|
15
mlx.pc.in
15
mlx.pc.in
@@ -28,10 +28,19 @@ endif()
|
||||
if (@MLX_BUILD_METAL@)
|
||||
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
|
||||
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
||||
set_and_check(MLX_INCLUDE_DIRS
|
||||
${MLX_INCLUDE_DIRS}
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
|
||||
)
|
||||
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
|
||||
else()
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set_target_properties(mlx PROPERTIES
|
||||
@@ -40,4 +49,4 @@ set_target_properties(mlx PROPERTIES
|
||||
)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||
|
@@ -5,6 +5,8 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
@@ -18,21 +20,31 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
# Define MLX_VERSION only in the version.cpp file.
|
||||
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||
|
||||
if(MSVC)
|
||||
# Disable some MSVC warnings to speed up compilation.
|
||||
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
# Export symbols by default to behave like macOS/linux.
|
||||
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
elseif(MLX_BUILD_CPU)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
|
@@ -4,12 +4,11 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Buffer malloc(size_t size) {
|
||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
||||
auto buffer = allocator().malloc(size);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
@@ -19,48 +18,7 @@ Buffer malloc(size_t size) {
|
||||
}
|
||||
|
||||
void free(Buffer buffer) {
|
||||
return allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
void* ptr = std::malloc(size + sizeof(size_t));
|
||||
if (ptr != nullptr) {
|
||||
*static_cast<size_t*>(ptr) = size;
|
||||
}
|
||||
return Buffer{ptr};
|
||||
}
|
||||
|
||||
void CommonAllocator::free(Buffer buffer) {
|
||||
std::free(buffer.ptr());
|
||||
}
|
||||
|
||||
size_t CommonAllocator::size(Buffer buffer) const {
|
||||
if (buffer.ptr() == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return *static_cast<size_t*>(buffer.ptr());
|
||||
}
|
||||
|
||||
Buffer malloc_or_wait(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
|
||||
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
||||
scheduler::wait_for_one();
|
||||
buffer = allocator().malloc(size);
|
||||
}
|
||||
|
||||
// Try swapping if needed
|
||||
if (size && !buffer.ptr()) {
|
||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
||||
}
|
||||
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return buffer;
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
|
||||
|
||||
void free(Buffer buffer);
|
||||
|
||||
// Wait for running tasks to finish and free up memory
|
||||
// if allocation fails
|
||||
Buffer malloc_or_wait(size_t size);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
virtual size_t size(Buffer buffer) const = 0;
|
||||
|
||||
@@ -53,16 +49,4 @@ class Allocator {
|
||||
|
||||
Allocator& allocator();
|
||||
|
||||
class CommonAllocator : public Allocator {
|
||||
/** A general CPU allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
private:
|
||||
CommonAllocator() = default;
|
||||
friend Allocator& allocator();
|
||||
};
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
182
mlx/array.cpp
182
mlx/array.cpp
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -9,28 +10,14 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
auto cval = static_cast<complex64_t>(val);
|
||||
init(&cval);
|
||||
}
|
||||
|
||||
array::array(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs)
|
||||
@@ -38,10 +25,21 @@ array::array(
|
||||
std::move(shape),
|
||||
dtype,
|
||||
std::move(primitive),
|
||||
std::move(inputs))) {}
|
||||
std::move(inputs))) {
|
||||
if (has_primitive() && this->primitive().stream().device == Device::gpu) {
|
||||
for (auto& in : this->inputs()) {
|
||||
if (in.dtype() == float64) {
|
||||
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||
}
|
||||
}
|
||||
if (this->dtype() == float64) {
|
||||
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
std::vector<std::vector<int>> shapes,
|
||||
std::vector<Shape> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
@@ -58,58 +56,90 @@ std::vector<array> array::make_arrays(
|
||||
return outputs;
|
||||
}
|
||||
|
||||
array array::unsafe_weak_copy(const array& other) {
|
||||
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
|
||||
cpy.set_data(
|
||||
other.buffer(),
|
||||
other.data_size(),
|
||||
other.strides(),
|
||||
other.flags(),
|
||||
[](auto) {});
|
||||
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||
return cpy;
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<float> data)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
float32)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
dtype)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter)
|
||||
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
set_data(data, deleter);
|
||||
}
|
||||
|
||||
void array::detach() {
|
||||
array_desc_->primitive = nullptr;
|
||||
for (auto& s : array_desc_->siblings) {
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
for (auto& s : array_desc_->siblings) {
|
||||
s.array_desc_->inputs.clear();
|
||||
s.array_desc_->siblings.clear();
|
||||
s.array_desc_->position = 0;
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
array_desc_->inputs.clear();
|
||||
array_desc_->siblings.clear();
|
||||
array_desc_->position = 0;
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
bool array::is_available() const {
|
||||
if (status() == Status::available) {
|
||||
return true;
|
||||
} else if (
|
||||
status() == Status::evaluated &&
|
||||
(!event().valid() || event().is_signaled())) {
|
||||
set_status(Status::available);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void array::wait() {
|
||||
if (!is_available()) {
|
||||
if (event().valid()) {
|
||||
event().wait();
|
||||
detach_event();
|
||||
}
|
||||
set_status(Status::available);
|
||||
}
|
||||
}
|
||||
|
||||
void array::eval() {
|
||||
// Ensure the array is ready to be read
|
||||
if (status() == Status::scheduled) {
|
||||
event().wait();
|
||||
set_status(Status::available);
|
||||
} else if (status() == Status::unscheduled) {
|
||||
if (status() == Status::unscheduled) {
|
||||
mlx::core::eval({*this});
|
||||
} else {
|
||||
wait();
|
||||
}
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
return (array_desc_->is_tracer && detail::in_tracing()) ||
|
||||
detail::retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = size();
|
||||
@@ -122,9 +152,9 @@ void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
void array::set_data(
|
||||
allocator::Buffer buffer,
|
||||
size_t data_size,
|
||||
std::vector<size_t> strides,
|
||||
Strides strides,
|
||||
Flags flags,
|
||||
deleter_t d) {
|
||||
Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = data_size;
|
||||
@@ -134,7 +164,7 @@ void array::set_data(
|
||||
|
||||
void array::copy_shared_buffer(
|
||||
const array& other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
@@ -151,32 +181,13 @@ void array::copy_shared_buffer(const array& other) {
|
||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(
|
||||
array other,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
array_desc_->data = std::move(other.array_desc_->data);
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
array_desc_->data_size = data_size;
|
||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||
array_desc_->data_ptr = static_cast<void*>(
|
||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(array other) {
|
||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
array::~array() {
|
||||
if (array_desc_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore arrays that might be detached during eval
|
||||
if (status() == array::Status::scheduled) {
|
||||
// Detached/detaching
|
||||
if (array_desc_->primitive == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -196,6 +207,8 @@ array::~array() {
|
||||
if (do_detach) {
|
||||
for (auto& s : siblings()) {
|
||||
for (auto& ss : s.siblings()) {
|
||||
// Set to null here to avoid descending into array destructor
|
||||
// for siblings
|
||||
ss.array_desc_ = nullptr;
|
||||
}
|
||||
s.array_desc_->siblings.clear();
|
||||
@@ -216,13 +229,13 @@ void array::ArrayDesc::init() {
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
|
||||
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs)
|
||||
@@ -242,25 +255,58 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
// This calls recursively the destructor and can result in stack overflow, we
|
||||
// instead put them in a vector and destroy them one at a time resulting in a
|
||||
// max stack depth of 2.
|
||||
if (inputs.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
|
||||
|
||||
for (array& a : inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
|
||||
std::unordered_map<std::uintptr_t, array> input_map;
|
||||
for (array& a : ad.inputs) {
|
||||
if (a.array_desc_) {
|
||||
input_map.insert({a.id(), a});
|
||||
for (auto& s : a.siblings()) {
|
||||
input_map.insert({s.id(), s});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ad.inputs.clear();
|
||||
for (auto& [_, a] : input_map) {
|
||||
bool is_deletable =
|
||||
(a.array_desc_.use_count() <= a.siblings().size() + 1);
|
||||
// An array with siblings is deletable only if all of its siblings
|
||||
// are deletable
|
||||
for (auto& s : a.siblings()) {
|
||||
if (!is_deletable) {
|
||||
break;
|
||||
}
|
||||
int is_input = (input_map.find(s.id()) != input_map.end());
|
||||
is_deletable &=
|
||||
s.array_desc_.use_count() <= a.siblings().size() + is_input;
|
||||
}
|
||||
if (is_deletable) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
append_deletable_inputs(*this);
|
||||
|
||||
while (!for_deletion.empty()) {
|
||||
// top is going to be deleted at the end of the block *after* the arrays
|
||||
// with inputs have been moved into the vector
|
||||
auto top = std::move(for_deletion.back());
|
||||
for_deletion.pop_back();
|
||||
append_deletable_inputs(*top);
|
||||
|
||||
for (array& a : top->inputs) {
|
||||
if (a.array_desc_.use_count() == 1) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
// Clear out possible siblings to break circular references
|
||||
for (auto& s : top->siblings) {
|
||||
// Set to null here to avoid descending into top-level
|
||||
// array destructor for siblings
|
||||
s.array_desc_ = nullptr;
|
||||
}
|
||||
top->siblings.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,7 +318,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
}
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
auto start = std::vector<int>(arr.ndim(), 0);
|
||||
auto start = Shape(arr.ndim(), 0);
|
||||
auto end = arr.shape();
|
||||
auto shape = arr.shape();
|
||||
shape.erase(shape.begin());
|
||||
|
120
mlx/array.h
120
mlx/array.h
@@ -15,7 +15,11 @@ namespace mlx::core {
|
||||
|
||||
// Forward declaration
|
||||
class Primitive;
|
||||
using deleter_t = std::function<void(allocator::Buffer)>;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using ShapeElem = int32_t;
|
||||
using Shape = std::vector<ShapeElem>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
|
||||
class array {
|
||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||
@@ -31,33 +35,33 @@ class array {
|
||||
explicit array(const std::complex<float>& val, Dtype dtype = complex64);
|
||||
|
||||
template <typename It>
|
||||
array(
|
||||
explicit array(
|
||||
It data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype =
|
||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||
|
||||
template <typename T>
|
||||
array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
||||
explicit array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Special case so empty lists default to float32. */
|
||||
array(std::initializer_list<float> data);
|
||||
explicit array(std::initializer_list<float> data);
|
||||
|
||||
/* Special case so array({}, type) is an empty array. */
|
||||
array(std::initializer_list<int> data, Dtype dtype);
|
||||
explicit array(std::initializer_list<int> data, Dtype dtype);
|
||||
|
||||
template <typename T>
|
||||
array(
|
||||
explicit array(
|
||||
std::initializer_list<T> data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a buffer */
|
||||
array(
|
||||
explicit array(
|
||||
allocator::Buffer data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter = allocator::free);
|
||||
Deleter deleter = allocator::free);
|
||||
|
||||
/** Assignment to rvalue does not compile. */
|
||||
array& operator=(const array& other) && = delete;
|
||||
@@ -96,7 +100,7 @@ class array {
|
||||
}
|
||||
|
||||
/** The shape of the array as a vector of integers. */
|
||||
const std::vector<int>& shape() const {
|
||||
const Shape& shape() const {
|
||||
return array_desc_->shape;
|
||||
}
|
||||
|
||||
@@ -105,12 +109,12 @@ class array {
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
int shape(int dim) const {
|
||||
auto shape(int dim) const {
|
||||
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
|
||||
/** The strides of the array. */
|
||||
const std::vector<size_t>& strides() const {
|
||||
const Strides& strides() const {
|
||||
return array_desc_->strides;
|
||||
}
|
||||
|
||||
@@ -119,7 +123,7 @@ class array {
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
size_t strides(int dim) const {
|
||||
auto strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
|
||||
@@ -184,17 +188,24 @@ class array {
|
||||
*/
|
||||
|
||||
array(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
std::vector<std::vector<int>> shapes,
|
||||
std::vector<Shape> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/**
|
||||
* Get a new array that refers to the same data as the input but with a
|
||||
* non-owning pointer to it. Note the array is detached from the graph and has
|
||||
* no inputs, siblings or primitive.
|
||||
*/
|
||||
static array unsafe_weak_copy(const array& other);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
std::uintptr_t id() const {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||
@@ -207,8 +218,8 @@ class array {
|
||||
|
||||
struct Data {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
Deleter d;
|
||||
Data(allocator::Buffer buffer, Deleter d = allocator::free)
|
||||
: buffer(buffer), d(d) {}
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
@@ -328,11 +339,11 @@ class array {
|
||||
return allocator::allocator().size(buffer());
|
||||
}
|
||||
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
std::shared_ptr<Data> data_shared_ptr() const {
|
||||
// Return the shared pointer to the array::Data struct
|
||||
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||
return array_desc_->data;
|
||||
}
|
||||
|
||||
// Return a raw pointer to the arrays data
|
||||
template <typename T>
|
||||
T* data() {
|
||||
@@ -344,11 +355,28 @@ class array {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
}
|
||||
|
||||
enum Status { unscheduled, scheduled, available };
|
||||
enum Status {
|
||||
// The ouptut of a computation which has not been scheduled.
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
bool is_available() const {
|
||||
return status() == Status::available;
|
||||
}
|
||||
// The array's `eval_*` function has been run, but the computation is not
|
||||
// necessarily complete. The array will have memory allocated and if it is
|
||||
// not a tracer then it will be detached from the graph.
|
||||
evaluated,
|
||||
|
||||
// If the array is the output of a computation then the computation
|
||||
// is complete. Constant arrays are always available (e.g. `array({1, 2,
|
||||
// 3})`)
|
||||
available
|
||||
};
|
||||
|
||||
// Check if the array is safe to read.
|
||||
bool is_available() const;
|
||||
|
||||
// Wait on the array to be available. After this `is_available` returns
|
||||
// `true`.
|
||||
void wait();
|
||||
|
||||
Status status() const {
|
||||
return array_desc_->status;
|
||||
@@ -368,6 +396,10 @@ class array {
|
||||
array_desc_->event = std::move(e);
|
||||
}
|
||||
|
||||
void detach_event() const {
|
||||
array_desc_->event = Event{};
|
||||
}
|
||||
|
||||
// Mark the array as a tracer array (true) or not.
|
||||
void set_tracer(bool is_tracer) {
|
||||
array_desc_->is_tracer = is_tracer;
|
||||
@@ -375,33 +407,24 @@ class array {
|
||||
// Check if the array is a tracer array
|
||||
bool is_tracer() const;
|
||||
|
||||
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
||||
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
|
||||
|
||||
void set_data(
|
||||
allocator::Buffer buffer,
|
||||
size_t data_size,
|
||||
std::vector<size_t> strides,
|
||||
Strides strides,
|
||||
Flags flags,
|
||||
deleter_t d = allocator::free);
|
||||
Deleter d = allocator::free);
|
||||
|
||||
void copy_shared_buffer(
|
||||
const array& other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
void copy_shared_buffer(const array& other);
|
||||
|
||||
void move_shared_buffer(
|
||||
array other,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
void move_shared_buffer(array other);
|
||||
|
||||
void overwrite_descriptor(const array& other) {
|
||||
array_desc_ = other.array_desc_;
|
||||
}
|
||||
@@ -414,8 +437,8 @@ class array {
|
||||
void init(const It src);
|
||||
|
||||
struct ArrayDesc {
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
size_t size;
|
||||
Dtype dtype;
|
||||
std::shared_ptr<Primitive> primitive;
|
||||
@@ -449,10 +472,10 @@ class array {
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
|
||||
explicit ArrayDesc(Shape shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
@@ -473,14 +496,14 @@ class array {
|
||||
|
||||
template <typename T>
|
||||
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
init(&val);
|
||||
}
|
||||
|
||||
template <typename It>
|
||||
array::array(
|
||||
It data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
init(data);
|
||||
@@ -491,7 +514,7 @@ array::array(
|
||||
std::initializer_list<T> data,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
dtype)) {
|
||||
init(data.begin());
|
||||
}
|
||||
@@ -499,7 +522,7 @@ array::array(
|
||||
template <typename T>
|
||||
array::array(
|
||||
std::initializer_list<T> data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
if (data.size() != size()) {
|
||||
@@ -568,6 +591,9 @@ void array::init(It src) {
|
||||
case float32:
|
||||
std::copy(src, src + size(), data<float>());
|
||||
break;
|
||||
case float64:
|
||||
std::copy(src, src + size(), data<double>());
|
||||
break;
|
||||
case bfloat16:
|
||||
std::copy(src, src + size(), data<bfloat16_t>());
|
||||
break;
|
||||
|
@@ -1,8 +0,0 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
|
@@ -1,20 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
|
||||
// TODO: Add accelerate based optimizations for CPU conv
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,253 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/backend/accelerate/utils.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<bool, size_t, array> check_transpose(const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matmul_cblas_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[matmul_cblas] on CPU currently only supports float32");
|
||||
}
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[matmul_cblas] on CPU currently only supports float32");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return matmul_cblas_general(a_pre, b_pre, out);
|
||||
}
|
||||
|
||||
inline void matmul_bnns_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
// TODO: Update to utilize BNNS broadcasting
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||
|
||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||
/* float alpha = */ alpha,
|
||||
/* float beta = */ beta,
|
||||
/* bool transA = */ a_transposed,
|
||||
/* bool transB = */ b_transposed,
|
||||
/* bool quadratic = */ false,
|
||||
/* bool a_is_weights = */ false,
|
||||
/* bool b_is_weights = */ false,
|
||||
/* BNNSNDArrayDescriptor iA_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, lda, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor iB_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, ldb, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor o_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{N, M, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, N, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
};
|
||||
|
||||
auto bnns_filter =
|
||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
BNNSFilterApplyTwoInput(
|
||||
bnns_filter,
|
||||
a.data<uint8_t>() +
|
||||
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
|
||||
b.data<uint8_t>() +
|
||||
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
|
||||
out.data<uint8_t>() + M * N * i * out.itemsize());
|
||||
}
|
||||
|
||||
BNNSFilterDestroy(bnns_filter);
|
||||
}
|
||||
|
||||
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
||||
// TODO: Update to utilize BNNS broadcasting
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return matmul_bnns_general(a_pre, b_pre, out);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const bool* mask,
|
||||
int tile_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str) {
|
||||
int tX = (X + tile_size - 1) / tile_size;
|
||||
int tY = (Y + tile_size - 1) / tile_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
||||
if (!do_mask) {
|
||||
int loc_x = i * tile_size;
|
||||
int loc_y = j * tile_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
|
||||
int size_x = std::min(tile_size, X - loc_x);
|
||||
int size_y = std::min(tile_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() == float32) {
|
||||
return matmul_cblas(inputs[0], inputs[1], out);
|
||||
}
|
||||
return matmul_bnns(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
copy(c, out, ctype);
|
||||
|
||||
if (out.dtype() == float32) {
|
||||
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,600 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define DEFAULT(primitive) \
|
||||
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
#define DEFAULT_MULTI(primitive) \
|
||||
void primitive::eval_cpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
primitive::eval(inputs, outputs); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Use the default implementation for the following primitives
|
||||
DEFAULT(Arange)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x + y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x + y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvacosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvacoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvasinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvasinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvatanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
if (a.is_donatable()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
int size = a.data_size();
|
||||
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvatanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (in.flags().contiguous) {
|
||||
// Use accelerate functions if possible
|
||||
if (in.dtype() == float32 && out.dtype() == uint32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfixu32(
|
||||
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == float32 && out.dtype() == int32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == uint32 && out.dtype() == float32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfltu32(
|
||||
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == int32 && out.dtype() == float32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
}
|
||||
}
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvcosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvcoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x / y; },
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x / y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpm1f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
assert(in.dtype() == out.dtype());
|
||||
if (in.data_size() == 1 && out.dtype() == float32) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
vDSP_vfill(in.data<float>(), out.data<float>(), 1, out.size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
vvlogf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
case Base::two:
|
||||
vvlog2f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
case Base::ten:
|
||||
vvlog10f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x * y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
int size = a.size();
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
|
||||
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int stride = in.shape(axis_);
|
||||
int count = in.size() / stride;
|
||||
const float* input = in.data<float>();
|
||||
float* output = out.data<float>();
|
||||
float s = 1.0;
|
||||
if (!reverse_) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
|
||||
input += stride;
|
||||
output += stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
input += stride - 1;
|
||||
output += stride - 1;
|
||||
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvsinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvsinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
if (recip_) {
|
||||
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
vvsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x - y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
float minus_1 = -1;
|
||||
vDSP_vsmsa(
|
||||
(const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
float val = -(*s);
|
||||
vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x - y; },
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
int val = -(*s);
|
||||
vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n);
|
||||
},
|
||||
UseDefaultBinaryOp());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvtanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvtanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,101 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void _qmm_t_4_64(
|
||||
float* result,
|
||||
const float* x,
|
||||
const uint32_t* w,
|
||||
const float* scales,
|
||||
const float* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bits = 4;
|
||||
constexpr int group_size = 64;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const float* scales_local = scales;
|
||||
const float* biases_local = biases;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
const simd_float16* x_local = (simd_float16*)x;
|
||||
simd_float16 sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
float scale = *scales_local++;
|
||||
float bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||
// TODO: vectorize this properly
|
||||
simd_uint16 wi;
|
||||
for (int e = 0; e < 2; e++) {
|
||||
uint32_t wii = *w_local++;
|
||||
for (int p = 0; p < 8; p++) {
|
||||
wi[e * 8 + p] = wii & bitmask;
|
||||
wii >>= bits;
|
||||
}
|
||||
}
|
||||
simd_float16 wf = simd_float(wi);
|
||||
wf *= scale;
|
||||
wf += bias;
|
||||
|
||||
sum += (*x_local) * wf;
|
||||
x_local++;
|
||||
}
|
||||
}
|
||||
|
||||
*result = simd_reduce_add(sum);
|
||||
result++;
|
||||
}
|
||||
|
||||
x += K;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
auto& scales = inputs[2];
|
||||
auto& biases = inputs[3];
|
||||
|
||||
bool condition =
|
||||
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
|
||||
scales.flags().row_contiguous && biases.flags().row_contiguous &&
|
||||
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
|
||||
|
||||
if (condition) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int K = x.shape(-1);
|
||||
int M = x.size() / K;
|
||||
int N = out.shape(-1);
|
||||
_qmm_t_4_64(
|
||||
out.data<float>(),
|
||||
x.data<float>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<float>(),
|
||||
biases.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,139 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct MinReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return simd_min(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct MaxReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return std::max(a, b);
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct SumReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, int N, typename Reduction>
|
||||
struct StridedReduce {
|
||||
void operator()(const T* x, T* accum, int size, size_t stride) {
|
||||
Reduction op;
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
size_t s = stride;
|
||||
T* a = accum;
|
||||
while (s >= N) {
|
||||
*(VT*)a = op((*(VT*)x), (*(VT*)a));
|
||||
x += N;
|
||||
a += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*a = op(*a, *x);
|
||||
a++;
|
||||
x++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (in.dtype() == float32) {
|
||||
if (reduce_type_ == Reduce::Sum) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
0,
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
SumReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float acc;
|
||||
vDSP_sve((const float*)x, 1, &acc, size);
|
||||
(*accum) += acc;
|
||||
},
|
||||
[](auto* accum, auto x) { *accum += x; });
|
||||
return;
|
||||
} else if (reduce_type_ == Reduce::Max) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
-std::numeric_limits<float>::infinity(),
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
MaxReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float max;
|
||||
vDSP_maxv((const float*)x, 1, &max, size);
|
||||
(*accum) = (*accum < max) ? max : *accum;
|
||||
},
|
||||
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
|
||||
return;
|
||||
} else if (reduce_type_ == Reduce::Min) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
std::numeric_limits<float>::infinity(),
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
MinReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float min;
|
||||
vDSP_minv((const float*)x, 1, &min, size);
|
||||
(*accum) = (*accum > min) ? min : *accum;
|
||||
},
|
||||
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
|
||||
return;
|
||||
}
|
||||
}
|
||||
// TODO: Add integer addition and min/max using the templates above and
|
||||
// simd_int16 and friends.
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,391 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* Compute exp(x) in an optimizer friendly way as follows:
|
||||
*
|
||||
* First change the problem to computing 2**y where y = x / ln(2).
|
||||
*
|
||||
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
|
||||
* `ipart` and y2 is fractional part. For the integer part we perform bit
|
||||
* shifting and for the fractional part we use a polynomial approximation.
|
||||
*
|
||||
* The algorithm and constants of the polynomial taken from
|
||||
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
|
||||
* from Cephes math library.
|
||||
*
|
||||
* Note: The implementation below is a general fast exp. There could be faster
|
||||
* implementations for numbers strictly < 0.
|
||||
*/
|
||||
inline simd_float16 simd_fast_exp(simd_float16 x) {
|
||||
x *= 1.442695; // multiply with log_2(e)
|
||||
simd_float16 ipart, fpart;
|
||||
simd_int16 epart;
|
||||
x = simd_clamp(x, -80, 80);
|
||||
ipart = simd::floor(x + 0.5);
|
||||
fpart = x - ipart;
|
||||
|
||||
x = 1.535336188319500e-4f;
|
||||
x = x * fpart + 1.339887440266574e-3f;
|
||||
x = x * fpart + 9.618437357674640e-3f;
|
||||
x = x * fpart + 5.550332471162809e-2f;
|
||||
x = x * fpart + 2.402264791363012e-1f;
|
||||
x = x * fpart + 6.931472028550421e-1f;
|
||||
x = x * fpart + 1.000000000000000f;
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
epart = (simd_int(ipart) + 127) << 23;
|
||||
|
||||
return (*(simd_float16*)&epart) * x;
|
||||
}
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
/**
|
||||
* The ARM neon equivalent of the fast exp above.
|
||||
*/
|
||||
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
|
||||
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
|
||||
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
|
||||
|
||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
|
||||
float16x8_t fpart = vsubq_f16(x, ipart);
|
||||
|
||||
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
int16x8_t epart = vcvtq_s16_f16(ipart);
|
||||
epart = vaddq_s16(epart, vdupq_n_s16(15));
|
||||
epart = vshlq_n_s16(epart, 10);
|
||||
|
||||
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of folding maximum for ARM neon. This should possibly be
|
||||
* refactored out of softmax.cpp at some point.
|
||||
*/
|
||||
inline float16_t neon_reduce_max(float16x8_t x) {
|
||||
float16x4_t y;
|
||||
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
|
||||
y = vpmax_f16(y, y);
|
||||
y = vpmax_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of folding sum for ARM neon. This should possibly be
|
||||
* refactored out of softmax.cpp at some point.
|
||||
*/
|
||||
inline float16_t neon_reduce_add(float16x8_t x) {
|
||||
float16x4_t y;
|
||||
float16x4_t zero = vdup_n_f16(0);
|
||||
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
|
||||
y = vpadd_f16(y, zero);
|
||||
y = vpadd_f16(y, zero);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct NeonFp16SimdOps {
|
||||
VT init(T a) {
|
||||
return vdupq_n_f16(a);
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return vld1q_f16(a);
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
vst1q_f16(dst, x);
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return vmaxq_f16(a, b);
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return neon_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return vaddq_f16(a, b);
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return vsubq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return vmulq_f16(a, b);
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return vmulq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return neon_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return neon_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct AccelerateSimdOps {
|
||||
VT init(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return *(VT*)a;
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
*(VT*)dst = x;
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return simd_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return simd_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return simd_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
VT vals;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vals = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vals[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vmaximum = ops.max(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT maximum = ops.reduce_max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
VT vnormalizer = ops.init(0.0);
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
VT vexp;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vexp = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vexp = ops.exp(ops.sub(vexp, maximum));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = ops.add(vnormalizer, vexp);
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = ops.reduce_add(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
} else {
|
||||
VT vexp;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
current_out_ptr[i] = vexp[i];
|
||||
}
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
array in = check_input(std::move(inputs[0]));
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case uint16:
|
||||
case uint32:
|
||||
case uint64:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::invalid_argument(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<
|
||||
float,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
if (precise_) {
|
||||
softmax<
|
||||
float16_t,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
eval(inputs, out); // Redirect to common backend for consistency
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
eval(inputs, out);
|
||||
break;
|
||||
case complex64:
|
||||
eval(inputs, out);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,28 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include "mlx/dtype.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
||||
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
||||
switch (kindof(mlx_dtype)) {
|
||||
case Dtype::Kind::b:
|
||||
return BNNSDataTypeBoolean;
|
||||
case Dtype::Kind::u:
|
||||
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
||||
case Dtype::Kind::i:
|
||||
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
||||
case Dtype::Kind::f:
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
||||
case Dtype::Kind::V:
|
||||
return BNNSDataTypeBFloat16;
|
||||
case Dtype::Kind::c:
|
||||
throw std::invalid_argument("BNNS does not support complex types");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,61 +1,9 @@
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(COMPILER ${CMAKE_C_COMPILER})
|
||||
set(CLANG TRUE)
|
||||
else()
|
||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND
|
||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG}
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
ops.h)
|
||||
|
||||
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
||||
|
||||
add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
if(IOS)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
|
||||
endif()
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
|
@@ -1,74 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void arange(T start, T next, array& out, size_t size) {
|
||||
auto ptr = out.data<T>();
|
||||
auto step_size = next - start;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ptr[i] = start;
|
||||
start += step_size;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void arange(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
double start,
|
||||
double step) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
break;
|
||||
case uint8:
|
||||
arange<uint8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint16:
|
||||
arange<uint16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint32:
|
||||
arange<uint32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint64:
|
||||
arange<uint64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int8:
|
||||
arange<int8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int16:
|
||||
arange<int16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int32:
|
||||
arange<int32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int64:
|
||||
arange<int64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float16:
|
||||
arange<float16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float32:
|
||||
arange<float>(start, start + step, out, out.size());
|
||||
break;
|
||||
case bfloat16:
|
||||
arange<bfloat16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case complex64:
|
||||
arange<complex64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,112 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
std::vector<size_t> strides = in.strides();
|
||||
std::vector<int> shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto in_ptr = in.data<InT>() + loc;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
|
||||
op(j, (*in_ptr), &ind_v, &v);
|
||||
}
|
||||
out.data<uint32_t>()[i] = ind_v;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void arg_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
ArgReduce::ReduceType rtype,
|
||||
int axis) {
|
||||
switch (rtype) {
|
||||
case ArgReduce::ArgMin: {
|
||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||
if (x < (*y)) {
|
||||
(*y) = x;
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
break;
|
||||
}
|
||||
case ArgReduce::ArgMax: {
|
||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||
if (x > (*y)) {
|
||||
(*y) = x;
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgReduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint8:
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int8:
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int16:
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int32:
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int64:
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float16:
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float32:
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case bfloat16:
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case complex64:
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,331 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/binary_two.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, U, Op> opsv(op);
|
||||
DefaultVectorScalar<T, U, Op> opvs(op);
|
||||
DefaultVectorVector<T, U, Op> opvv(op);
|
||||
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
comparison_op<bool, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
comparison_op<uint8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
comparison_op<uint16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
comparison_op<uint32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
comparison_op<uint64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
comparison_op<int8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
comparison_op<int16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
comparison_op<int32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
comparison_op<int64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
comparison_op<float16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
comparison_op<float, bool>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
comparison_op<bfloat16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
comparison_op<complex64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Add::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Add());
|
||||
}
|
||||
|
||||
void DivMod::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto integral_op = [](auto x, auto y) {
|
||||
return std::make_pair(x / y, x % y);
|
||||
};
|
||||
auto float_op = [](auto x, auto y) {
|
||||
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||
};
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, integral_op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, float_op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case complex64:
|
||||
// Should never get here
|
||||
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Divide());
|
||||
}
|
||||
|
||||
void Remainder::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Remainder());
|
||||
}
|
||||
|
||||
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (equal_nan_) {
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual());
|
||||
} else {
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Equal());
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater());
|
||||
}
|
||||
|
||||
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
|
||||
}
|
||||
|
||||
void Less::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less());
|
||||
}
|
||||
|
||||
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
|
||||
}
|
||||
|
||||
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[logaddexp] Cannot compute logaddexp for arrays with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Maximum());
|
||||
}
|
||||
|
||||
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Minimum());
|
||||
}
|
||||
|
||||
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Multiply());
|
||||
}
|
||||
|
||||
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
|
||||
}
|
||||
|
||||
void Power::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Power());
|
||||
}
|
||||
|
||||
void Subtract::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Subtract());
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto dispatch_type = [&a, &b, &out](auto op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[BitwiseBinary::eval_cpu] Type not supported");
|
||||
break;
|
||||
}
|
||||
};
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
dispatch_type(detail::BitwiseAnd());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
dispatch_type(detail::BitwiseOr());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
dispatch_type(detail::BitwiseXor());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
dispatch_type(detail::LeftShift());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
dispatch_type(detail::RightShift());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[arctan2] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan2] Cannot compute inverse tangent for arrays"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
@@ -9,8 +8,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class BinaryOpType {
|
||||
ScalarScalar,
|
||||
ScalarVector,
|
||||
@@ -19,7 +16,7 @@ enum class BinaryOpType {
|
||||
General,
|
||||
};
|
||||
|
||||
BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
BinaryOpType bopt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1) {
|
||||
bopt = BinaryOpType::ScalarScalar;
|
||||
@@ -28,8 +25,8 @@ BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
} else if (b.data_size() == 1 && a.flags().contiguous) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
} else if (
|
||||
a.flags().row_contiguous && b.flags().row_contiguous ||
|
||||
a.flags().col_contiguous && b.flags().col_contiguous) {
|
||||
(a.flags().row_contiguous && b.flags().row_contiguous) ||
|
||||
(a.flags().col_contiguous && b.flags().col_contiguous)) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
} else {
|
||||
bopt = BinaryOpType::General;
|
||||
@@ -37,29 +34,24 @@ BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
return bopt;
|
||||
}
|
||||
|
||||
void set_binary_op_output_data(
|
||||
inline void set_binary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
BinaryOpType bopt,
|
||||
bool donate_with_move = false) {
|
||||
BinaryOpType bopt) {
|
||||
bool b_donatable = is_donatable(b, out);
|
||||
bool a_donatable = is_donatable(a, out);
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||
allocator::malloc(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -67,14 +59,10 @@ void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
if (a_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
out.copy_shared_buffer(a);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -82,20 +70,12 @@ void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
if (a_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -103,428 +83,15 @@ void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (
|
||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
struct UseDefaultBinaryOp {};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorScalar {
|
||||
Op op;
|
||||
|
||||
DefaultVectorScalar(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *b;
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, scalar);
|
||||
dst++;
|
||||
a++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultScalarVector {
|
||||
Op op;
|
||||
|
||||
DefaultScalarVector(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *a;
|
||||
while (size-- > 0) {
|
||||
*dst = op(scalar, *b);
|
||||
dst++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorVector {
|
||||
Op op;
|
||||
|
||||
DefaultVectorVector(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, *b);
|
||||
dst++;
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int D, bool Strided>
|
||||
void binary_op_dims(
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
auto stride_out = out_strides[axis];
|
||||
auto N = shape[axis];
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
binary_op_dims<T, U, Op, D - 1, Strided>(
|
||||
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||
} else {
|
||||
if constexpr (Strided) {
|
||||
op(a, b, out, stride_out);
|
||||
} else {
|
||||
*out = op(*a, *b);
|
||||
}
|
||||
}
|
||||
out += stride_out;
|
||||
a += stride_a;
|
||||
b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, bool Strided, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int dim,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims<T, U, Op, 1, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims<T, U, Op, 2, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
|
||||
size_t stride = out_strides[dim - 4];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
dim - 3);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename OpSV,
|
||||
typename OpVS,
|
||||
typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||
const auto& a_strides = new_strides[0];
|
||||
const auto& b_strides = new_strides[1];
|
||||
const auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
||||
|
||||
auto ndim = new_shape.size();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false>(
|
||||
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||
// with template specializations and overloading. Would it be simpler?
|
||||
|
||||
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv and opvs were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opsv and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvs and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opvs was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opvv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// All ops provided
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, T, Op> opsv(op);
|
||||
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||
DefaultVectorVector<T, T, Op> opvv(op);
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename... Ops>
|
||||
void binary(const array& a, const array& b, array& out, Ops... ops) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, ops...);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, ops...);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, ops...);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, out, ops...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
24
mlx/backend/common/broadcasting.cpp
Normal file
24
mlx/backend/common/broadcasting.cpp
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/common/broadcasting.h
Normal file
11
mlx/backend/common/broadcasting.h
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,101 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Delegate to the Cholesky factorization taking into account differences in
|
||||
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
|
||||
int spotrf_wrapper(char uplo, float* matrix, int N) {
|
||||
int info;
|
||||
|
||||
#ifdef LAPACK_FORTRAN_STRLEN_END
|
||||
spotrf_(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info,
|
||||
/* uplo_len = */ static_cast<size_t>(1));
|
||||
#else
|
||||
spotrf_(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
#endif
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||
// the matrix should be symmetric:
|
||||
// (A)ᵀ = A
|
||||
// and that a column-major lower triangular matrix is a row-major upper
|
||||
// triangular matrix, so uplo is the opposite of what we would expect from
|
||||
// upper
|
||||
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
float* matrix = factor.data<float>();
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info = spotrf_wrapper(uplo, matrix, N);
|
||||
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
// to catch errors from the implementation we should throw.
|
||||
if (info < 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[cholesky] Cholesky decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Zero out the upper/lower triangle while advancing the pointer to the
|
||||
// next matrix at the same time.
|
||||
for (int row = 0; row < N; row++) {
|
||||
if (upper) {
|
||||
std::fill(matrix, matrix + row, 0);
|
||||
} else {
|
||||
std::fill(matrix + row + 1, matrix + N, 0);
|
||||
}
|
||||
matrix += N;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Cholesky::eval] only supports float32.");
|
||||
}
|
||||
cholesky_impl(inputs[0], output, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -43,22 +44,11 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -85,9 +75,19 @@ void Depends::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto strides = in.strides();
|
||||
for (auto ax : axes_) {
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
@@ -135,15 +135,16 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
case bfloat16:
|
||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||
break;
|
||||
case float64:
|
||||
*out.data<double>() = static_cast<double>(numel);
|
||||
break;
|
||||
case complex64:
|
||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
@@ -151,8 +152,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
return {false, Strides(out.ndim(), 0)};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
@@ -160,7 +160,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
Strides out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
@@ -181,9 +181,9 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Strides& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
@@ -249,16 +249,18 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
Strides strides;
|
||||
for (int i = 0, j = 0; i < in.ndim(); ++i) {
|
||||
if (j < axes_.size() && i == axes_[j]) {
|
||||
j++;
|
||||
} else {
|
||||
strides.push_back(in.strides(i));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -268,7 +270,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
Strides out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
@@ -285,8 +287,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
|
@@ -130,7 +130,7 @@ std::string build_lib_name(
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape) {
|
||||
const Shape& shape) {
|
||||
bool contiguous = true;
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
@@ -161,11 +161,10 @@ void compiled_allocate_outputs(
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers /* = false */) {
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
Strides strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
@@ -178,11 +177,7 @@ void compiled_allocate_outputs(
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o++].move_shared_buffer(in);
|
||||
} else {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
@@ -193,7 +188,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
@@ -210,18 +205,13 @@ void compiled_allocate_outputs(
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
} else {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
}
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -11,9 +11,7 @@
|
||||
namespace mlx::core {
|
||||
|
||||
inline bool is_static_cast(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
|
||||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
@@ -56,7 +54,7 @@ inline bool is_scalar(const array& x) {
|
||||
// Check if we can use a contiguous operation given inputs and the output shape
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape);
|
||||
const Shape& shape);
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
void compiled_allocate_outputs(
|
||||
@@ -64,7 +62,6 @@ void compiled_allocate_outputs(
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers = false);
|
||||
bool contiguous);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user