mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-02 22:04:45 +08:00
Compare commits
350 Commits
socket-dis
...
split_logs
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7c99acb799 | ||
![]() |
5a1a5d5ed1 | ||
![]() |
1683975acf | ||
![]() |
af705590ac | ||
![]() |
825124af8f | ||
![]() |
9c5e7da507 | ||
![]() |
481349495b | ||
![]() |
9daa6b003f | ||
![]() |
a3a632d567 | ||
![]() |
e496c5a4b4 | ||
![]() |
ea890d8710 | ||
![]() |
aa5d84f102 | ||
![]() |
f1606486d2 | ||
![]() |
87720a8908 | ||
![]() |
bb6565ef14 | ||
![]() |
7bb063bcb3 | ||
![]() |
b36dd472bb | ||
![]() |
167b759a38 | ||
![]() |
99b9868859 | ||
![]() |
6b2d5448f2 | ||
![]() |
eaf709b83e | ||
![]() |
f0e70afff0 | ||
![]() |
86984cad68 | ||
![]() |
fbc89e3ced | ||
![]() |
38c1e720c2 | ||
![]() |
600e87e03c | ||
![]() |
3836445241 | ||
![]() |
1d2c9d6a07 | ||
![]() |
e8ac6bd2f5 | ||
![]() |
fdadc4f22c | ||
![]() |
79b527f45f | ||
![]() |
dc4eada7f0 | ||
![]() |
70ebc3b598 | ||
![]() |
b13f2aed16 | ||
![]() |
5f04c0f818 | ||
![]() |
55935ccae7 | ||
![]() |
b529515eb1 | ||
![]() |
3cde719eb7 | ||
![]() |
5de6d94a90 | ||
![]() |
99eefd2ec0 | ||
![]() |
e9e268336b | ||
![]() |
7275ac7523 | ||
![]() |
c4189a38e4 | ||
![]() |
68d1b3256b | ||
![]() |
9c6953bda7 | ||
![]() |
ef7ece9851 | ||
![]() |
ddaa4b7dcb | ||
![]() |
dfae2c6989 | ||
![]() |
515f104926 | ||
![]() |
9ecefd56db | ||
![]() |
e5d35aa187 | ||
![]() |
00794c42bc | ||
![]() |
08a1bf3f10 | ||
![]() |
60c4154346 | ||
![]() |
f2c85308c1 | ||
![]() |
1a28b69ee2 | ||
![]() |
ba09f01ce8 | ||
![]() |
6cf48872b7 | ||
![]() |
7b3b8fa000 | ||
![]() |
ec5e2aae61 | ||
![]() |
86389bf970 | ||
![]() |
3290bfa690 | ||
![]() |
8777fd104f | ||
![]() |
c41f7565ed | ||
![]() |
9ba81e3da4 | ||
![]() |
c23888acd7 | ||
![]() |
f98ce25ab9 | ||
![]() |
de5f38fd48 | ||
![]() |
ec2854b13a | ||
![]() |
90823d2938 | ||
![]() |
5f5770e3a2 | ||
![]() |
28f39e9038 | ||
![]() |
b2d2b37888 | ||
![]() |
fe597e141c | ||
![]() |
72ca1539e0 | ||
![]() |
13b26775f1 | ||
![]() |
05d7118561 | ||
![]() |
98b901ad66 | ||
![]() |
5580b47291 | ||
![]() |
bc62932984 | ||
![]() |
a6b5d6e759 | ||
![]() |
a8931306e1 | ||
![]() |
fecdb8717e | ||
![]() |
916fd273ea | ||
![]() |
0da8506552 | ||
![]() |
eda7a7b43e | ||
![]() |
022eabb734 | ||
![]() |
aba899cef8 | ||
![]() |
6a40e1c176 | ||
![]() |
9307b2ab8b | ||
![]() |
522d8d3917 | ||
![]() |
a84cc0123f | ||
![]() |
f018e248cd | ||
![]() |
cfd7237a80 | ||
![]() |
4eef8102c9 | ||
![]() |
69e4dd506b | ||
![]() |
25814a9458 | ||
![]() |
2a980a76ce | ||
![]() |
d343782c8b | ||
![]() |
4e1994e9d7 | ||
![]() |
65a38c452b | ||
![]() |
7b7e2352cd | ||
![]() |
1177d28395 | ||
![]() |
005e7efa64 | ||
![]() |
b42d13ec84 | ||
![]() |
9adcd1a650 | ||
![]() |
3c164fca8c | ||
![]() |
95e335db7b | ||
![]() |
f90206ad74 | ||
![]() |
3779150750 | ||
![]() |
0a9777aa5c | ||
![]() |
45ad06aac8 | ||
![]() |
c6ea2ba329 | ||
![]() |
2770a10240 | ||
![]() |
d2a94f9e6a | ||
![]() |
32da94507a | ||
![]() |
736a340478 | ||
![]() |
117e1355a2 | ||
![]() |
3c3e558c60 | ||
![]() |
cffceda6ee | ||
![]() |
048805ad2c | ||
![]() |
d14c9fe7ea | ||
![]() |
5db90ce822 | ||
![]() |
d699cc1330 | ||
![]() |
c4230747a1 | ||
![]() |
5245f12a46 | ||
![]() |
a198b2787e | ||
![]() |
04edad8c59 | ||
![]() |
392b3060b0 | ||
![]() |
85b34d59bc | ||
![]() |
f599c11bc8 | ||
![]() |
0792ff02ff | ||
![]() |
fd0d63ba5b | ||
![]() |
3835a428c5 | ||
![]() |
9680f72cca | ||
![]() |
a0737273d3 | ||
![]() |
e613d0eaf0 | ||
![]() |
6bcd6bcf70 | ||
![]() |
ba12e4999a | ||
![]() |
4e7cd31d12 | ||
![]() |
5e6c130d93 | ||
![]() |
5d68082881 | ||
![]() |
607181644f | ||
![]() |
89d327075f | ||
![]() |
6bf00ef631 | ||
![]() |
7d042f17fe | ||
![]() |
28b8079e30 | ||
![]() |
7face5d9fd | ||
![]() |
a44dc4bdb0 | ||
![]() |
2d0f384b6f | ||
![]() |
8ff84b5c43 | ||
![]() |
10b271d963 | ||
![]() |
0ebc8a3d25 | ||
![]() |
bbda0fdbdb | ||
![]() |
c86422bdd4 | ||
![]() |
c707b2b0a6 | ||
![]() |
78ba24c37d | ||
![]() |
1a2cb72030 | ||
![]() |
344a29506e | ||
![]() |
71de73a668 | ||
![]() |
4c1dfa58b7 | ||
![]() |
5274c3c43f | ||
![]() |
1762793989 | ||
![]() |
6cec78d8f2 | ||
![]() |
2dc307f2e6 | ||
![]() |
7aea5b1895 | ||
![]() |
9733e16496 | ||
![]() |
7f2d1024f3 | ||
![]() |
428f589364 | ||
![]() |
5cd97f7ffe | ||
![]() |
e425dc00c0 | ||
![]() |
d274ae77f2 | ||
![]() |
55c5ac7820 | ||
![]() |
0145911bea | ||
![]() |
0a5215693e | ||
![]() |
2a45056ba8 | ||
![]() |
142b77751d | ||
![]() |
a5ededf1c3 | ||
![]() |
7df3f792a2 | ||
![]() |
9eb7d7362f | ||
![]() |
1c0c118f7c | ||
![]() |
1a1b2108ec | ||
![]() |
b6c6552d20 | ||
![]() |
83a0340fa7 | ||
![]() |
a62fc1b39f | ||
![]() |
af1b725fda | ||
![]() |
9174606d4c | ||
![]() |
ca305afdbe | ||
![]() |
fe5987b81d | ||
![]() |
a229c8cef0 | ||
![]() |
f6c0499b8d | ||
![]() |
1156c84e86 | ||
![]() |
ec7c7def40 | ||
![]() |
2d8e667400 | ||
![]() |
80c863b972 | ||
![]() |
f5cc1eea72 | ||
![]() |
b7c9f1d38f | ||
![]() |
c6fc07f1f4 | ||
![]() |
ded914f442 | ||
![]() |
4758c8baa1 | ||
![]() |
7064fed1b1 | ||
![]() |
1017ac4a9e | ||
![]() |
ccb61d7aae | ||
![]() |
2235dee906 | ||
![]() |
28091aa1ff | ||
![]() |
121d9a0702 | ||
![]() |
0cea88bcc5 | ||
![]() |
72146fc4cd | ||
![]() |
e6a7ab9675 | ||
![]() |
1f4c127fb9 | ||
![]() |
90532b1f37 | ||
![]() |
a8666a757a | ||
![]() |
a4667da1eb | ||
![]() |
0c259961ac | ||
![]() |
f288db8d34 | ||
![]() |
33421c1dd3 | ||
![]() |
5cc5201914 | ||
![]() |
252e423e81 | ||
![]() |
a4a2764a52 | ||
![]() |
ab8e832c18 | ||
![]() |
1ce0c0fcb0 | ||
![]() |
657f466402 | ||
![]() |
c7b0300af5 | ||
![]() |
da8c885784 | ||
![]() |
1ccaf80575 | ||
![]() |
ec36bfa317 | ||
![]() |
b8f76f717a | ||
![]() |
d1766f2c70 | ||
![]() |
516ded618b | ||
![]() |
c9c81d0584 | ||
![]() |
545f84d905 | ||
![]() |
d5ec172c95 | ||
![]() |
25b3a3e541 | ||
![]() |
058d6ce683 | ||
![]() |
eab93985b8 | ||
![]() |
b51d70a83c | ||
![]() |
259025100e | ||
![]() |
c9d30aa6ac | ||
![]() |
8544b42007 | ||
![]() |
6fa0501387 | ||
![]() |
ae69cb15e9 | ||
![]() |
a64a8dfe45 | ||
![]() |
491fa95b1f | ||
![]() |
92ec632ad5 | ||
![]() |
8ecdfb718b | ||
![]() |
4ba0c24a8f | ||
![]() |
935c8c4bb1 | ||
![]() |
88f993da38 | ||
![]() |
ebfe64b92d | ||
![]() |
0308e9af71 | ||
![]() |
c3628eea49 | ||
![]() |
e03f0372b1 | ||
![]() |
f17536af9c | ||
![]() |
ed4ec81bca | ||
![]() |
7480059306 | ||
![]() |
8bae22b0fa | ||
![]() |
49c34c4161 | ||
![]() |
5548fcc96d | ||
![]() |
070bd433ab | ||
![]() |
c8fb54951a | ||
![]() |
f110357aaa | ||
![]() |
a6b426422e | ||
![]() |
d03c01dfbc | ||
![]() |
a82996e9fb | ||
![]() |
af5a614aad | ||
![]() |
f9640e049d | ||
![]() |
4768c61b57 | ||
![]() |
dfccd17ab9 | ||
![]() |
635117c5d4 | ||
![]() |
50f3535693 | ||
![]() |
9111999af3 | ||
![]() |
6bd28d246e | ||
![]() |
4d595a2a39 | ||
![]() |
3a21f61772 | ||
![]() |
4e1e9520e1 | ||
![]() |
0bf19037ca | ||
![]() |
f3dfa36a3a | ||
![]() |
4f9b60dd53 | ||
![]() |
f76a49e555 | ||
![]() |
310ad8d9db | ||
![]() |
56db268f47 | ||
![]() |
92ab6bdeb8 | ||
![]() |
0070e360a1 | ||
![]() |
9df8fed046 | ||
![]() |
a59fae040f | ||
![]() |
29a620cab2 | ||
![]() |
87d7a2520e | ||
![]() |
40c62c1321 | ||
![]() |
35b412c099 | ||
![]() |
d0f471cff7 | ||
![]() |
6f316b8bf5 | ||
![]() |
7c10c93a1f | ||
![]() |
d92ea094f1 | ||
![]() |
6ae5423b4a | ||
![]() |
9635cffdc8 | ||
![]() |
96986fb362 | ||
![]() |
3ceb341a75 | ||
![]() |
50fa705125 | ||
![]() |
69a2991614 | ||
![]() |
fd3377dd1f | ||
![]() |
d0b6cb0425 | ||
![]() |
95c4a2e3af | ||
![]() |
bc2a29f033 | ||
![]() |
3bb5b4a302 | ||
![]() |
fc88fd9097 | ||
![]() |
c5b0928c1f | ||
![]() |
e047fd977d | ||
![]() |
9d40e521d7 | ||
![]() |
1445dcaa60 | ||
![]() |
e4eeb4e910 | ||
![]() |
aa86876813 | ||
![]() |
974bb54ab2 | ||
![]() |
9bc2183a31 | ||
![]() |
d4b222b6d3 | ||
![]() |
af2af818a6 | ||
![]() |
698e63a608 | ||
![]() |
211411faf2 | ||
![]() |
bb303c45a5 | ||
![]() |
6f7986d592 | ||
![]() |
7cbb4aef17 | ||
![]() |
02bec0bb6d | ||
![]() |
c79f6a4a8c | ||
![]() |
0c5eea226b | ||
![]() |
dcca0d7477 | ||
![]() |
0d5e7716ad | ||
![]() |
d8c824c594 | ||
![]() |
cb431dfc9f | ||
![]() |
61d787726a | ||
![]() |
5e89aace9b | ||
![]() |
2af7e8a9a6 | ||
![]() |
2419edd5b2 | ||
![]() |
bf481e8e5d | ||
![]() |
9d7fa6b8e6 | ||
![]() |
073076ac7d | ||
![]() |
9bd03dd9b4 | ||
![]() |
6931f84412 | ||
![]() |
16ec0556a0 | ||
![]() |
610af352d4 | ||
![]() |
b35f1e3c9c | ||
![]() |
dfa0b9aab4 | ||
![]() |
a4c47b0276 | ||
![]() |
111fefd5e9 | ||
![]() |
c1fe1ef081 | ||
![]() |
8c34c9dac4 | ||
![]() |
91c0277356 | ||
![]() |
9f0d5c12fc | ||
![]() |
59247c2b62 | ||
![]() |
9a3842a2d9 | ||
![]() |
726dbd9267 | ||
![]() |
54f05e7195 |
@@ -24,8 +24,8 @@ jobs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.medium.gen1
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -85,10 +85,11 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
@@ -108,6 +109,8 @@ jobs:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
@@ -122,10 +125,15 @@ jobs:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.2.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -137,7 +145,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
@@ -146,7 +154,9 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
@@ -160,6 +170,7 @@ jobs:
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
@@ -208,13 +219,18 @@ jobs:
|
||||
default: "3.9"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "15.2.0"
|
||||
default: "16.2.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
resource_class: m2pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -226,7 +242,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
@@ -235,7 +251,7 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEV_RELEASE=1 \
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
@@ -291,7 +307,7 @@ jobs:
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.2.0
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
@@ -330,7 +346,7 @@ workflows:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- build_documentation
|
||||
|
||||
@@ -350,8 +366,70 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
@@ -374,7 +452,7 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
@@ -387,7 +465,54 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
weekly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -398,8 +523,70 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
uv.lock
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
@@ -76,6 +77,9 @@ build/
|
||||
*.out
|
||||
*.app
|
||||
|
||||
# Debug symbols
|
||||
*.pdb
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
.DS_Store
|
||||
|
@@ -1,15 +1,16 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.8
|
||||
rev: v19.1.7
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 24.8.0
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
|
@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||
|
128
CMakeLists.txt
128
CMakeLists.txt
@@ -1,6 +1,24 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
cmake_minimum_required(VERSION 3.25)
|
||||
|
||||
project(mlx LANGUAGES C CXX)
|
||||
if(NOT MLX_VERSION)
|
||||
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
|
||||
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_major ${CMAKE_MATCH_1})
|
||||
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_minor ${CMAKE_MATCH_1})
|
||||
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_patch ${CMAKE_MATCH_1})
|
||||
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
||||
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
||||
else()
|
||||
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
||||
${MLX_VERSION})
|
||||
endif()
|
||||
|
||||
project(
|
||||
mlx
|
||||
LANGUAGES C CXX
|
||||
VERSION ${MLX_PROJECT_VERSION})
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
@@ -20,22 +38,16 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.19.3)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
||||
message(
|
||||
STATUS
|
||||
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
||||
)
|
||||
|
||||
set(MLX_BUILD_ARM OFF)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
if(NOT MLX_ENABLE_X64_MAC)
|
||||
@@ -57,10 +69,6 @@ else()
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
|
||||
include(FetchContent)
|
||||
@@ -89,25 +97,26 @@ elseif(MLX_BUILD_METAL)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_VERSION} LESS 14.0)
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
||||
endif()
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
||||
)
|
||||
# Get the metal version
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
||||
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
@@ -115,20 +124,58 @@ elseif(MLX_BUILD_METAL)
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
||||
if(WIN32)
|
||||
if(MSVC)
|
||||
# GGUF does not build with MSVC.
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
# There is no prebuilt OpenBLAS distribution for MSVC.
|
||||
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
|
||||
endif()
|
||||
# Windows implementation of dlfcn.h APIs.
|
||||
FetchContent_Declare(
|
||||
dlfcn-win32
|
||||
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
||||
GIT_TAG v1.4.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
block()
|
||||
set(BUILD_SHARED_LIBS OFF)
|
||||
FetchContent_MakeAvailable(dlfcn-win32)
|
||||
endblock()
|
||||
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
|
||||
target_link_libraries(mlx PRIVATE dl)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
if(ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(MLX_USE_ACCELERATE)
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
|
||||
# Download and build OpenBLAS from source code.
|
||||
FetchContent_Declare(
|
||||
openblas
|
||||
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
|
||||
GIT_TAG v0.3.28
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(BUILD_STATIC_LIBS ON) # link statically
|
||||
set(NOFORTRAN ON) # msvc has no fortran compiler
|
||||
FetchContent_MakeAvailable(openblas)
|
||||
target_link_libraries(mlx PRIVATE openblas)
|
||||
target_include_directories(
|
||||
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
|
||||
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
|
||||
else()
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
@@ -146,7 +193,7 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old
|
||||
# version of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
@@ -159,29 +206,19 @@ if(MLX_BUILD_CPU)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
||||
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
find_package(MPI)
|
||||
if(MPI_FOUND)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "mpirun --version"
|
||||
OUTPUT_VARIABLE MPI_VERSION
|
||||
ERROR_QUIET)
|
||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
elseif(MPI_VERSION STREQUAL "")
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
||||
else()
|
||||
set(MPI_FOUND FALSE)
|
||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Downloading json")
|
||||
FetchContent_Declare(
|
||||
json
|
||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||
FetchContent_MakeAvailable(json)
|
||||
target_include_directories(
|
||||
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
||||
@@ -206,8 +243,7 @@ if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
OUTPUT_VARIABLE nanobind_ROOT)
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||
endif()
|
||||
|
@@ -5,26 +5,26 @@ possible.
|
||||
|
||||
## Pull Requests
|
||||
|
||||
1. Fork and submit pull requests to the repo.
|
||||
1. Fork and submit pull requests to the repo.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
||||
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
||||
4. If you've changed APIs, update the documentation.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||
This should install hooks for running `black` and `clang-format` to ensure
|
||||
consistent style for C++ and python code.
|
||||
|
||||
|
||||
You can also run the formatters manually as follows:
|
||||
|
||||
```
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```
|
||||
black file.py
|
||||
```
|
||||
|
||||
|
||||
```shell
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```shell
|
||||
black file.py
|
||||
```
|
||||
|
||||
or run `pre-commit run --all-files` to check all files in the repo.
|
||||
|
||||
## Issues
|
||||
|
@@ -1,4 +1,6 @@
|
||||
include CMakeLists.txt
|
||||
include mlx.pc.in
|
||||
recursive-include mlx/ *
|
||||
include cmake/*
|
||||
include python/src/*
|
||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||
|
@@ -5,35 +5,35 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_value_and_grad() {
|
||||
auto x = ones({200, 1000});
|
||||
eval(x);
|
||||
auto fn = [](array x) {
|
||||
auto x = mx::ones({200, 1000});
|
||||
mx::eval(x);
|
||||
auto fn = [](mx::array x) {
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
x = log(exp(x));
|
||||
x = mx::log(mx::exp(x));
|
||||
}
|
||||
return sum(x);
|
||||
return mx::sum(x);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(fn);
|
||||
auto grad_fn = mx::grad(fn);
|
||||
auto independent_value_and_grad = [&]() {
|
||||
auto value = fn(x);
|
||||
auto dfdx = grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(independent_value_and_grad);
|
||||
|
||||
auto value_and_grad_fn = value_and_grad(fn);
|
||||
auto value_and_grad_fn = mx::value_and_grad(fn);
|
||||
auto combined_value_and_grad = [&]() {
|
||||
auto [value, dfdx] = value_and_grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(combined_value_and_grad);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_value_and_grad();
|
||||
}
|
||||
|
@@ -4,21 +4,21 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_add_op() {
|
||||
std::vector<int> sizes(1, 1);
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
sizes.push_back(10 * sizes.back());
|
||||
}
|
||||
set_default_device(Device::cpu);
|
||||
set_default_device(mx::Device::cpu);
|
||||
for (auto size : sizes) {
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
std::cout << "Size " << size << std::endl;
|
||||
TIMEM("cpu", add, a, b, Device::cpu);
|
||||
TIMEM("gpu", add, a, b, Device::gpu);
|
||||
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
|
||||
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -6,105 +6,105 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_irregular_binary_ops_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
b = slice(b, {0}, {size}, {step});
|
||||
TIMEM("1D strided", add, a, b, device);
|
||||
TIMEM("1D strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
auto a = random::uniform({size, size});
|
||||
auto b = random::uniform({size, size});
|
||||
eval(a, b);
|
||||
TIMEM("2D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({size, size});
|
||||
auto b = mx::random::uniform({size, size});
|
||||
mx::eval(a, b);
|
||||
TIMEM("2D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b);
|
||||
eval(b);
|
||||
TIMEM("2D transpose", add, a, b, device);
|
||||
b = mx::transpose(b);
|
||||
mx::eval(b);
|
||||
TIMEM("2D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({size});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({size});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = reshape(b, {size, 1});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 1", add, a, b, device);
|
||||
b = mx::reshape(b, {size, 1});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_3D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int d0 = 32;
|
||||
int d1 = 512;
|
||||
int d2 = 512;
|
||||
auto a = random::uniform({d0, d1, d2});
|
||||
auto b = random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({d0, d1, d2});
|
||||
auto b = mx::random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 2, 1});
|
||||
TIMEM("3D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 2, 1});
|
||||
TIMEM("3D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
||||
b = mx::random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
auto a = random::uniform(shape);
|
||||
auto b = random::uniform(shape);
|
||||
auto a = mx::random::uniform(shape);
|
||||
auto b = mx::random::uniform(shape);
|
||||
|
||||
TIMEM("4D regular", add, a, b, device);
|
||||
TIMEM("4D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
std::string om = "4D broadcast dims ";
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = 1;
|
||||
b = random::uniform(shape);
|
||||
b = mx::random::uniform(shape);
|
||||
std::ostringstream msg;
|
||||
msg << om << i;
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
|
||||
for (int j = i + 1; j < shape.size(); ++j) {
|
||||
shape[j] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[j] = a.shape(j);
|
||||
|
||||
for (int k = j + 1; k < shape.size(); ++k) {
|
||||
shape[k] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j << ", " << k;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[k] = a.shape(k);
|
||||
}
|
||||
}
|
||||
@@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
|
||||
}
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape;
|
||||
auto reshape_fn = [&shape, device](const array& a) {
|
||||
return reshape(a, shape, device);
|
||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||
return mx::reshape(a, shape, device);
|
||||
};
|
||||
|
||||
int size = 64;
|
||||
int d = 2 * size;
|
||||
|
||||
auto a = random::uniform({d, d, d});
|
||||
auto a = mx::random::uniform({d, d, d});
|
||||
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D contiguous", reshape_fn, a);
|
||||
|
||||
a = transpose(a);
|
||||
a = mx::transpose(a);
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose", reshape_fn, a);
|
||||
|
||||
a = transpose(a, {1, 2, 0});
|
||||
a = mx::transpose(a, {1, 2, 0});
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||
}
|
||||
|
||||
void time_irregular_astype_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto a = mx::random::uniform({size});
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
TIMEM("1D strided", astype, a, int32, device);
|
||||
TIMEM("1D strided", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
|
||||
auto a = random::uniform(shape);
|
||||
TIMEM("2D regular", astype, a, int32, device);
|
||||
auto a = mx::random::uniform(shape);
|
||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = transpose(a);
|
||||
TIMEM("2D transpose", astype, a, int32, device);
|
||||
a = mx::transpose(a);
|
||||
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 1) {
|
||||
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
||||
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
|
||||
}
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_irregular_binary_ops_1D();
|
||||
time_irregular_binary_ops_2D();
|
||||
time_irregular_binary_ops_3D();
|
||||
|
@@ -3,20 +3,20 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_creation_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
||||
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
|
||||
TIME(full_fp32);
|
||||
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
||||
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
|
||||
TIME(zeros_fp32);
|
||||
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
||||
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
|
||||
TIME(ones_fp32);
|
||||
|
||||
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
||||
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
|
||||
TIME(arange_fp32);
|
||||
}
|
||||
|
||||
@@ -24,194 +24,196 @@ void time_type_conversions() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = zeros(shape, float32);
|
||||
eval(a);
|
||||
TIMEM("float32 to int32", astype, a, int32, device);
|
||||
TIMEM("float32 to uint32", astype, a, uint32, device);
|
||||
auto a = mx::zeros(shape, mx::float32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
|
||||
a = zeros(shape, int32);
|
||||
eval(a);
|
||||
TIMEM("int32 to float32", astype, a, float32, device);
|
||||
a = mx::zeros(shape, mx::int32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
|
||||
|
||||
a = zeros(shape, bool_);
|
||||
eval(a);
|
||||
TIMEM("bool to float32", astype, a, float32, device);
|
||||
TIMEM("bool to int32", astype, a, int32, device);
|
||||
TIMEM("bool to uint32", astype, a, uint32, device);
|
||||
a = mx::zeros(shape, mx::bool_);
|
||||
mx::eval(a);
|
||||
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
|
||||
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
}
|
||||
|
||||
void time_random_generation() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
|
||||
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
||||
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
|
||||
TIME(uniform);
|
||||
auto normal = [&]() { return random::normal({M, N}, float32); };
|
||||
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
|
||||
TIME(normal);
|
||||
}
|
||||
|
||||
void time_unary_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = random::normal({M, N});
|
||||
eval(a);
|
||||
auto a = mx::random::normal({M, N});
|
||||
mx::eval(a);
|
||||
TIME(mlx::core::abs, a, device);
|
||||
TIME(negative, a, device);
|
||||
TIME(sign, a, device);
|
||||
TIME(square, a, device);
|
||||
TIME(mx::negative, a, device);
|
||||
TIME(mx::sign, a, device);
|
||||
TIME(mx::square, a, device);
|
||||
TIME(mlx::core::sqrt, a, device);
|
||||
TIME(rsqrt, a, device);
|
||||
TIME(mx::rsqrt, a, device);
|
||||
TIME(mlx::core::exp, a, device);
|
||||
|
||||
a = random::uniform({M, N});
|
||||
a = mx::random::uniform({M, N});
|
||||
TIME(mlx::core::log, a, device);
|
||||
}
|
||||
|
||||
void time_binary_ops() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto condition = random::randint(0, 2, {M, N, K});
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
auto condition = mx::random::randint(0, 2, {M, N, K});
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
|
||||
TIME(add, a, b, device);
|
||||
TIME(subtract, a, b, device);
|
||||
TIME(multiply, a, b, device);
|
||||
TIME(divide, a, b, device);
|
||||
TIME(maximum, a, b, device);
|
||||
TIME(minimum, a, b, device);
|
||||
TIME(where, condition, a, b, device);
|
||||
TIME(mx::add, a, b, device);
|
||||
TIME(mx::subtract, a, b, device);
|
||||
TIME(mx::multiply, a, b, device);
|
||||
TIME(mx::divide, a, b, device);
|
||||
TIME(mx::maximum, a, b, device);
|
||||
TIME(mx::minimum, a, b, device);
|
||||
TIME(mx::where, condition, a, b, device);
|
||||
|
||||
condition = array({true});
|
||||
b = random::uniform({1});
|
||||
eval(b);
|
||||
TIMEM("scalar", add, a, b, device);
|
||||
TIMEM("vector-scalar", subtract, a, b, device);
|
||||
TIMEM("scalar-vector", subtract, b, a, device);
|
||||
TIMEM("scalar", multiply, a, b, device);
|
||||
TIMEM("vector-scalar", divide, a, b, device);
|
||||
TIMEM("scalar-vector", divide, b, a, device);
|
||||
TIMEM("scalar-vector", where, condition, a, b, device);
|
||||
condition = mx::array({true});
|
||||
b = mx::random::uniform({1});
|
||||
mx::eval(b);
|
||||
TIMEM("scalar", mx::add, a, b, device);
|
||||
TIMEM("vector-scalar", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-vector", mx::subtract, b, a, device);
|
||||
TIMEM("scalar", mx::multiply, a, b, device);
|
||||
TIMEM("vector-scalar", mx::divide, a, b, device);
|
||||
TIMEM("scalar-vector", mx::divide, b, a, device);
|
||||
TIMEM("scalar-vector", mx::where, condition, a, b, device);
|
||||
|
||||
condition = broadcast_to(array({true}), {1000, 100});
|
||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
|
||||
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
|
||||
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
mx::eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
|
||||
}
|
||||
|
||||
void time_strided_ops() {
|
||||
int M = 50, N = 50, O = 50, P = 50;
|
||||
auto a = random::uniform({M, N, O, P});
|
||||
auto b = random::uniform({M, N, O, P});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIMEM("non-strided", add, a, b, device);
|
||||
a = transpose(a, {1, 0, 2, 3});
|
||||
b = transpose(b, {3, 2, 0, 1});
|
||||
eval(a, b);
|
||||
TIMEM("strided", add, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, O, P});
|
||||
auto b = mx::random::uniform({M, N, O, P});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIMEM("non-strided", mx::add, a, b, device);
|
||||
a = mx::transpose(a, {1, 0, 2, 3});
|
||||
b = mx::transpose(b, {3, 2, 0, 1});
|
||||
mx::eval(a, b);
|
||||
TIMEM("strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_comparisons() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(equal, a, b, device);
|
||||
TIME(greater, a, b, device);
|
||||
TIME(greater_equal, a, b, device);
|
||||
TIME(less, a, b, device);
|
||||
TIME(less_equal, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::equal, a, b, device);
|
||||
TIME(mx::greater, a, b, device);
|
||||
TIME(mx::greater_equal, a, b, device);
|
||||
TIME(mx::less, a, b, device);
|
||||
TIME(mx::less_equal, a, b, device);
|
||||
}
|
||||
|
||||
void time_matvec() {
|
||||
int M = 2000, N = 200;
|
||||
auto a = random::uniform({M, N});
|
||||
auto b = random::uniform({N});
|
||||
auto c = random::uniform({M});
|
||||
eval(a, b, c);
|
||||
auto matvec = [&]() { return matmul(a, b); };
|
||||
auto a = mx::random::uniform({M, N});
|
||||
auto b = mx::random::uniform({N});
|
||||
auto c = mx::random::uniform({M});
|
||||
mx::eval(a, b, c);
|
||||
auto matvec = [&]() { return mx::matmul(a, b); };
|
||||
TIME(matvec);
|
||||
|
||||
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
||||
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
|
||||
TIME(matvec_transpose);
|
||||
}
|
||||
|
||||
void time_matmul() {
|
||||
int M = 1000, N = 1000, K = 1000;
|
||||
auto a = random::uniform({M, K});
|
||||
auto b = random::uniform({K, N});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(matmul, a, b, device);
|
||||
auto a = mx::random::uniform({M, K});
|
||||
auto b = mx::random::uniform({K, N});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::matmul, a, b, device);
|
||||
|
||||
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
||||
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
|
||||
TIME(transpose_matmul);
|
||||
}
|
||||
|
||||
void time_reductions() {
|
||||
auto a = random::normal({10000, 1000});
|
||||
eval(a);
|
||||
auto sum_all = [&a]() { return sum(a, false); };
|
||||
auto a = mx::random::normal({10000, 1000});
|
||||
mx::eval(a);
|
||||
auto sum_all = [&a]() { return mx::sum(a, false); };
|
||||
TIME(sum_all);
|
||||
|
||||
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
||||
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
|
||||
TIME(sum_along_0);
|
||||
|
||||
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
||||
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
|
||||
TIME(sum_along_1);
|
||||
|
||||
auto prod_all = [&a]() { return prod(a, false); };
|
||||
auto prod_all = [&a]() { return mx::prod(a, false); };
|
||||
TIME(prod_all);
|
||||
|
||||
auto all_true = [&a]() { return all(a, false); };
|
||||
auto all_true = [&a]() { return mx::all(a, false); };
|
||||
TIME(all_true);
|
||||
|
||||
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
||||
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
|
||||
TIME(all_along_0);
|
||||
|
||||
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
||||
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
|
||||
TIME(all_along_1);
|
||||
|
||||
auto any_true = [&a]() { return any(a, false); };
|
||||
auto any_true = [&a]() { return mx::any(a, false); };
|
||||
TIME(any_true);
|
||||
|
||||
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
||||
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
|
||||
TIME(argmin_along_0);
|
||||
|
||||
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
auto a = random::normal({1000, 768});
|
||||
eval(a);
|
||||
auto indices = random::randint(0, 1000, {256});
|
||||
eval(indices);
|
||||
auto a = mx::random::normal({1000, 768});
|
||||
mx::eval(a);
|
||||
auto indices = mx::random::randint(0, 1000, {256});
|
||||
mx::eval(indices);
|
||||
|
||||
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
||||
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
|
||||
TIME(embedding_lookup);
|
||||
|
||||
indices = random::randint(0, 768 * 1000, {256 * 768});
|
||||
eval(indices);
|
||||
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
|
||||
mx::eval(indices);
|
||||
|
||||
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
||||
auto single_element_lookup = [&a, &indices]() {
|
||||
return mx::take(a, indices);
|
||||
};
|
||||
TIME(single_element_lookup);
|
||||
|
||||
indices = random::randint(0, 1000, {256});
|
||||
auto updates = random::normal({256, 1, 768});
|
||||
eval(indices, updates);
|
||||
indices = mx::random::randint(0, 1000, {256});
|
||||
auto updates = mx::random::normal({256, 1, 768});
|
||||
mx::eval(indices, updates);
|
||||
|
||||
auto embedding_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -223,10 +225,10 @@ void time_gather_scatter() {
|
||||
};
|
||||
TIME(embedding_add);
|
||||
|
||||
a = reshape(a, {-1});
|
||||
indices = random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = random::normal({256 * 768, 1});
|
||||
eval(a, indices, updates);
|
||||
a = mx::reshape(a, {-1});
|
||||
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = mx::random::normal({256 * 768, 1});
|
||||
mx::eval(a, indices, updates);
|
||||
|
||||
auto single_element_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -240,21 +242,21 @@ void time_gather_scatter() {
|
||||
}
|
||||
|
||||
void time_divmod() {
|
||||
auto a = random::normal({1000});
|
||||
auto b = random::normal({1000});
|
||||
eval({a, b});
|
||||
auto a = mx::random::normal({1000});
|
||||
auto b = mx::random::normal({1000});
|
||||
mx::eval({a, b});
|
||||
|
||||
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
|
||||
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
|
||||
TIME(divmod_fused);
|
||||
|
||||
auto divmod_separate = [&a, &b]() {
|
||||
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
|
||||
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
|
||||
};
|
||||
TIME(divmod_separate);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_creation_ops();
|
||||
time_type_conversions();
|
||||
time_unary_ops();
|
||||
|
@@ -1,7 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from time import time
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
|
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_mm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = x @ w1.T
|
||||
x = x @ w2.T
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_mm()
|
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate(
|
||||
[
|
||||
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||
for i, j in enumerate(idx.tolist())
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_qmm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_qmm()
|
@@ -10,7 +10,12 @@ def layer_norm(x, w, b, eps):
|
||||
x = x.astype(mx.float32)
|
||||
mu = mx.mean(x, -1, keepdims=True)
|
||||
v = mx.var(x, -1, keepdims=True)
|
||||
return (x - mu) * mx.rsqrt(v + eps) * w + b
|
||||
y = (x - mu) * mx.rsqrt(v + eps)
|
||||
if w is not None:
|
||||
y = y * w
|
||||
if b is not None:
|
||||
y = y + b
|
||||
return y
|
||||
|
||||
|
||||
def time_layer_norm():
|
||||
@@ -36,6 +41,28 @@ def time_layer_norm():
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(layer_norm_loop, g1, x)
|
||||
time_fn(layer_norm_loop, g2, x)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_layer_norm()
|
||||
|
@@ -9,7 +9,10 @@ def rms_norm(x, w, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return (x * n).astype(ot) * w
|
||||
y = (x * n).astype(ot)
|
||||
if w is not None:
|
||||
y = y * w
|
||||
return y
|
||||
|
||||
|
||||
def time_rms_norm():
|
||||
@@ -34,6 +37,27 @@ def time_rms_norm():
|
||||
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
||||
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
||||
|
||||
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, y)
|
||||
|
||||
def rms_norm_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(rms_norm_loop, g1, x)
|
||||
time_fn(rms_norm_loop, g2, x)
|
||||
time_fn(rms_norm_loop, mx.compile(g1), x)
|
||||
time_fn(rms_norm_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_rms_norm()
|
||||
|
@@ -1,62 +1,223 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
import numpy as np
|
||||
|
||||
MAX_SEQ = 300
|
||||
START_SEQ = 100
|
||||
SEQ_INCREMENT = 50
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 5
|
||||
N_iter_bench = 40
|
||||
N_iter_func = 8
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
def bench(f, *args):
|
||||
for i in range(N_warmup):
|
||||
f(*args)
|
||||
|
||||
def sdpa_primitives(qs, ks, vs, alpha):
|
||||
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ vs
|
||||
return o
|
||||
|
||||
time_fn(sdpa_primitives, q, k, v, scale)
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(*args)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
def sdpa_fused(qs, ks, vs, alpha):
|
||||
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
|
||||
return o
|
||||
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||
|
||||
time_fn(sdpa_fused, q, k, v, scale)
|
||||
scale = 1.0 / math.sqrt(D)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
if mask is not None:
|
||||
if mask == "additive":
|
||||
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
||||
mask = mx.array(mask_np)
|
||||
elif mask == "bool":
|
||||
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||
mask = mx.array(mask_np)
|
||||
|
||||
return q_mx, k_mx, v_mx, scale, mask
|
||||
|
||||
|
||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
n_kv_heads = k.shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
B = q.shape[0]
|
||||
L = q.shape[2]
|
||||
kL = k.shape[2]
|
||||
|
||||
if n_repeats > 1:
|
||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||
k = mx.expand_dims(k, 2)
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
if mask == "causal":
|
||||
q_offset = max(0, kL - L)
|
||||
q_indices = mx.arange(q_offset, q_offset + L)
|
||||
k_indices = mx.arange(kL)
|
||||
mask = q_indices[:, None] >= k_indices[None]
|
||||
|
||||
if n_repeats > 1 and mask.ndim >= 3:
|
||||
if mask.shape[-3] == 1:
|
||||
mask = mx.expand_dims(mask, -3)
|
||||
else:
|
||||
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||
|
||||
if mask.dtype == mx.bool_:
|
||||
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||
else:
|
||||
scores += mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mlx_fused_attn(q, k, v, scale, mask):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||
if transpose:
|
||||
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||
else:
|
||||
return f(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
||||
q_out = q
|
||||
|
||||
for i in range(N_iter_func):
|
||||
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
||||
):
|
||||
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
||||
)
|
||||
|
||||
time_mlx_unfused = bench(
|
||||
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
time_mlx_fused = bench(
|
||||
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
|
||||
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
||||
o_mlx_unfused = do_attention(
|
||||
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
|
||||
atol = 1e-5 if dtype == "float32" else 2e-4
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
||||
print(
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
)
|
||||
|
||||
return time_mlx_fused, time_mlx_unfused
|
||||
|
||||
|
||||
def get_gflop_count(B, M, N, K):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
dtypes = ("float16", "float32")[:1]
|
||||
transposes = (False,)
|
||||
|
||||
# fmt: off
|
||||
shapes_64 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 32, 32, 64, 32, 32),
|
||||
( 1, 64, 64, 64, 32, 32),
|
||||
( 1, 128, 128, 64, 32, 32),
|
||||
( 1, 256, 256, 64, 32, 32),
|
||||
( 1, 512, 512, 64, 32, 32),
|
||||
( 1, 1024, 1024, 64, 32, 8),
|
||||
( 1, 2048, 2048, 64, 32, 8),
|
||||
( 1, 4096, 4096, 64, 32, 8),
|
||||
)
|
||||
|
||||
shapes_80 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 80, 32, 8),
|
||||
( 1, 2048, 2048, 80, 32, 8),
|
||||
( 1, 4096, 4096, 80, 32, 8),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 128, 32, 8),
|
||||
( 1, 2048, 2048, 128, 32, 8),
|
||||
( 1, 4096, 4096, 128, 32, 8),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_80 + shapes_128
|
||||
|
||||
masks = [None, "bool", "causal"]
|
||||
|
||||
print(
|
||||
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||
for mask_in in masks:
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B,
|
||||
qsl,
|
||||
ksl,
|
||||
head_dim,
|
||||
n_q_heads,
|
||||
n_kv_heads,
|
||||
dtype,
|
||||
transpose,
|
||||
mask_in,
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
|
@@ -4,46 +4,92 @@ import math
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 1024
|
||||
L = 16384
|
||||
H = 32
|
||||
H_k = 32 // 4
|
||||
H_k = H // 4
|
||||
D = 128
|
||||
V = 128
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, D)
|
||||
def upproject(x, w):
|
||||
if w is None:
|
||||
return x
|
||||
else:
|
||||
return x @ w.T
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
def attention(q, k, v, mask=None, w=None):
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
_, _, _, V = v.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
if mask is not None:
|
||||
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
|
||||
s = mx.where(m, s, mx.finfo(s.dtype).min)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, V)
|
||||
|
||||
for i in range(loops):
|
||||
q = _sdpa(q, k, v)
|
||||
q = upproject(q, w)
|
||||
return q
|
||||
|
||||
|
||||
def sdpa(q, k, v, mask=None, w=None):
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
q = upproject(q, w)
|
||||
return q
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
time_fn(attention, q, k, v)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mx.eval(q, k, v, w)
|
||||
time_fn(attention, q, k, v, w=w)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
time_fn(sdpa, q, k, v)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mx.eval(q, k, v, w)
|
||||
time_fn(sdpa, q, k, v, w=w)
|
||||
|
||||
|
||||
def time_self_attention_sdpa_with_mask():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mask = mx.full((L,), True)
|
||||
mask[L // 2 :] = False
|
||||
mx.eval(q, k, v, mask, w)
|
||||
|
||||
def sdpa_mask(*args):
|
||||
return sdpa(*args, mask=mask, w=w)
|
||||
|
||||
def attention_mask(*args):
|
||||
return attention(*args, mask=mask, w=w)
|
||||
|
||||
time_fn(attention_mask, q, k, v)
|
||||
time_fn(sdpa_mask, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
time_self_attention_sdpa_with_mask()
|
||||
|
55
benchmarks/python/synchronize_bench.py
Normal file
55
benchmarks/python/synchronize_bench.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
rank = mx.distributed.init().rank()
|
||||
|
||||
|
||||
def timeit(fn, a):
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(a))
|
||||
|
||||
its = 10
|
||||
tic = time.perf_counter()
|
||||
for _ in range(its):
|
||||
mx.eval(fn(a))
|
||||
toc = time.perf_counter()
|
||||
ms = 1000 * (toc - tic) / its
|
||||
return ms
|
||||
|
||||
|
||||
def all_reduce_benchmark():
|
||||
a = mx.ones((5, 5), mx.int32)
|
||||
|
||||
its_per_eval = 100
|
||||
|
||||
def fn(x):
|
||||
for _ in range(its_per_eval):
|
||||
x = mx.distributed.all_sum(x)
|
||||
x = x - 1
|
||||
return x
|
||||
|
||||
ms = timeit(fn, a) / its_per_eval
|
||||
if rank == 0:
|
||||
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
|
||||
|
||||
|
||||
def all_gather_benchmark():
|
||||
a = mx.ones((5, 5), mx.int32)
|
||||
its_per_eval = 100
|
||||
|
||||
def fn(x):
|
||||
for _ in range(its_per_eval):
|
||||
x = mx.distributed.all_gather(x)[0]
|
||||
return x
|
||||
|
||||
ms = timeit(fn, a) / its_per_eval
|
||||
if rank == 0:
|
||||
print(f"All gather: time per iteration {ms:.6f} (ms)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
all_reduce_benchmark()
|
||||
all_gather_benchmark()
|
@@ -1,5 +1,7 @@
|
||||
include(CMakeParseArguments)
|
||||
|
||||
# clang format off
|
||||
#
|
||||
# ##############################################################################
|
||||
# Build metal library
|
||||
#
|
||||
@@ -11,6 +13,8 @@ include(CMakeParseArguments)
|
||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||
# files (like headers)
|
||||
#
|
||||
# clang format on
|
||||
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||
@@ -21,7 +25,7 @@ macro(mlx_build_metallib)
|
||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
|
@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
|
||||
CREATE_SUBDIRS = NO
|
||||
FULL_PATH_NAMES = YES
|
||||
RECURSIVE = YES
|
||||
GENERATE_HTML = YES
|
||||
GENERATE_HTML = NO
|
||||
GENERATE_LATEX = NO
|
||||
GENERATE_XML = YES
|
||||
XML_PROGRAMLISTING = YES
|
||||
|
@@ -22,12 +22,12 @@ You can do that in MLX directly:
|
||||
This function performs that operation while leaving the implementation and
|
||||
function transformations to MLX.
|
||||
|
||||
However you may need to customize the underlying implementation, perhaps to
|
||||
make it faster or for custom differentiation. In this tutorial we will go
|
||||
through adding custom extensions. It will cover:
|
||||
However, you may want to customize the underlying implementation, perhaps to
|
||||
make it faster. In this tutorial we will go through adding custom extensions.
|
||||
It will cover:
|
||||
|
||||
* The structure of the MLX library.
|
||||
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
|
||||
* Implementing a CPU operation.
|
||||
* Implementing a GPU operation using metal.
|
||||
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||
* Building a custom extension and binding it to python.
|
||||
@@ -45,7 +45,7 @@ Operations
|
||||
Operations are the front-end functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
|
||||
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
|
||||
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||
C++:
|
||||
|
||||
@@ -55,7 +55,7 @@ C++:
|
||||
* Scale and sum two vectors element-wise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Use NumPy-style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
@@ -66,7 +66,7 @@ C++:
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
The simplest way to this operation is in terms of existing operations:
|
||||
The simplest way to implement this is with existing operations:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -93,9 +93,9 @@ Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
defines how to create output arrays given input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
@@ -128,7 +128,7 @@ more concrete:
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
@@ -153,9 +153,6 @@ more concrete:
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||
@@ -188,7 +185,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = is_floating_point(promoted_dtype)
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
@@ -234,49 +231,57 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
||||
Implementing the CPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's start by implementing a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
Let's start by implementing :meth:`Axpby::eval_cpu`.
|
||||
|
||||
Our naive method will go over each element of the output array, find the
|
||||
The method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
const T* y_ptr = y.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(y);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
// Launch the CPU kernel
|
||||
encoder.dispatch([x_ptr = x.data<T>(),
|
||||
y_ptr = y.data<T>(),
|
||||
out_ptr = out.data<T>(),
|
||||
size = out.size(),
|
||||
shape = out.shape(),
|
||||
x_strides = x.strides(),
|
||||
y_strides = y.strides(),
|
||||
alpha_,
|
||||
beta_]() {
|
||||
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Our implementation should work for all incoming floating point arrays.
|
||||
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
@@ -284,112 +289,32 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[Axpby] Only supports floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
This is good as a fallback implementation. We can use the ``axpby`` routine
|
||||
provided by the Accelerate_ framework for a faster implementation in certain
|
||||
cases:
|
||||
|
||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||
floats. We can only use it for ``float32`` types.
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
|
||||
elements have fixed strides between them. We only direct to Accelerate
|
||||
if both ``x`` and ``y`` are row contiguous or column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
|
||||
MLX expects to write the output to a new array. We must copy the elements
|
||||
of ``y`` into the output and use that as an input to ``axpby``.
|
||||
|
||||
Let's write an implementation that uses Accelerate in the right conditions.
|
||||
It allocates data for the output, copies ``y`` into it, and then calls the
|
||||
:func:`catlas_saxpby` from accelerate.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
T* y_ptr = out.data<T>();
|
||||
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Call the inplace accelerate operator
|
||||
catlas_saxpby(
|
||||
/* N = */ out.size(),
|
||||
/* ALPHA = */ alpha,
|
||||
/* X = */ x_ptr,
|
||||
/* INCX = */ 1,
|
||||
/* BETA = */ beta,
|
||||
/* Y = */ y_ptr,
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
For inputs that do not fit the criteria for accelerate, we fall back to
|
||||
:meth:`Axpby::eval`. With this in mind, let's finish our
|
||||
:meth:`Axpby::eval_cpu`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common back-end if specializations are not available
|
||||
eval(inputs, outputs);
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||
you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
primitive here.
|
||||
|
||||
Implementing the GPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@@ -420,8 +345,8 @@ element in the output.
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int64_t* x_strides [[buffer(6)]],
|
||||
constant const int64_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
// Convert linear indices to offsets in array
|
||||
@@ -438,24 +363,10 @@ each instantiation a unique host name so we can identify it.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] \
|
||||
[[kernel]] void axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
instantiate_kernel("axpby_general_float32", axpby_general, float)
|
||||
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
|
||||
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
|
||||
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
|
||||
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
@@ -480,7 +391,7 @@ below.
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
@@ -494,7 +405,7 @@ below.
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -509,14 +420,14 @@ below.
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
@@ -530,7 +441,7 @@ below.
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
@@ -558,7 +469,7 @@ one we just defined:
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// The jvp transform on the primitive can be built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
@@ -570,7 +481,7 @@ one we just defined:
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// If argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
@@ -824,7 +735,7 @@ Let's look at a simple script and its results:
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
@@ -832,13 +743,13 @@ Output:
|
||||
|
||||
c shape: [3, 4]
|
||||
c dtype: float32
|
||||
c correctness: True
|
||||
c is correct: True
|
||||
|
||||
Results
|
||||
^^^^^^^
|
||||
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
with the naive :meth:`simple_axpby` we first defined.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -846,13 +757,11 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
from mlx_sample_extensions import axpby
|
||||
import time
|
||||
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
M = 256
|
||||
N = 512
|
||||
M = 4096
|
||||
N = 4096
|
||||
|
||||
x = mx.random.normal((M, N))
|
||||
y = mx.random.normal((M, N))
|
||||
@@ -863,24 +772,24 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
|
||||
def bench(f):
|
||||
# Warm up
|
||||
for i in range(100):
|
||||
for i in range(5):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
|
||||
# Timed run
|
||||
s = time.time()
|
||||
for i in range(5000):
|
||||
for i in range(100):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
e = time.time()
|
||||
return e - s
|
||||
return 1000 * (e - s) / 100
|
||||
|
||||
simple_time = bench(simple_axpby)
|
||||
custom_time = bench(axpby)
|
||||
|
||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
|
||||
|
||||
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
|
||||
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
|
||||
modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations, in
|
||||
|
121
docs/src/dev/mlx_in_cpp.rst
Normal file
121
docs/src/dev/mlx_in_cpp.rst
Normal file
@@ -0,0 +1,121 @@
|
||||
.. _mlx_in_cpp:
|
||||
|
||||
Using MLX in C++
|
||||
================
|
||||
|
||||
You can use MLX in a C++ project with CMake.
|
||||
|
||||
.. note::
|
||||
|
||||
This guide is based one the following `example using MLX in C++
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
|
||||
|
||||
First install MLX:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U mlx
|
||||
|
||||
You can also install the MLX Python package from source or just the C++
|
||||
library. For more information see the :ref:`documentation on installing MLX
|
||||
<build_and_install>`.
|
||||
|
||||
Next make an example program in ``example.cpp``:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
auto x = mx::array({1, 2, 3});
|
||||
auto y = mx::array({1, 2, 3});
|
||||
std::cout << x + y << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
The next step is to setup a CMake file in ``CMakeLists.txt``:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(example LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
|
||||
Depending on how you installed MLX, you may need to tell CMake where to
|
||||
find it.
|
||||
|
||||
If you installed MLX with Python, then add the following to the CMake file:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
|
||||
If you installed the MLX C++ package to a system path, then CMake should be
|
||||
able to find it. If you installed it to a non-standard location or CMake can't
|
||||
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
set(MLX_ROOT "/path/to/mlx/")
|
||||
|
||||
Next, instruct CMake to find MLX:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
Finally, add the ``example.cpp`` program as an executable and link MLX.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
add_executable(example example.cpp)
|
||||
target_link_libraries(example PRIVATE mlx)
|
||||
|
||||
You can build the example with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
|
||||
And run it with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./build/example
|
||||
|
||||
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
|
||||
|
||||
.. list-table:: Package Variables
|
||||
:widths: 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - Variable
|
||||
- Description
|
||||
* - MLX_FOUND
|
||||
- ``True`` if MLX is found
|
||||
* - MLX_INCLUDE_DIRS
|
||||
- Include directory
|
||||
* - MLX_LIBRARIES
|
||||
- Libraries to link against
|
||||
* - MLX_CXX_FLAGS
|
||||
- Additional compiler flags
|
||||
* - MLX_BUILD_ACCELERATE
|
||||
- ``True`` if MLX was built with Accelerate
|
||||
* - MLX_BUILD_METAL
|
||||
- ``True`` if MLX was built with Metal
|
@@ -45,6 +45,7 @@ are the CPU and GPU.
|
||||
usage/numpy
|
||||
usage/distributed
|
||||
usage/using_streams
|
||||
usage/export
|
||||
|
||||
.. toctree::
|
||||
:caption: Examples
|
||||
@@ -61,6 +62,7 @@ are the CPU and GPU.
|
||||
python/array
|
||||
python/data_types
|
||||
python/devices_and_streams
|
||||
python/export
|
||||
python/ops
|
||||
python/random
|
||||
python/transforms
|
||||
@@ -68,6 +70,7 @@ are the CPU and GPU.
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
@@ -86,3 +89,4 @@ are the CPU and GPU.
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
dev/custom_metal_kernels
|
||||
dev/mlx_in_cpp
|
||||
|
@@ -1,3 +1,5 @@
|
||||
.. _build_and_install:
|
||||
|
||||
Build and Install
|
||||
=================
|
||||
|
||||
@@ -53,7 +55,7 @@ Build Requirements
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
|
||||
- Xcode >= 15.0 and macOS SDK >= 14.0
|
||||
|
||||
.. note::
|
||||
@@ -209,7 +211,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists accross reboots.
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
@@ -38,6 +38,7 @@ Array
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logcumsumexp
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
|
@@ -51,11 +51,20 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``float32``
|
||||
- 4
|
||||
- 32-bit float
|
||||
* - ``float64``
|
||||
- 4
|
||||
- 64-bit double
|
||||
* - ``complex64``
|
||||
- 8
|
||||
- 64-bit complex float
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Arrays with type ``float64`` only work with CPU operations. Using
|
||||
``float64`` arrays on the GPU will result in an exception.
|
||||
|
||||
|
||||
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
``dtype`` (or category) is a subtype of another category.
|
||||
@@ -66,3 +75,4 @@ documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
Dtype
|
||||
DtypeCategory
|
||||
issubdtype
|
||||
finfo
|
||||
|
14
docs/src/python/export.rst
Normal file
14
docs/src/python/export.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
.. _export:
|
||||
|
||||
Export Functions
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
export_function
|
||||
import_function
|
||||
exporter
|
||||
export_to_dot
|
@@ -12,5 +12,4 @@ Fast
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
affine_quantize
|
||||
metal_kernel
|
||||
|
@@ -20,3 +20,5 @@ FFT
|
||||
irfft2
|
||||
rfftn
|
||||
irfftn
|
||||
fftshift
|
||||
ifftshift
|
||||
|
@@ -5,8 +5,8 @@ Linear Algebra
|
||||
|
||||
.. currentmodule:: mlx.core.linalg
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
inv
|
||||
tri_inv
|
||||
@@ -18,3 +18,8 @@ Linear Algebra
|
||||
svd
|
||||
eigvalsh
|
||||
eigh
|
||||
lu
|
||||
lu_factor
|
||||
pinv
|
||||
solve
|
||||
solve_triangular
|
||||
|
16
docs/src/python/memory_management.rst
Normal file
16
docs/src/python/memory_management.rst
Normal file
@@ -0,0 +1,16 @@
|
||||
Memory Management
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
@@ -8,13 +8,5 @@ Metal
|
||||
|
||||
is_available
|
||||
device_info
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@@ -174,6 +174,7 @@ In detail:
|
||||
|
||||
value_and_grad
|
||||
quantize
|
||||
average_gradients
|
||||
|
||||
.. toctree::
|
||||
|
||||
|
@@ -12,6 +12,7 @@ Layers
|
||||
ALiBi
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
AvgPool3d
|
||||
BatchNorm
|
||||
CELU
|
||||
Conv1d
|
||||
@@ -41,6 +42,7 @@ Layers
|
||||
LSTM
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
MaxPool3d
|
||||
Mish
|
||||
MultiHeadAttention
|
||||
PReLU
|
||||
|
@@ -32,13 +32,16 @@ Operations
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
bitwise_and
|
||||
bitwise_invert
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
broadcast_arrays
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
contiguous
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
@@ -89,6 +92,7 @@ Operations
|
||||
isneginf
|
||||
isposinf
|
||||
issubdtype
|
||||
kron
|
||||
left_shift
|
||||
less
|
||||
less_equal
|
||||
@@ -99,6 +103,7 @@ Operations
|
||||
log10
|
||||
log1p
|
||||
logaddexp
|
||||
logcumsumexp
|
||||
logical_not
|
||||
logical_and
|
||||
logical_or
|
||||
@@ -144,6 +149,8 @@ Operations
|
||||
sign
|
||||
sin
|
||||
sinh
|
||||
slice
|
||||
slice_update
|
||||
softmax
|
||||
sort
|
||||
split
|
||||
@@ -168,6 +175,7 @@ Operations
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
unflatten
|
||||
var
|
||||
view
|
||||
where
|
||||
|
@@ -18,3 +18,4 @@ Common Optimizers
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
|
@@ -9,6 +9,7 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
async_eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
|
@@ -421,3 +421,77 @@ the most opportunity to optimize the computation graph:
|
||||
# Compiling the outer function is good to do as it will likely
|
||||
# be faster even though the inner functions are compiled
|
||||
fun = mx.compile(outer)
|
||||
|
||||
|
||||
|
||||
.. _shapeless_compile:
|
||||
|
||||
Shapeless Compilation
|
||||
---------------------
|
||||
|
||||
When the shape of an input to a compiled function changes, the function is
|
||||
recompiled. You can compile a function once and run it on inputs with
|
||||
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
|
||||
case changes to the shapes of the inputs do not cause the function to be
|
||||
recompiled.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.abs(x + y)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(-2.0)
|
||||
|
||||
# Firt call compiles the function
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
# Second call with different shapes
|
||||
# does not recompile the function
|
||||
x = mx.array([1.0, -6.0])
|
||||
y = mx.array([-2.0, 3.0])
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
|
||||
Use shapeless compilations carefully. Since compilation is not triggered when
|
||||
shapes change, any graphs which are conditional on the input shapes will not
|
||||
work as expected. Shape-dependent computations are common and sometimes subtle
|
||||
to detect. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return x.reshape(x.shape[0] * x.shape[1], -1)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.random.uniform(shape=(2, 3, 4))
|
||||
|
||||
out = compiled_fun(x)
|
||||
|
||||
x = mx.random.uniform(shape=(5, 5, 3))
|
||||
|
||||
# Error, can't reshape (5, 5, 3) to (6, -1)
|
||||
out = compiled_fun(x)
|
||||
|
||||
The second call to the ``compiled_fun`` fails because of the call to
|
||||
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
|
||||
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return x.flatten(0, 1)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.random.uniform(shape=(2, 3, 4))
|
||||
|
||||
out = compiled_fun(x)
|
||||
|
||||
x = mx.random.uniform(shape=(5, 5, 3))
|
||||
|
||||
# Ok
|
||||
out = compiled_fun(x)
|
||||
|
@@ -5,21 +5,27 @@ Distributed Communication
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
||||
provide distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. You can
|
||||
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
||||
MLX supports distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. At the
|
||||
moment we support two different communication backends:
|
||||
|
||||
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
||||
full-featured and mature distributed communications library
|
||||
* A **ring** backend of our own that uses native TCP sockets and should be
|
||||
faster for thunderbolt connections.
|
||||
|
||||
The list of all currently supported operations and their documentation can be
|
||||
seen in the :ref:`API docs<distributed>`.
|
||||
|
||||
.. note::
|
||||
A lot of operations may not be supported or not as fast as they should be.
|
||||
Some operations may not be supported or not as fast as they should be.
|
||||
We are adding more and tuning the ones we have as we are figuring out the
|
||||
best way to do distributed computing on Macs using MLX.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. The minimal distributed program in MLX is as simple as:
|
||||
A distributed program in MLX is as simple as:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@@ -30,74 +36,79 @@ machine. The minimal distributed program in MLX is as simple as:
|
||||
print(world.rank(), x)
|
||||
|
||||
The program above sums the array ``mx.ones(10)`` across all
|
||||
distributed processes. If simply run with ``python``, however, only one
|
||||
process is launched and no distributed communication takes place.
|
||||
distributed processes. However, when this script is run with ``python`` only
|
||||
one process is launched and no distributed communication takes place. Namely,
|
||||
all operations in ``mx.distributed`` are noops when the distributed group has a
|
||||
size of one. This property allows us to avoid code that checks if we are in a
|
||||
distributed setting similar to the one below:
|
||||
|
||||
To launch the program in distributed mode we need to use ``mpirun`` or
|
||||
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
||||
following:
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
x = ...
|
||||
world = mx.distributed.init()
|
||||
# No need for the check we can simply do x = mx.distributed.all_sum(x)
|
||||
if world.size() > 1:
|
||||
x = mx.distributed.all_sum(x)
|
||||
|
||||
Running Distributed Programs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
|
||||
Continuing with our initial example we can run it on localhost with 4 processes using
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 python test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
$ mlx.launch -n 4 my_script.py
|
||||
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
---------------
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
We can also run it on some remote hosts by providing their IPs (provided that
|
||||
the script exists on all hosts and they are reachable by ssh)
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ conda install openmpi
|
||||
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
|
||||
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
||||
Consult the dedicated :doc:`usage guide<launching_distributed>` for more
|
||||
information on using ``mlx.launch``.
|
||||
|
||||
.. code:: shell
|
||||
Selecting Backend
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
-----------------------
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
||||
full path to force all machines to use a specific path.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
You can select the backend you want to use when calling :func:`init` by passing
|
||||
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
|
||||
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
|
||||
both fail then a singleton group is created.
|
||||
|
||||
.. note::
|
||||
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
||||
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
||||
After a distributed backend is successfully initialized :func:`init` will
|
||||
return **the same backend** if called without arguments or with backend set to
|
||||
``any``.
|
||||
|
||||
An easy way to pass the host names to MPI is using a host file. A host file
|
||||
looks like the following, where ``host1`` and ``host2`` should be the fully
|
||||
qualified domain names or IPs for these hosts.
|
||||
The following examples aim to clarify the backend initialization logic in MLX:
|
||||
|
||||
.. code::
|
||||
.. code:: python
|
||||
|
||||
host1 slots=1
|
||||
host2 slots=1
|
||||
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
|
||||
world = mx.distributed.init(backend="mpi")
|
||||
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
|
||||
|
||||
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
||||
process per host. The hostfile also needs to contain the current
|
||||
host if you want to run on the local host. Passing the host file to
|
||||
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
||||
# Case 2: Initialize any backend
|
||||
world = mx.distributed.init(backend="any") # equivalent to no arguments
|
||||
world2 = mx.distributed.init() # same as above
|
||||
|
||||
# Case 3: Initialize both backends at the same time
|
||||
world_mpi = mx.distributed.init(backend="mpi")
|
||||
world_ring = mx.distributed.init(backend="ring")
|
||||
world_any = mx.distributed.init() # same as MPI because it was initialized first!
|
||||
|
||||
Training Example
|
||||
----------------
|
||||
@@ -141,12 +152,13 @@ everything else remaining the same.
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def all_reduce_grads(grads):
|
||||
N = mx.distributed.init()
|
||||
N = mx.distributed.init().size()
|
||||
if N == 1:
|
||||
return grads
|
||||
return tree_map(
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads)
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads
|
||||
)
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
@@ -154,13 +166,179 @@ everything else remaining the same.
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
Tuning All Reduce
|
||||
-----------------
|
||||
Utilizing ``nn.average_gradients``
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We are working on improving the performance of all reduce on MLX but for now
|
||||
the two main things one can do to extract the most out of distributed training with MLX are:
|
||||
Although the code example above works correctly; it performs one communication
|
||||
per gradient. It is significantly more efficient to aggregate several gradients
|
||||
together and perform fewer communication steps.
|
||||
|
||||
1. Perform a few large reductions instead of many small ones to improve
|
||||
bandwidth and latency
|
||||
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
||||
connections between each host to improve bandwidth
|
||||
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
|
||||
almost identical to the example above:
|
||||
|
||||
.. code:: python
|
||||
|
||||
model = ...
|
||||
optimizer = ...
|
||||
dataset = ...
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(model, x, y)
|
||||
mx.eval(loss, model.parameters())
|
||||
|
||||
|
||||
Getting Started with MPI
|
||||
------------------------
|
||||
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. Launching distributed MLX programs that use MPI can be done with
|
||||
``mpirun`` as expected. However, in the following examples we will be using
|
||||
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
|
||||
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
|
||||
library.
|
||||
|
||||
The simplest possible usage is the following which, assuming the minimal
|
||||
example in the beginning of this page, should result in:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mlx.launch --backend mpi -n 2 test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ conda install conda-forge::openmpi
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
||||
done automatically by ``mlx.launch``.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
$ # or simply
|
||||
$ mlx.launch -n 2 test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
|
||||
Tuning MPI All Reduce
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. note::
|
||||
|
||||
For faster all reduce consider using the ring backend either with Thunderbolt
|
||||
connections or over Ethernet.
|
||||
|
||||
Configure MPI to use N tcp connections between each host to improve bandwidth
|
||||
by passing ``--mca btl_tcp_links N``.
|
||||
|
||||
Force MPI to use the most performant network interface by setting ``--mca
|
||||
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
|
||||
to use.
|
||||
|
||||
Getting Started with Ring
|
||||
-------------------------
|
||||
|
||||
The ring backend does not depend on any third party library so it is always
|
||||
available. It uses TCP sockets so the nodes need to be reachable via a network.
|
||||
As the name suggests the nodes are connected in a ring which means that rank 1
|
||||
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
|
||||
and so on and so forth. As a result :func:`send` and :func:`recv` with
|
||||
arbitrary sender and receiver is not supported in the ring backend.
|
||||
|
||||
Defining a Ring
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
The easiest way to define and use a ring is via a JSON hostfile and the
|
||||
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
|
||||
defines a hostname to ssh into to run commands on this node and one or more IPs
|
||||
that this node will listen to for connections.
|
||||
|
||||
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
|
||||
rank 0, ``hostname2`` rank 1 etc.
|
||||
|
||||
.. code:: json
|
||||
|
||||
[
|
||||
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
||||
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
||||
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
|
||||
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
|
||||
]
|
||||
|
||||
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
|
||||
node, run the script which will listen for connections in each of the provided
|
||||
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
|
||||
connection from ``123.123.123.4`` and so on and so forth.
|
||||
|
||||
Thunderbolt Ring
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Although the ring backend can have benefits over MPI even for Ethernet, its
|
||||
main purpose is to use Thunderbolt rings for higher bandwidth communication.
|
||||
Setting up such thunderbolt rings can be done manually, but is a relatively
|
||||
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
|
||||
|
||||
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
|
||||
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
|
||||
utility as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
|
||||
|
||||
By default the script will attempt to discover the thunderbolt ring and provide
|
||||
you with the commands to configure each node as well as the ``hostfile.json``
|
||||
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
|
||||
then ``--auto-setup`` can be used to configure them automatically.
|
||||
|
||||
To validate your connection without configuring anything
|
||||
``mlx.distributed_config`` can also plot the ring using DOT format.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
|
||||
dot -Tpng ring.dot >ring.png
|
||||
open ring.png
|
||||
|
||||
If you want to go through the process manually, the steps are as follows:
|
||||
|
||||
* Disable the thunderbolt bridge interface
|
||||
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
|
||||
corresponding to that cable in nodes ``i`` and ``i + 1``.
|
||||
* Set up a unique subnetwork connecting the two nodes for the corresponding
|
||||
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
|
||||
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
|
||||
``192.168.0.2`` respectively to the two nodes. For more details you can see
|
||||
the commands prepared by the utility script.
|
||||
|
288
docs/src/usage/export.rst
Normal file
288
docs/src/usage/export.rst
Normal file
@@ -0,0 +1,288 @@
|
||||
.. _export_usage:
|
||||
|
||||
Exporting Functions
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX has an API to export and import functions to and from a file. This lets you
|
||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||
front-end (e.g. C++).
|
||||
|
||||
This guide walks through the basics of the MLX export API with some examples.
|
||||
To see the full list of functions check-out the :ref:`API documentation
|
||||
<export>`.
|
||||
|
||||
Basics of Exporting
|
||||
-------------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
To export a function, provide sample input arrays that the function
|
||||
can be called with. The data doesn't matter, but the shapes and types of the
|
||||
arrays do. In the above example we exported ``fun`` with two ``float32``
|
||||
scalar arrays. We can then import the function and run it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
add_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
out, = add_fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints: array(3, dtype=float32)
|
||||
print(out)
|
||||
|
||||
out, = add_fun(mx.array(1.0), mx.array(3.0))
|
||||
# Prints: array(4, dtype=float32)
|
||||
print(out)
|
||||
|
||||
# Raises an exception
|
||||
add_fun(mx.array(1), mx.array(3.0))
|
||||
|
||||
# Raises an exception
|
||||
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
|
||||
|
||||
Notice the third and fourth calls to ``add_fun`` raise exceptions because the
|
||||
shapes and types of the inputs are different than the shapes and types of the
|
||||
example inputs we exported the function with.
|
||||
|
||||
Also notice that even though the original ``fun`` returns a single output
|
||||
array, the imported function always returns a tuple of one or more arrays.
|
||||
|
||||
The inputs to :func:`export_function` and to an imported function can be
|
||||
specified as variable positional arguments or as a tuple of arrays:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
# Both arguments to fun are positional
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
# Same as above
|
||||
mx.export_function("add.mlxfn", fun, (x, y))
|
||||
|
||||
imported_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_fun(x, y)
|
||||
|
||||
# Also ok
|
||||
out, = imported_fun((x, y))
|
||||
|
||||
You can pass example inputs to functions as positional or keyword arguments. If
|
||||
you use keyword arguments to export the function, then you have to use the same
|
||||
keyword arguments when calling the imported function.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
# One argument to fun is positional, the other is a kwarg
|
||||
mx.export_function("add.mlxfn", fun, x, y=y)
|
||||
|
||||
imported_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_fun(x, y=y)
|
||||
|
||||
# Also ok
|
||||
out, = imported_fun((x,), {"y": y})
|
||||
|
||||
# Raises since the keyword argument is missing
|
||||
out, = imported_fun(x, y)
|
||||
|
||||
# Raises since the keyword argument has the wrong key
|
||||
out, = imported_fun(x, z=y)
|
||||
|
||||
|
||||
Exporting Modules
|
||||
-----------------
|
||||
|
||||
An :obj:`mlx.nn.Module` can be exported with or without the parameters included
|
||||
in the exported function. Here's an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Linear(4, 4)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def call(x):
|
||||
return model(x)
|
||||
|
||||
mx.export_function("model.mlxfn", call, mx.zeros(4))
|
||||
|
||||
In the above example, the :obj:`mlx.nn.Linear` module is exported. Its
|
||||
parameters are also saved to the ``model.mlxfn`` file.
|
||||
|
||||
.. note::
|
||||
|
||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||
they are evaluated. The computation graph that gets exported will include
|
||||
the computation that produces enclosed inputs.
|
||||
|
||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||
exported function would include the random initialization of the
|
||||
:obj:`mlx.nn.Module` parameters.
|
||||
|
||||
If you only want to export the ``Module.__call__`` function without the
|
||||
parameters, pass them as inputs to the ``call`` wrapper:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Linear(4, 4)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def call(x, **params):
|
||||
# Set the model's parameters to the input parameters
|
||||
model.update(tree_unflatten(list(params.items())))
|
||||
return model(x)
|
||||
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||
|
||||
|
||||
Shapeless Exports
|
||||
-----------------
|
||||
|
||||
Just like :func:`compile`, functions can also be exported for dynamically shaped
|
||||
inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`
|
||||
to export a function which can be used for inputs with variable shapes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
|
||||
imported_abs = mx.import_function("fun.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_abs(mx.array(-1.0))
|
||||
|
||||
# Also ok
|
||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||
|
||||
With ``shapeless=False`` (which is the default), the second call to
|
||||
``imported_abs`` would raise an exception with a shape mismatch.
|
||||
|
||||
Shapeless exporting works the same as shapeless compilation and should be
|
||||
used carefully. See the :ref:`documentation on shapeless compilation
|
||||
<shapeless_compile>` for more information.
|
||||
|
||||
Exporting Multiple Traces
|
||||
-------------------------
|
||||
|
||||
In some cases, functions build different computation graphs for different
|
||||
input arguments. A simple way to manage this is to export to a new file with
|
||||
each set of inputs. This is a fine option in many cases. But it can be
|
||||
suboptimal if the exported functions have a large amount of duplicate constant
|
||||
data (for example the parameters of a :obj:`mlx.nn.Module`).
|
||||
|
||||
The export API in MLX lets you export multiple traces of the same function to
|
||||
a single file by creating an exporting context manager with :func:`exporter`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y=None):
|
||||
constant = mx.array(3.0)
|
||||
if y is not None:
|
||||
x += y
|
||||
return x + constant
|
||||
|
||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||
exporter(mx.array(1.0))
|
||||
exporter(mx.array(1.0), y=mx.array(0.0))
|
||||
|
||||
imported_function = mx.import_function("fun.mlxfn")
|
||||
|
||||
# Call the function with y=None
|
||||
out, = imported_function(mx.array(1.0))
|
||||
print(out)
|
||||
|
||||
# Call the function with y specified
|
||||
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
|
||||
print(out)
|
||||
|
||||
In the above example the function constant data, (i.e. ``constant``), is only
|
||||
saved once.
|
||||
|
||||
Transformations with Imported Functions
|
||||
---------------------------------------
|
||||
|
||||
Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work
|
||||
on imported functions just like regular Python functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return mx.sin(x)
|
||||
|
||||
x = mx.array(0.0)
|
||||
mx.export_function("sine.mlxfn", fun, x)
|
||||
|
||||
imported_fun = mx.import_function("sine.mlxfn")
|
||||
|
||||
# Take the derivative of the imported function
|
||||
dfdx = mx.grad(lambda x: imported_fun(x)[0])
|
||||
# Prints: array(1, dtype=float32)
|
||||
print(dfdx(x))
|
||||
|
||||
# Compile the imported function
|
||||
mx.compile(imported_fun)
|
||||
# Prints: array(0, dtype=float32)
|
||||
print(compiled_fun(x)[0])
|
||||
|
||||
|
||||
Importing Functions in C++
|
||||
--------------------------
|
||||
|
||||
Importing and running functions in C++ is basically the same as importing and
|
||||
running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to
|
||||
setup a simple C++ project that uses MLX as a library.
|
||||
|
||||
Next, export a simple function from Python:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(x + y)
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
mx.export_function("fun.mlxfn", fun, x, y)
|
||||
|
||||
|
||||
Import and run the function in C++ with only a few lines of code:
|
||||
|
||||
.. code-block:: c++
|
||||
|
||||
auto fun = mx::import_function("fun.mlxfn");
|
||||
|
||||
auto inputs = {mx::array(1.0), mx::array(1.0)};
|
||||
auto outputs = fun(inputs);
|
||||
|
||||
// Prints: array(2, dtype=float32)
|
||||
std::cout << outputs[0] << std::endl;
|
||||
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||
|
||||
More Examples
|
||||
-------------
|
||||
|
||||
Here are a few more complete examples exporting more complex functions from
|
||||
Python and importing and running them in C++:
|
||||
|
||||
* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_
|
@@ -184,8 +184,8 @@ Let's time these two different versions:
|
||||
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||
|
||||
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
|
||||
vectorized version takes only ``0.025`` seconds, more than ten times faster.
|
||||
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
|
||||
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
|
||||
|
||||
Of course, this operation is quite contrived. A better approach is to simply do
|
||||
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
||||
|
105
docs/src/usage/launching_distributed.rst
Normal file
105
docs/src/usage/launching_distributed.rst
Normal file
@@ -0,0 +1,105 @@
|
||||
:orphan:
|
||||
|
||||
.. _usage_launch_distributed:
|
||||
|
||||
Launching Distributed Programs
|
||||
==============================
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
Installing the MLX python package provides a helper script ``mlx.launch`` that
|
||||
can be used to run python scripts distributed on several nodes. It allows
|
||||
launching using either the MPI backend or the ring backend. See the
|
||||
:doc:`distributed docs <distributed>` for the different backends.
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
The minimal usage example of ``mlx.launch`` is simply
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.launch --hosts ip1,ip2 my_script.py
|
||||
|
||||
or for testing on localhost
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.launch -n 2 my_script.py
|
||||
|
||||
The ``mlx.launch`` command connects to the provided host and launches the input
|
||||
script on each host. It monitors each of the launched processes and terminates
|
||||
the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
|
||||
It also takes care of forwarding the output of each remote process to stdout
|
||||
and stderr respectively.
|
||||
|
||||
Providing Hosts
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Hosts can be provided as command line arguments, like above, but the way that
|
||||
allows to fully define a list of hosts is via a JSON hostfile. The hostfile has
|
||||
a very simple schema. It is simply a list of objects that define each host via
|
||||
a hostname to ssh to and a list of IPs to utilize for the communication.
|
||||
|
||||
.. code:: json
|
||||
|
||||
[
|
||||
{"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
|
||||
{"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
|
||||
]
|
||||
|
||||
You can use ``mlx.distributed_config --over ethernet`` to create a hostfile
|
||||
with IPs corresponding to the ``en0`` interface.
|
||||
|
||||
Setting up Remote Hosts
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
In order to be able to launch the script on each host we need to be able to
|
||||
connect via ssh. Moreover the input script and python binary need to be on each
|
||||
host and on the same path. A good checklist to debug errors is the following:
|
||||
|
||||
* ``ssh hostname`` works without asking for password or host confirmation
|
||||
* the python binary is available on all hosts at the same path. You can use
|
||||
``mlx.launch --print-python`` to see what that path is.
|
||||
* the script you want to run is available on all hosts at the same path
|
||||
|
||||
.. _mpi_specifics:
|
||||
|
||||
MPI Specifics
|
||||
-------------
|
||||
|
||||
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
|
||||
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
|
||||
|
||||
* The IPs in the hostfile are ignored
|
||||
* The ssh connectivity requirement is stronger as every node needs to be able
|
||||
to connect to every other node
|
||||
* ``mpirun`` needs to be available on every node at the same path
|
||||
|
||||
Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance
|
||||
to choose a specific interface for the byte-transfer-layer of MPI we can call
|
||||
``mlx.launch`` as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
|
||||
|
||||
|
||||
.. _ring_specifics:
|
||||
|
||||
Ring Specifics
|
||||
--------------
|
||||
|
||||
The ring backend, which is also the default backend, can be explicitly selected
|
||||
with the argument ``--backend ring``. The ring backend has some specific
|
||||
requirements and arguments that are different to MPI:
|
||||
|
||||
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
|
||||
ssh to a hostname that does not correspond to the IP we want to bind to we
|
||||
have to provide a hostfile.
|
||||
* ``--starting-port`` defines the port to bind to on the remote hosts.
|
||||
Specifically rank 0 for the first IP will use this port and each subsequent
|
||||
IP or rank will add 1 to this port.
|
||||
* ``--connections-per-ip`` allows us to increase the number of connections
|
||||
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
|
||||
``mpirun``.
|
@@ -21,11 +21,13 @@ Let's convert an array to NumPy and back.
|
||||
|
||||
.. note::
|
||||
|
||||
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
|
||||
``np.array(a.astype(mx.float32))``.
|
||||
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
|
||||
Since NumPy does not support ``bfloat16`` arrays, you will need to convert
|
||||
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
|
||||
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
|
||||
buffer format string does not match the dtype V item size 0.``
|
||||
|
||||
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
|
||||
By default, NumPy copies data to a new array. This can be prevented by creating
|
||||
an array view:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -35,10 +37,16 @@ By default, NumPy copies data to a new array. This can be prevented by creating
|
||||
a_view[0] = 1
|
||||
print(a[0].item()) # 1
|
||||
|
||||
A NumPy array view is a normal NumPy array, except that it does not own its memory.
|
||||
This means writing to the view is reflected in the original array.
|
||||
.. note::
|
||||
|
||||
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
|
||||
NumPy arrays with type ``float64`` will be default converted to MLX arrays
|
||||
with type ``float32``.
|
||||
|
||||
A NumPy array view is a normal NumPy array, except that it does not own its
|
||||
memory. This means writing to the view is reflected in the original array.
|
||||
|
||||
While this is quite powerful to prevent copying arrays, it should be noted that
|
||||
external changes to the memory of arrays cannot be reflected in gradients.
|
||||
|
||||
Let's demonstrate this in an example:
|
||||
|
||||
@@ -56,11 +64,12 @@ Let's demonstrate this in an example:
|
||||
|
||||
|
||||
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
||||
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
|
||||
representing the gradient of the sum operation alone.
|
||||
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
|
||||
It's important to note that a similar issue arises during array conversion and copying.
|
||||
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||
However, this modification is not reflected in the gradient, as seen in the
|
||||
last line outputting ``1.0``, representing the gradient of the sum operation
|
||||
alone. The squaring of ``x`` occurs externally to MLX, meaning that no
|
||||
gradient is incorporated. It's important to note that a similar issue arises
|
||||
during array conversion and copying. For instance, a function defined as
|
||||
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||
even though no in-place operations on MLX memory are executed.
|
||||
|
||||
PyTorch
|
||||
@@ -71,7 +80,8 @@ PyTorch
|
||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||
|
||||
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||
PyTorch supports the buffer protocol, but it requires an explicit
|
||||
:obj:`memoryview`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -82,7 +92,8 @@ PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryvi
|
||||
b = torch.tensor(memoryview(a))
|
||||
c = mx.array(b.numpy())
|
||||
|
||||
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
|
||||
Conversion from PyTorch tensors back to arrays must be done via intermediate
|
||||
NumPy arrays with ``numpy()``.
|
||||
|
||||
JAX
|
||||
---
|
||||
@@ -100,7 +111,8 @@ JAX fully supports the buffer protocol.
|
||||
TensorFlow
|
||||
----------
|
||||
|
||||
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||
TensorFlow supports the buffer protocol, but it requires an explicit
|
||||
:obj:`memoryview`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
22
examples/cmake_project/CMakeLists.txt
Normal file
22
examples/cmake_project/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(example LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
# Comment the following two commands only the MLX C++ library is installed and
|
||||
# set(MLX_ROOT "/path/to/mlx") directly if needed.
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
add_executable(example example.cpp)
|
||||
target_link_libraries(example PRIVATE mlx)
|
26
examples/cmake_project/README.md
Normal file
26
examples/cmake_project/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
## Build and Run
|
||||
|
||||
Install MLX with Python:
|
||||
|
||||
```bash
|
||||
pip install mlx>=0.22
|
||||
```
|
||||
|
||||
Build the C++ example:
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
Run the C++ example:
|
||||
|
||||
```
|
||||
./build/example
|
||||
```
|
||||
|
||||
which should output:
|
||||
|
||||
```
|
||||
array([2, 4, 6], dtype=int32)
|
||||
```
|
14
examples/cmake_project/example.cpp
Normal file
14
examples/cmake_project/example.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
auto x = mx::array({1, 2, 3});
|
||||
auto y = mx::array({1, 2, 3});
|
||||
std::cout << x + y << std::endl;
|
||||
return 0;
|
||||
}
|
@@ -4,19 +4,19 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
if (!distributed::is_available()) {
|
||||
if (!mx::distributed::is_available()) {
|
||||
std::cout << "No communication backend found" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto global_group = distributed::init();
|
||||
auto global_group = mx::distributed::init();
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_sum(x, global_group);
|
||||
mx::array x = mx::ones({10});
|
||||
mx::array out = mx::distributed::all_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
||||
|
@@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of linear regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.01;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
|
||||
// The input examples (design matrix)
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
|
||||
// Noisy labels
|
||||
auto eps = 1e-2 * random::normal({num_examples});
|
||||
auto y = matmul(X, w_star) + eps;
|
||||
auto eps = 1e-2 * mx::random::normal({num_examples});
|
||||
auto y = mx::matmul(X, w_star) + eps;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto yhat = matmul(X, w);
|
||||
return (0.5f / num_examples) * sum(square(yhat - y));
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto yhat = mx::matmul(X, w);
|
||||
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
|
||||
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
||||
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
||||
|
@@ -10,7 +10,7 @@
|
||||
/**
|
||||
* An example of logistic regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
@@ -19,35 +19,35 @@ int main() {
|
||||
float learning_rate = 0.1;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
auto w_star = mx::random::normal({num_features});
|
||||
|
||||
// The input examples
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
auto X = mx::random::normal({num_examples, num_features});
|
||||
|
||||
// Labels
|
||||
auto y = matmul(X, w_star) > 0;
|
||||
auto y = mx::matmul(X, w_star) > 0;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto logits = matmul(X, w);
|
||||
auto loss_fn = [&](mx::array w) {
|
||||
auto logits = mx::matmul(X, w);
|
||||
auto scale = (1.0f / num_examples);
|
||||
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
|
||||
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
auto grad_fn = mx::grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
auto grads = grad_fn(w);
|
||||
w = w - learning_rate * grads;
|
||||
mx::eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
|
||||
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
||||
<< throughput << " (it/s)." << std::endl;
|
||||
|
@@ -5,27 +5,27 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
// To use Metal debugging and profiling:
|
||||
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||
metal::start_capture("mlx_trace.gputrace");
|
||||
mx::metal::start_capture("mlx_trace.gputrace");
|
||||
|
||||
// Start at index two because the default GPU and CPU streams have indices
|
||||
// zero and one, respectively. This naming matches the label assigned to each
|
||||
// stream's command queue.
|
||||
auto s2 = new_stream(Device::gpu);
|
||||
auto s3 = new_stream(Device::gpu);
|
||||
auto s2 = new_stream(mx::Device::gpu);
|
||||
auto s3 = new_stream(mx::Device::gpu);
|
||||
|
||||
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||
auto x = add(a, a, s2);
|
||||
auto y = add(b, b, s3);
|
||||
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
|
||||
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
|
||||
auto x = mx::add(a, a, s2);
|
||||
auto y = mx::add(b, b, s3);
|
||||
|
||||
// The multiply will happen on the default stream.
|
||||
std::cout << multiply(x, y) << std::endl;
|
||||
std::cout << mx::multiply(x, y) << std::endl;
|
||||
|
||||
metal::stop_capture();
|
||||
mx::metal::stop_capture();
|
||||
}
|
||||
|
@@ -5,11 +5,11 @@
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void array_basics() {
|
||||
// Make a scalar array:
|
||||
array x(1.0);
|
||||
mx::array x(1.0);
|
||||
|
||||
// Get the value out of it:
|
||||
auto s = x.item<float>();
|
||||
@@ -29,31 +29,31 @@ void array_basics() {
|
||||
|
||||
// The datatype should be float32:
|
||||
auto dtype = x.dtype();
|
||||
assert(dtype == float32);
|
||||
assert(dtype == mx::float32);
|
||||
|
||||
// Specify the dtype when constructing the array:
|
||||
x = array(1, int32);
|
||||
assert(x.dtype() == int32);
|
||||
x = mx::array(1, mx::int32);
|
||||
assert(x.dtype() == mx::int32);
|
||||
x.item<int>(); // OK
|
||||
// x.item<float>(); // Undefined!
|
||||
|
||||
// Make a multidimensional array:
|
||||
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||
// mlx is row-major by default so the first row of this array
|
||||
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||
|
||||
// Make an array of shape {2, 2} filled with ones:
|
||||
auto y = ones({2, 2});
|
||||
auto y = mx::ones({2, 2});
|
||||
|
||||
// Pointwise add x and y:
|
||||
auto z = add(x, y);
|
||||
auto z = mx::add(x, y);
|
||||
|
||||
// Same thing:
|
||||
z = x + y;
|
||||
|
||||
// mlx is lazy by default. At this point `z` only
|
||||
// has a shape and a type but no actual data:
|
||||
assert(z.dtype() == float32);
|
||||
assert(z.dtype() == mx::float32);
|
||||
assert(z.shape(0) == 2);
|
||||
assert(z.shape(1) == 2);
|
||||
|
||||
@@ -63,33 +63,33 @@ void array_basics() {
|
||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||
// all of its dependencies are recursively evaluated to produce the result.
|
||||
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||
eval(z);
|
||||
mx::eval(z);
|
||||
|
||||
// Of course the array can still be an input to other operations. You can even
|
||||
// call eval on the array again, this will just be a no-op:
|
||||
eval(z); // no-op
|
||||
// Of course the array can still be an input to other operations. You can
|
||||
// even call eval on the array again, this will just be a no-op:
|
||||
mx::eval(z); // no-op
|
||||
|
||||
// Some functions or methods on arrays implicitly evaluate them. For example
|
||||
// accessing a value in an array or printing the array implicitly evaluate it:
|
||||
z = ones({1});
|
||||
z = mx::ones({1});
|
||||
z.item<float>(); // implicit evaluation
|
||||
|
||||
z = ones({2, 2});
|
||||
z = mx::ones({2, 2});
|
||||
std::cout << z << std::endl; // implicit evaluation
|
||||
}
|
||||
|
||||
void automatic_differentiation() {
|
||||
auto fn = [](array x) { return square(x); };
|
||||
auto fn = [](mx::array x) { return mx::square(x); };
|
||||
|
||||
// Computing the derivative function of a function
|
||||
auto grad_fn = grad(fn);
|
||||
auto grad_fn = mx::grad(fn);
|
||||
// Call grad_fn on the input to get the derivative
|
||||
auto x = array(1.5);
|
||||
auto x = mx::array(1.5);
|
||||
auto dfdx = grad_fn(x);
|
||||
// dfdx is 2 * x
|
||||
|
||||
// Get the second derivative by composing grad with grad
|
||||
auto d2fdx2 = grad(grad(fn))(x);
|
||||
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
|
||||
// d2fdx2 is 2
|
||||
}
|
||||
|
||||
|
22
examples/export/CMakeLists.txt
Normal file
22
examples/export/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(import_mlx LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
add_executable(eval_mlp eval_mlp.cpp)
|
||||
target_link_libraries(eval_mlp PRIVATE mlx)
|
||||
|
||||
add_executable(train_mlp train_mlp.cpp)
|
||||
target_link_libraries(train_mlp PRIVATE mlx)
|
49
examples/export/README.md
Normal file
49
examples/export/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
## Setup
|
||||
|
||||
Install MLX:
|
||||
|
||||
```bash
|
||||
pip install mlx>=0.22
|
||||
```
|
||||
|
||||
Build the C++ examples:
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
### Eval MLP
|
||||
|
||||
Run the Python script to export the eval function:
|
||||
|
||||
```bash
|
||||
python eval_mlp.py
|
||||
```
|
||||
|
||||
Then run the C++ program to import and run the function:
|
||||
|
||||
```
|
||||
./build/eval_mlp
|
||||
```
|
||||
|
||||
The Python and C++ programs should output the same result.
|
||||
|
||||
### Train MLP
|
||||
|
||||
Run the Python script to export the model initialization and training
|
||||
functions:
|
||||
|
||||
```bash
|
||||
python train_mlp.py
|
||||
```
|
||||
|
||||
Then run the C++ program to import and run the functions:
|
||||
|
||||
```
|
||||
./build/train_mlp
|
||||
```
|
||||
|
||||
The Python and C++ programs should output the same results.
|
25
examples/export/eval_mlp.cpp
Normal file
25
examples/export/eval_mlp.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
|
||||
// Make the input
|
||||
mx::random::seed(42);
|
||||
auto example_x = mx::random::uniform({batch_size, input_dim});
|
||||
|
||||
// Import the function
|
||||
auto forward = mx::import_function("eval_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
auto out = forward({example_x})[0];
|
||||
|
||||
std::cout << out << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
52
examples/export/eval_mlp.py
Normal file
52
examples/export/eval_mlp.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 8
|
||||
input_dim = 32
|
||||
output_dim = 10
|
||||
|
||||
# Load the model
|
||||
mx.random.seed(0) # Seed for params
|
||||
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
|
||||
mx.eval(model)
|
||||
|
||||
# Note, the model parameters are saved in the export function
|
||||
def forward(x):
|
||||
return model(x)
|
||||
|
||||
mx.random.seed(42) # Seed for input
|
||||
example_x = mx.random.uniform(shape=(batch_size, input_dim))
|
||||
|
||||
mx.export_function("eval_mlp.mlxfn", forward, example_x)
|
||||
|
||||
# Import in Python
|
||||
imported_forward = mx.import_function("eval_mlp.mlxfn")
|
||||
expected = forward(example_x)
|
||||
(out,) = imported_forward(example_x)
|
||||
assert mx.allclose(expected, out)
|
||||
print(out)
|
35
examples/export/train_mlp.cpp
Normal file
35
examples/export/train_mlp.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
int output_dim = 10;
|
||||
|
||||
auto state = mx::import_function("init_mlp.mlxfn")({});
|
||||
|
||||
// Make the input
|
||||
mx::random::seed(42);
|
||||
auto example_X = mx::random::normal({batch_size, input_dim});
|
||||
auto example_y = mx::random::randint(0, output_dim, {batch_size});
|
||||
|
||||
// Import the function
|
||||
auto step = mx::import_function("train_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
for (int it = 0; it < 100; ++it) {
|
||||
state.insert(state.end(), {example_X, example_y});
|
||||
state = step(state);
|
||||
eval(state);
|
||||
auto loss = state.back();
|
||||
state.pop_back();
|
||||
if (it % 10 == 0) {
|
||||
std::cout << "Loss " << loss.item<float>() << std::endl;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
76
examples/export/train_mlp.py
Normal file
76
examples/export/train_mlp.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 8
|
||||
input_dim = 32
|
||||
output_dim = 10
|
||||
|
||||
def init():
|
||||
# Seed for the parameter initialization
|
||||
mx.random.seed(0)
|
||||
model = MLP(
|
||||
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
|
||||
)
|
||||
optimizer = optim.SGD(learning_rate=1e-1)
|
||||
optimizer.init(model.parameters())
|
||||
state = [model.parameters(), optimizer.state]
|
||||
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
|
||||
return model, optimizer, tree_structure, state
|
||||
|
||||
# Export the model parameter initialization
|
||||
model, optimizer, tree_structure, state = init()
|
||||
mx.eval(state)
|
||||
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
|
||||
|
||||
def loss_fn(params, X, y):
|
||||
model.update(params)
|
||||
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||
|
||||
def step(*inputs):
|
||||
*state, X, y = inputs
|
||||
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
|
||||
optimizer.state = opt_state
|
||||
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
|
||||
params = optimizer.apply_gradients(grads, params)
|
||||
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
|
||||
return *state, loss
|
||||
|
||||
# Make some random data
|
||||
mx.random.seed(42)
|
||||
example_X = mx.random.normal(shape=(batch_size, input_dim))
|
||||
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
|
||||
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
|
||||
|
||||
# Export one step of SGD
|
||||
imported_step = mx.import_function("train_mlp.mlxfn")
|
||||
|
||||
for it in range(100):
|
||||
*state, loss = imported_step(*state, example_X, example_y)
|
||||
if it % 10 == 0:
|
||||
print(f"Loss {loss.item():.6}")
|
@@ -10,7 +10,6 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
find_package(
|
||||
Python 3.8
|
||||
COMPONENTS Interpreter Development.Module
|
||||
@@ -18,10 +17,15 @@ find_package(
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE NB_DIR)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||
OUTPUT_VARIABLE nanobind_ROOT)
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
# ----------------------------- Extensions -----------------------------
|
||||
|
||||
# Add library
|
||||
|
@@ -1,25 +1,20 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include "axpby/axpby.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/cblas_new.h>
|
||||
#endif
|
||||
|
||||
#ifdef _METAL_
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
namespace my_ext {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation Implementation
|
||||
@@ -32,24 +27,24 @@ namespace mlx::core {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input mx::array x
|
||||
const mx::array& y, // Input mx::array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Promote dtypes between x and y as needed
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
: promote_types(promoted_dtype, mx::float32);
|
||||
|
||||
// Cast x and y up to the determined dtype (on the same stream s)
|
||||
auto x_casted = astype(x, out_dtype, s);
|
||||
auto y_casted = astype(y, out_dtype, s);
|
||||
auto x_casted = mx::astype(x, out_dtype, s);
|
||||
auto y_casted = mx::astype(y, out_dtype, s);
|
||||
|
||||
// Broadcast the shapes of x and y (on the same stream s)
|
||||
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||
@@ -57,12 +52,12 @@ array axpby(
|
||||
|
||||
// Construct the array as the output of the Axpby primitive
|
||||
// with the broadcasted and upcasted arrays as inputs
|
||||
return array(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
return mx::array(
|
||||
/* const mx::Shape& shape = */ out_shape,
|
||||
/* mx::Dtype dtype = */ out_dtype,
|
||||
/* std::shared_ptr<mx::Primitive> primitive = */
|
||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -71,140 +66,69 @@ array axpby(
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
const T* y_ptr = y.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(y);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
// Launch the CPU kernel
|
||||
encoder.dispatch([x_ptr = x.data<T>(),
|
||||
y_ptr = y.data<T>(),
|
||||
out_ptr = out.data<T>(),
|
||||
size = out.size(),
|
||||
shape = out.shape(),
|
||||
x_strides = x.strides(),
|
||||
y_strides = y.strides(),
|
||||
alpha_,
|
||||
beta_]() {
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Accelerate Backend Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
|
||||
// This specialization requires both x and y be contiguous in the same mode
|
||||
// i.e: corresponding linear indices in both point to corresponding elements
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
T* y_ptr = out.data<T>();
|
||||
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Call the inplace accelerate operator
|
||||
catlas_saxpby(
|
||||
/* N = */ out.size(),
|
||||
/* ALPHA = */ alpha,
|
||||
/* X = */ x_ptr,
|
||||
/* INCX = */ 1,
|
||||
/* BETA = */ beta,
|
||||
/* Y = */ y_ptr,
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#else // Accelerate not available
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Metal Backend Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -213,10 +137,9 @@ void Axpby::eval_cpu(
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
@@ -225,7 +148,7 @@ void Axpby::eval_gpu(
|
||||
// and each stream carries its device identifiers
|
||||
auto& s = stream();
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
auto& d = mx::metal::device(s.device);
|
||||
|
||||
// Prepare to specialize based on contiguity
|
||||
bool contiguous_kernel =
|
||||
@@ -235,12 +158,12 @@ void Axpby::eval_gpu(
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
@@ -257,7 +180,7 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -272,15 +195,15 @@ void Axpby::eval_gpu(
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim if needed
|
||||
if (!contiguous_kernel) {
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_vector_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
}
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
@@ -295,15 +218,15 @@ void Axpby::eval_gpu(
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
#else // Metal is not available
|
||||
|
||||
/** Fail evaluation on GPU */
|
||||
void Axpby::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& out) {
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& out) {
|
||||
throw std::runtime_error("Axpby has no GPU implementation.");
|
||||
}
|
||||
|
||||
@@ -314,9 +237,9 @@ void Axpby::eval_gpu(
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> Axpby::jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can built with ops
|
||||
@@ -328,8 +251,8 @@ std::vector<array> Axpby::jvp(
|
||||
// scaled by beta
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
auto scale_arr = mx::array(scale, tangents[0].dtype());
|
||||
return {mx::multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
@@ -339,24 +262,24 @@ std::vector<array> Axpby::jvp(
|
||||
}
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
std::vector<mx::array> Axpby::vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
const std::vector<mx::array>&) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
std::vector<mx::array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||
auto scale_arr = mx::array(scale, cotangents[0].dtype());
|
||||
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
/** Vectorize primitive along given axis */
|
||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
}
|
||||
@@ -367,4 +290,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace my_ext
|
||||
|
@@ -1,11 +1,13 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
namespace mx = mlx::core;
|
||||
|
||||
namespace my_ext {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation
|
||||
@@ -18,22 +20,22 @@ namespace mlx::core {
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
mx::array axpby(
|
||||
const mx::array& x, // Input array x
|
||||
const mx::array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class Axpby : public Primitive {
|
||||
class Axpby : public mx::Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
explicit Axpby(mx::Stream stream, float alpha, float beta)
|
||||
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
@@ -42,23 +44,25 @@ class Axpby : public Primitive {
|
||||
* To avoid unnecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
void eval_gpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
std::vector<array> jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
std::vector<mx::array> jvp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
std::vector<mx::array> vjp(
|
||||
const std::vector<mx::array>& primals,
|
||||
const std::vector<mx::array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
const std::vector<mx::array>& outputs) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself across
|
||||
@@ -66,8 +70,8 @@ class Axpby : public Primitive {
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
@@ -76,14 +80,11 @@ class Axpby : public Primitive {
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
bool is_equivalent(const mx::Primitive& other) const override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace my_ext
|
||||
|
@@ -1,8 +1,7 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T>
|
||||
@@ -13,8 +12,8 @@ template <typename T>
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int64_t* x_strides [[buffer(6)]],
|
||||
constant const int64_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
@@ -35,29 +34,14 @@ template <typename T>
|
||||
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||
}
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
||||
axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
||||
axpby_contiguous<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
// clang-format off
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \
|
||||
instantiate_kernel( \
|
||||
"axpby_contiguous_" #type_name, axpby_contiguous, type)
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bfloat16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
// clang-format on
|
||||
|
@@ -8,14 +8,12 @@
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
NB_MODULE(_ext, m) {
|
||||
m.doc() = "Sample extension for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
&my_ext::axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
"alpha"_a,
|
||||
|
@@ -1,8 +1,8 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"cmake>=3.24",
|
||||
"cmake>=3.25",
|
||||
"mlx>=0.18.0",
|
||||
"nanobind==2.2.0",
|
||||
"nanobind==2.4.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.18.1
|
||||
cmake>=3.25
|
||||
mlx>=0.21.0
|
||||
nanobind==2.2.0
|
||||
|
15
mlx.pc.in
15
mlx.pc.in
@@ -28,10 +28,19 @@ endif()
|
||||
if (@MLX_BUILD_METAL@)
|
||||
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
|
||||
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
||||
set_and_check(MLX_INCLUDE_DIRS
|
||||
${MLX_INCLUDE_DIRS}
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
|
||||
)
|
||||
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
|
||||
else()
|
||||
set(MLX_INCLUDE_DIRS
|
||||
"${MLX_INCLUDE_DIRS};"
|
||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set_target_properties(mlx PROPERTIES
|
||||
@@ -40,4 +49,4 @@ set_target_properties(mlx PROPERTIES
|
||||
)
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
||||
|
@@ -5,6 +5,8 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
@@ -18,24 +20,37 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
# Define MLX_VERSION only in the version.cpp file.
|
||||
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||
|
||||
if(MSVC)
|
||||
# Disable some MSVC warnings to speed up compilation.
|
||||
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
# Export symbols by default to behave like macOS/linux.
|
||||
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
elseif(MLX_BUILD_CPU)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||
endif()
|
||||
|
@@ -4,12 +4,11 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Buffer malloc(size_t size) {
|
||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
||||
auto buffer = allocator().malloc(size);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
@@ -19,48 +18,7 @@ Buffer malloc(size_t size) {
|
||||
}
|
||||
|
||||
void free(Buffer buffer) {
|
||||
return allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
void* ptr = std::malloc(size + sizeof(size_t));
|
||||
if (ptr != nullptr) {
|
||||
*static_cast<size_t*>(ptr) = size;
|
||||
}
|
||||
return Buffer{ptr};
|
||||
}
|
||||
|
||||
void CommonAllocator::free(Buffer buffer) {
|
||||
std::free(buffer.ptr());
|
||||
}
|
||||
|
||||
size_t CommonAllocator::size(Buffer buffer) const {
|
||||
if (buffer.ptr() == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return *static_cast<size_t*>(buffer.ptr());
|
||||
}
|
||||
|
||||
Buffer malloc_or_wait(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
|
||||
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
||||
scheduler::wait_for_one();
|
||||
buffer = allocator().malloc(size);
|
||||
}
|
||||
|
||||
// Try swapping if needed
|
||||
if (size && !buffer.ptr()) {
|
||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
||||
}
|
||||
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return buffer;
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
|
||||
|
||||
void free(Buffer buffer);
|
||||
|
||||
// Wait for running tasks to finish and free up memory
|
||||
// if allocation fails
|
||||
Buffer malloc_or_wait(size_t size);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
virtual size_t size(Buffer buffer) const = 0;
|
||||
|
||||
@@ -53,16 +49,4 @@ class Allocator {
|
||||
|
||||
Allocator& allocator();
|
||||
|
||||
class CommonAllocator : public Allocator {
|
||||
/** A general CPU allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
private:
|
||||
CommonAllocator() = default;
|
||||
friend Allocator& allocator();
|
||||
};
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
137
mlx/array.cpp
137
mlx/array.cpp
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/ops.h"
|
||||
@@ -9,28 +10,14 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
auto cval = static_cast<complex64_t>(val);
|
||||
init(&cval);
|
||||
}
|
||||
|
||||
array::array(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs)
|
||||
@@ -38,10 +25,21 @@ array::array(
|
||||
std::move(shape),
|
||||
dtype,
|
||||
std::move(primitive),
|
||||
std::move(inputs))) {}
|
||||
std::move(inputs))) {
|
||||
if (has_primitive() && this->primitive().stream().device == Device::gpu) {
|
||||
for (auto& in : this->inputs()) {
|
||||
if (in.dtype() == float64) {
|
||||
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||
}
|
||||
}
|
||||
if (this->dtype() == float64) {
|
||||
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
std::vector<std::vector<int>> shapes,
|
||||
std::vector<Shape> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs) {
|
||||
@@ -58,47 +56,59 @@ std::vector<array> array::make_arrays(
|
||||
return outputs;
|
||||
}
|
||||
|
||||
array array::unsafe_weak_copy(const array& other) {
|
||||
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
|
||||
cpy.set_data(
|
||||
other.buffer(),
|
||||
other.data_size(),
|
||||
other.strides(),
|
||||
other.flags(),
|
||||
[](auto) {});
|
||||
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||
return cpy;
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<float> data)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
float32)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
dtype)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter)
|
||||
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
set_data(data, deleter);
|
||||
}
|
||||
|
||||
void array::detach() {
|
||||
array_desc_->primitive = nullptr;
|
||||
for (auto& s : array_desc_->siblings) {
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
for (auto& s : array_desc_->siblings) {
|
||||
s.array_desc_->inputs.clear();
|
||||
s.array_desc_->siblings.clear();
|
||||
s.array_desc_->position = 0;
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
array_desc_->inputs.clear();
|
||||
array_desc_->siblings.clear();
|
||||
array_desc_->position = 0;
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
bool array::is_available() const {
|
||||
if (status() == Status::available) {
|
||||
return true;
|
||||
} else if (status() == Status::evaluated && event().is_signaled()) {
|
||||
} else if (
|
||||
status() == Status::evaluated &&
|
||||
(!event().valid() || event().is_signaled())) {
|
||||
set_status(Status::available);
|
||||
return true;
|
||||
}
|
||||
@@ -107,7 +117,10 @@ bool array::is_available() const {
|
||||
|
||||
void array::wait() {
|
||||
if (!is_available()) {
|
||||
event().wait();
|
||||
if (event().valid()) {
|
||||
event().wait();
|
||||
detach_event();
|
||||
}
|
||||
set_status(Status::available);
|
||||
}
|
||||
}
|
||||
@@ -122,10 +135,11 @@ void array::eval() {
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
return (array_desc_->is_tracer && detail::in_tracing()) ||
|
||||
detail::retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = size();
|
||||
@@ -138,9 +152,9 @@ void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
void array::set_data(
|
||||
allocator::Buffer buffer,
|
||||
size_t data_size,
|
||||
std::vector<size_t> strides,
|
||||
Strides strides,
|
||||
Flags flags,
|
||||
deleter_t d) {
|
||||
Deleter d) {
|
||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||
array_desc_->data_ptr = buffer.raw_ptr();
|
||||
array_desc_->data_size = data_size;
|
||||
@@ -150,7 +164,7 @@ void array::set_data(
|
||||
|
||||
void array::copy_shared_buffer(
|
||||
const array& other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
@@ -167,34 +181,13 @@ void array::copy_shared_buffer(const array& other) {
|
||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(
|
||||
array other,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
array_desc_->data = std::move(other.array_desc_->data);
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
array_desc_->data_size = data_size;
|
||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||
auto data_ptr = other.array_desc_->data_ptr;
|
||||
other.array_desc_->data_ptr = nullptr;
|
||||
array_desc_->data_ptr =
|
||||
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(array other) {
|
||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
array::~array() {
|
||||
if (array_desc_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore arrays that might be detached during eval
|
||||
if (status() == array::Status::scheduled) {
|
||||
// Detached/detaching
|
||||
if (array_desc_->primitive == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -214,6 +207,8 @@ array::~array() {
|
||||
if (do_detach) {
|
||||
for (auto& s : siblings()) {
|
||||
for (auto& ss : s.siblings()) {
|
||||
// Set to null here to avoid descending into array destructor
|
||||
// for siblings
|
||||
ss.array_desc_ = nullptr;
|
||||
}
|
||||
s.array_desc_->siblings.clear();
|
||||
@@ -234,13 +229,13 @@ void array::ArrayDesc::init() {
|
||||
}
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
||||
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
|
||||
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
|
||||
init();
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs)
|
||||
@@ -278,7 +273,19 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
}
|
||||
ad.inputs.clear();
|
||||
for (auto& [_, a] : input_map) {
|
||||
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
|
||||
bool is_deletable =
|
||||
(a.array_desc_.use_count() <= a.siblings().size() + 1);
|
||||
// An array with siblings is deletable only if all of its siblings
|
||||
// are deletable
|
||||
for (auto& s : a.siblings()) {
|
||||
if (!is_deletable) {
|
||||
break;
|
||||
}
|
||||
int is_input = (input_map.find(s.id()) != input_map.end());
|
||||
is_deletable &=
|
||||
s.array_desc_.use_count() <= a.siblings().size() + is_input;
|
||||
}
|
||||
if (is_deletable) {
|
||||
for_deletion.push_back(std::move(a.array_desc_));
|
||||
}
|
||||
}
|
||||
@@ -292,6 +299,14 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
auto top = std::move(for_deletion.back());
|
||||
for_deletion.pop_back();
|
||||
append_deletable_inputs(*top);
|
||||
|
||||
// Clear out possible siblings to break circular references
|
||||
for (auto& s : top->siblings) {
|
||||
// Set to null here to avoid descending into top-level
|
||||
// array destructor for siblings
|
||||
s.array_desc_ = nullptr;
|
||||
}
|
||||
top->siblings.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,7 +318,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
}
|
||||
|
||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||
auto start = std::vector<int>(arr.ndim(), 0);
|
||||
auto start = Shape(arr.ndim(), 0);
|
||||
auto end = arr.shape();
|
||||
auto shape = arr.shape();
|
||||
shape.erase(shape.begin());
|
||||
|
102
mlx/array.h
102
mlx/array.h
@@ -15,7 +15,11 @@ namespace mlx::core {
|
||||
|
||||
// Forward declaration
|
||||
class Primitive;
|
||||
using deleter_t = std::function<void(allocator::Buffer)>;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using ShapeElem = int32_t;
|
||||
using Shape = std::vector<ShapeElem>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
|
||||
class array {
|
||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||
@@ -31,33 +35,33 @@ class array {
|
||||
explicit array(const std::complex<float>& val, Dtype dtype = complex64);
|
||||
|
||||
template <typename It>
|
||||
array(
|
||||
explicit array(
|
||||
It data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype =
|
||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||
|
||||
template <typename T>
|
||||
array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
||||
explicit array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Special case so empty lists default to float32. */
|
||||
array(std::initializer_list<float> data);
|
||||
explicit array(std::initializer_list<float> data);
|
||||
|
||||
/* Special case so array({}, type) is an empty array. */
|
||||
array(std::initializer_list<int> data, Dtype dtype);
|
||||
explicit array(std::initializer_list<int> data, Dtype dtype);
|
||||
|
||||
template <typename T>
|
||||
array(
|
||||
explicit array(
|
||||
std::initializer_list<T> data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a buffer */
|
||||
array(
|
||||
explicit array(
|
||||
allocator::Buffer data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
deleter_t deleter = allocator::free);
|
||||
Deleter deleter = allocator::free);
|
||||
|
||||
/** Assignment to rvalue does not compile. */
|
||||
array& operator=(const array& other) && = delete;
|
||||
@@ -96,7 +100,7 @@ class array {
|
||||
}
|
||||
|
||||
/** The shape of the array as a vector of integers. */
|
||||
const std::vector<int>& shape() const {
|
||||
const Shape& shape() const {
|
||||
return array_desc_->shape;
|
||||
}
|
||||
|
||||
@@ -105,12 +109,12 @@ class array {
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
int shape(int dim) const {
|
||||
auto shape(int dim) const {
|
||||
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
|
||||
/** The strides of the array. */
|
||||
const std::vector<size_t>& strides() const {
|
||||
const Strides& strides() const {
|
||||
return array_desc_->strides;
|
||||
}
|
||||
|
||||
@@ -119,7 +123,7 @@ class array {
|
||||
*
|
||||
* This function supports negative indexing and provides
|
||||
* bounds checking. */
|
||||
size_t strides(int dim) const {
|
||||
auto strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
|
||||
@@ -184,17 +188,24 @@ class array {
|
||||
*/
|
||||
|
||||
array(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
std::vector<std::vector<int>> shapes,
|
||||
std::vector<Shape> shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/**
|
||||
* Get a new array that refers to the same data as the input but with a
|
||||
* non-owning pointer to it. Note the array is detached from the graph and has
|
||||
* no inputs, siblings or primitive.
|
||||
*/
|
||||
static array unsafe_weak_copy(const array& other);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
std::uintptr_t id() const {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||
@@ -207,8 +218,8 @@ class array {
|
||||
|
||||
struct Data {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
Deleter d;
|
||||
Data(allocator::Buffer buffer, Deleter d = allocator::free)
|
||||
: buffer(buffer), d(d) {}
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
@@ -328,11 +339,11 @@ class array {
|
||||
return allocator::allocator().size(buffer());
|
||||
}
|
||||
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
std::shared_ptr<Data> data_shared_ptr() const {
|
||||
// Return the shared pointer to the array::Data struct
|
||||
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||
return array_desc_->data;
|
||||
}
|
||||
|
||||
// Return a raw pointer to the arrays data
|
||||
template <typename T>
|
||||
T* data() {
|
||||
@@ -345,15 +356,10 @@ class array {
|
||||
}
|
||||
|
||||
enum Status {
|
||||
// The ouptut of a computation which has not been scheduled.
|
||||
// The output of a computation which has not been scheduled.
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
// The ouptut of a computation which has been scheduled but `eval_*` has
|
||||
// not yet been called on the array's primitive. A possible
|
||||
// status of `x` in `auto x = a + b; eval(x);`
|
||||
scheduled,
|
||||
|
||||
// The array's `eval_*` function has been run, but the computation is not
|
||||
// necessarily complete. The array will have memory allocated and if it is
|
||||
// not a tracer then it will be detached from the graph.
|
||||
@@ -390,6 +396,10 @@ class array {
|
||||
array_desc_->event = std::move(e);
|
||||
}
|
||||
|
||||
void detach_event() const {
|
||||
array_desc_->event = Event{};
|
||||
}
|
||||
|
||||
// Mark the array as a tracer array (true) or not.
|
||||
void set_tracer(bool is_tracer) {
|
||||
array_desc_->is_tracer = is_tracer;
|
||||
@@ -397,33 +407,24 @@ class array {
|
||||
// Check if the array is a tracer array
|
||||
bool is_tracer() const;
|
||||
|
||||
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
||||
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
|
||||
|
||||
void set_data(
|
||||
allocator::Buffer buffer,
|
||||
size_t data_size,
|
||||
std::vector<size_t> strides,
|
||||
Strides strides,
|
||||
Flags flags,
|
||||
deleter_t d = allocator::free);
|
||||
Deleter d = allocator::free);
|
||||
|
||||
void copy_shared_buffer(
|
||||
const array& other,
|
||||
const std::vector<size_t>& strides,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
void copy_shared_buffer(const array& other);
|
||||
|
||||
void move_shared_buffer(
|
||||
array other,
|
||||
const std::vector<size_t>& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
void move_shared_buffer(array other);
|
||||
|
||||
void overwrite_descriptor(const array& other) {
|
||||
array_desc_ = other.array_desc_;
|
||||
}
|
||||
@@ -436,8 +437,8 @@ class array {
|
||||
void init(const It src);
|
||||
|
||||
struct ArrayDesc {
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
size_t size;
|
||||
Dtype dtype;
|
||||
std::shared_ptr<Primitive> primitive;
|
||||
@@ -471,10 +472,10 @@ class array {
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
|
||||
explicit ArrayDesc(Shape shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array> inputs);
|
||||
@@ -495,14 +496,14 @@ class array {
|
||||
|
||||
template <typename T>
|
||||
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
init(&val);
|
||||
}
|
||||
|
||||
template <typename It>
|
||||
array::array(
|
||||
It data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
init(data);
|
||||
@@ -513,7 +514,7 @@ array::array(
|
||||
std::initializer_list<T> data,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
dtype)) {
|
||||
init(data.begin());
|
||||
}
|
||||
@@ -521,7 +522,7 @@ array::array(
|
||||
template <typename T>
|
||||
array::array(
|
||||
std::initializer_list<T> data,
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype /* = TypeToDtype<T>() */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
if (data.size() != size()) {
|
||||
@@ -590,6 +591,9 @@ void array::init(It src) {
|
||||
case float32:
|
||||
std::copy(src, src + size(), data<float>());
|
||||
break;
|
||||
case float64:
|
||||
std::copy(src, src + size(), data<double>());
|
||||
break;
|
||||
case bfloat16:
|
||||
std::copy(src, src + size(), data<bfloat16_t>());
|
||||
break;
|
||||
|
@@ -1,8 +0,0 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
|
@@ -1,20 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
|
||||
// TODO: Add accelerate based optimizations for CPU conv
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,253 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/backend/accelerate/utils.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<bool, size_t, array> check_transpose(const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matmul_cblas_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[matmul_cblas] on CPU currently only supports float32");
|
||||
}
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[matmul_cblas] on CPU currently only supports float32");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return matmul_cblas_general(a_pre, b_pre, out);
|
||||
}
|
||||
|
||||
inline void matmul_bnns_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
// TODO: Update to utilize BNNS broadcasting
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||
|
||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||
/* float alpha = */ alpha,
|
||||
/* float beta = */ beta,
|
||||
/* bool transA = */ a_transposed,
|
||||
/* bool transB = */ b_transposed,
|
||||
/* bool quadratic = */ false,
|
||||
/* bool a_is_weights = */ false,
|
||||
/* bool b_is_weights = */ false,
|
||||
/* BNNSNDArrayDescriptor iA_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, lda, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor iB_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, ldb, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor o_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{N, M, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, N, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
};
|
||||
|
||||
auto bnns_filter =
|
||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
BNNSFilterApplyTwoInput(
|
||||
bnns_filter,
|
||||
a.data<uint8_t>() +
|
||||
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
|
||||
b.data<uint8_t>() +
|
||||
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
|
||||
out.data<uint8_t>() + M * N * i * out.itemsize());
|
||||
}
|
||||
|
||||
BNNSFilterDestroy(bnns_filter);
|
||||
}
|
||||
|
||||
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
||||
// TODO: Update to utilize BNNS broadcasting
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return matmul_bnns_general(a_pre, b_pre, out);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const bool* mask,
|
||||
int tile_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str) {
|
||||
int tX = (X + tile_size - 1) / tile_size;
|
||||
int tY = (Y + tile_size - 1) / tile_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
||||
if (!do_mask) {
|
||||
int loc_x = i * tile_size;
|
||||
int loc_y = j * tile_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
|
||||
int size_x = std::min(tile_size, X - loc_x);
|
||||
int size_y = std::min(tile_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() == float32) {
|
||||
return matmul_cblas(inputs[0], inputs[1], out);
|
||||
}
|
||||
return matmul_bnns(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
copy(c, out, ctype);
|
||||
|
||||
if (out.dtype() == float32) {
|
||||
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,601 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define DEFAULT(primitive) \
|
||||
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
#define DEFAULT_MULTI(primitive) \
|
||||
void primitive::eval_cpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
primitive::eval(inputs, outputs); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Use the default implementation for the following primitives
|
||||
DEFAULT(Arange)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
DEFAULT_MULTI(Eigh)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x + y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x + y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvacosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvacoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvasinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvasinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvatanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
if (a.is_donatable()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
int size = a.data_size();
|
||||
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvatanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (in.flags().contiguous) {
|
||||
// Use accelerate functions if possible
|
||||
if (in.dtype() == float32 && out.dtype() == uint32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfixu32(
|
||||
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == float32 && out.dtype() == int32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == uint32 && out.dtype() == float32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfltu32(
|
||||
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == int32 && out.dtype() == float32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
}
|
||||
}
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvcosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvcoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x / y; },
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x / y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpm1f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
assert(in.dtype() == out.dtype());
|
||||
if (in.data_size() == 1 && out.dtype() == float32) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
vDSP_vfill(in.data<float>(), out.data<float>(), 1, out.size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
vvlogf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
case Base::two:
|
||||
vvlog2f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
case Base::ten:
|
||||
vvlog10f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x * y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
int size = a.size();
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
|
||||
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int stride = in.shape(axis_);
|
||||
int count = in.size() / stride;
|
||||
const float* input = in.data<float>();
|
||||
float* output = out.data<float>();
|
||||
float s = 1.0;
|
||||
if (!reverse_) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
|
||||
input += stride;
|
||||
output += stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
input += stride - 1;
|
||||
output += stride - 1;
|
||||
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvsinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvsinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
if (recip_) {
|
||||
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
vvsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x - y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
float minus_1 = -1;
|
||||
vDSP_vsmsa(
|
||||
(const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
float val = -(*s);
|
||||
vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x - y; },
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
int val = -(*s);
|
||||
vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n);
|
||||
},
|
||||
UseDefaultBinaryOp());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvtanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvtanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,117 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void _qmm_t_4_64(
|
||||
float* result,
|
||||
const float* x,
|
||||
const uint32_t* w,
|
||||
const float* scales,
|
||||
const float* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int B,
|
||||
bool batched_w) {
|
||||
constexpr int bits = 4;
|
||||
constexpr int group_size = 64;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
int w_els = N * K / pack_factor;
|
||||
int g_els = w_els * pack_factor / group_size;
|
||||
|
||||
for (int i = 0; i < B; i++) {
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const float* scales_local = scales;
|
||||
const float* biases_local = biases;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
const simd_float16* x_local = (simd_float16*)x;
|
||||
simd_float16 sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
float scale = *scales_local++;
|
||||
float bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||
// TODO: vectorize this properly
|
||||
simd_uint16 wi;
|
||||
for (int e = 0; e < 2; e++) {
|
||||
uint32_t wii = *w_local++;
|
||||
for (int p = 0; p < 8; p++) {
|
||||
wi[e * 8 + p] = wii & bitmask;
|
||||
wii >>= bits;
|
||||
}
|
||||
}
|
||||
simd_float16 wf = simd_float(wi);
|
||||
wf *= scale;
|
||||
wf += bias;
|
||||
|
||||
sum += (*x_local) * wf;
|
||||
x_local++;
|
||||
}
|
||||
}
|
||||
|
||||
*result = simd_reduce_add(sum);
|
||||
result++;
|
||||
}
|
||||
|
||||
x += K;
|
||||
}
|
||||
if (batched_w) {
|
||||
w += w_els;
|
||||
scales += g_els;
|
||||
biases += g_els;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
auto& scales = inputs[2];
|
||||
auto& biases = inputs[3];
|
||||
|
||||
bool condition =
|
||||
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
|
||||
scales.flags().row_contiguous && biases.flags().row_contiguous &&
|
||||
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
|
||||
|
||||
if (condition) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
int B = x.size() / K / M;
|
||||
bool batched_w = w.ndim() > 2;
|
||||
_qmm_t_4_64(
|
||||
out.data<float>(),
|
||||
x.data<float>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<float>(),
|
||||
biases.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
B,
|
||||
batched_w);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,139 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct MinReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return simd_min(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct MaxReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return std::max(a, b);
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct SumReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, int N, typename Reduction>
|
||||
struct StridedReduce {
|
||||
void operator()(const T* x, T* accum, int size, size_t stride) {
|
||||
Reduction op;
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
size_t s = stride;
|
||||
T* a = accum;
|
||||
while (s >= N) {
|
||||
*(VT*)a = op((*(VT*)x), (*(VT*)a));
|
||||
x += N;
|
||||
a += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*a = op(*a, *x);
|
||||
a++;
|
||||
x++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (in.dtype() == float32) {
|
||||
if (reduce_type_ == Reduce::Sum) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
0,
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
SumReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float acc;
|
||||
vDSP_sve((const float*)x, 1, &acc, size);
|
||||
(*accum) += acc;
|
||||
},
|
||||
[](auto* accum, auto x) { *accum += x; });
|
||||
return;
|
||||
} else if (reduce_type_ == Reduce::Max) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
-std::numeric_limits<float>::infinity(),
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
MaxReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float max;
|
||||
vDSP_maxv((const float*)x, 1, &max, size);
|
||||
(*accum) = (*accum < max) ? max : *accum;
|
||||
},
|
||||
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
|
||||
return;
|
||||
} else if (reduce_type_ == Reduce::Min) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
std::numeric_limits<float>::infinity(),
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
MinReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float min;
|
||||
vDSP_minv((const float*)x, 1, &min, size);
|
||||
(*accum) = (*accum > min) ? min : *accum;
|
||||
},
|
||||
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
|
||||
return;
|
||||
}
|
||||
}
|
||||
// TODO: Add integer addition and min/max using the templates above and
|
||||
// simd_int16 and friends.
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,393 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* Compute exp(x) in an optimizer friendly way as follows:
|
||||
*
|
||||
* First change the problem to computing 2**y where y = x / ln(2).
|
||||
*
|
||||
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
|
||||
* `ipart` and y2 is fractional part. For the integer part we perform bit
|
||||
* shifting and for the fractional part we use a polynomial approximation.
|
||||
*
|
||||
* The algorithm and constants of the polynomial taken from
|
||||
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
|
||||
* from Cephes math library.
|
||||
*
|
||||
* Note: The implementation below is a general fast exp. There could be faster
|
||||
* implementations for numbers strictly < 0.
|
||||
*/
|
||||
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
|
||||
auto x = x_init * 1.442695; // multiply with log_2(e)
|
||||
simd_float16 ipart, fpart;
|
||||
simd_int16 epart;
|
||||
x = simd_clamp(x, -80, 80);
|
||||
ipart = simd::floor(x + 0.5);
|
||||
fpart = x - ipart;
|
||||
|
||||
x = 1.535336188319500e-4f;
|
||||
x = x * fpart + 1.339887440266574e-3f;
|
||||
x = x * fpart + 9.618437357674640e-3f;
|
||||
x = x * fpart + 5.550332471162809e-2f;
|
||||
x = x * fpart + 2.402264791363012e-1f;
|
||||
x = x * fpart + 6.931472028550421e-1f;
|
||||
x = x * fpart + 1.000000000000000f;
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
epart = (simd_int(ipart) + 127) << 23;
|
||||
|
||||
// Avoid supressing NaNs
|
||||
simd_int16 eq = (x_init == x_init);
|
||||
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
|
||||
}
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
/**
|
||||
* The ARM neon equivalent of the fast exp above.
|
||||
*/
|
||||
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
|
||||
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
|
||||
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
|
||||
|
||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
|
||||
float16x8_t fpart = vsubq_f16(x, ipart);
|
||||
|
||||
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
int16x8_t epart = vcvtq_s16_f16(ipart);
|
||||
epart = vaddq_s16(epart, vdupq_n_s16(15));
|
||||
epart = vshlq_n_s16(epart, 10);
|
||||
|
||||
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of folding maximum for ARM neon. This should possibly be
|
||||
* refactored out of softmax.cpp at some point.
|
||||
*/
|
||||
inline float16_t neon_reduce_max(float16x8_t x) {
|
||||
float16x4_t y;
|
||||
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
|
||||
y = vpmax_f16(y, y);
|
||||
y = vpmax_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of folding sum for ARM neon. This should possibly be
|
||||
* refactored out of softmax.cpp at some point.
|
||||
*/
|
||||
inline float16_t neon_reduce_add(float16x8_t x) {
|
||||
float16x4_t y;
|
||||
float16x4_t zero = vdup_n_f16(0);
|
||||
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
|
||||
y = vpadd_f16(y, zero);
|
||||
y = vpadd_f16(y, zero);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct NeonFp16SimdOps {
|
||||
VT init(T a) {
|
||||
return vdupq_n_f16(a);
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return vld1q_f16(a);
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
vst1q_f16(dst, x);
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return vmaxq_f16(a, b);
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return neon_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return vaddq_f16(a, b);
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return vsubq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return vmulq_f16(a, b);
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return vmulq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return neon_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return neon_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct AccelerateSimdOps {
|
||||
VT init(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return *(VT*)a;
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
*(VT*)dst = x;
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return simd_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return simd_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return simd_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
VT vals;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vals = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vals[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vmaximum = ops.max(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT maximum = ops.reduce_max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
VT vnormalizer = ops.init(0.0);
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
VT vexp;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vexp = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vexp = ops.exp(ops.sub(vexp, maximum));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = ops.add(vnormalizer, vexp);
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = ops.reduce_add(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
} else {
|
||||
VT vexp;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
current_out_ptr[i] = vexp[i];
|
||||
}
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
array in = check_input(std::move(inputs[0]));
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case uint16:
|
||||
case uint32:
|
||||
case uint64:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::invalid_argument(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<
|
||||
float,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
if (precise_) {
|
||||
softmax<
|
||||
float16_t,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
eval(inputs, out); // Redirect to common backend for consistency
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
eval(inputs, out);
|
||||
break;
|
||||
case complex64:
|
||||
eval(inputs, out);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,28 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include "mlx/dtype.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
||||
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
||||
switch (kindof(mlx_dtype)) {
|
||||
case Dtype::Kind::b:
|
||||
return BNNSDataTypeBoolean;
|
||||
case Dtype::Kind::u:
|
||||
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
||||
case Dtype::Kind::i:
|
||||
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
||||
case Dtype::Kind::f:
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
||||
case Dtype::Kind::V:
|
||||
return BNNSDataTypeBFloat16;
|
||||
case Dtype::Kind::c:
|
||||
throw std::invalid_argument("BNNS does not support complex types");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,62 +1,9 @@
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(COMPILER ${CMAKE_C_COMPILER})
|
||||
set(CLANG TRUE)
|
||||
else()
|
||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND
|
||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG}
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
ops.h)
|
||||
|
||||
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
||||
|
||||
add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
if(IOS)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
|
||||
endif()
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
|
@@ -1,74 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void arange(T start, T next, array& out, size_t size) {
|
||||
auto ptr = out.data<T>();
|
||||
auto step_size = next - start;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ptr[i] = start;
|
||||
start += step_size;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void arange(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
double start,
|
||||
double step) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
break;
|
||||
case uint8:
|
||||
arange<uint8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint16:
|
||||
arange<uint16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint32:
|
||||
arange<uint32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint64:
|
||||
arange<uint64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int8:
|
||||
arange<int8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int16:
|
||||
arange<int16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int32:
|
||||
arange<int32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int64:
|
||||
arange<int64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float16:
|
||||
arange<float16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float32:
|
||||
arange<float>(start, start + step, out, out.size());
|
||||
break;
|
||||
case bfloat16:
|
||||
arange<bfloat16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case complex64:
|
||||
arange<complex64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,112 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
std::vector<size_t> strides = in.strides();
|
||||
std::vector<int> shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto in_ptr = in.data<InT>() + loc;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
|
||||
op(j, (*in_ptr), &ind_v, &v);
|
||||
}
|
||||
out.data<uint32_t>()[i] = ind_v;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void arg_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
ArgReduce::ReduceType rtype,
|
||||
int axis) {
|
||||
switch (rtype) {
|
||||
case ArgReduce::ArgMin: {
|
||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||
if (x < (*y)) {
|
||||
(*y) = x;
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
break;
|
||||
}
|
||||
case ArgReduce::ArgMax: {
|
||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||
if (x > (*y)) {
|
||||
(*y) = x;
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgReduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint8:
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int8:
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int16:
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int32:
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int64:
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float16:
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float32:
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case bfloat16:
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case complex64:
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,331 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/binary_two.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, U, Op> opsv(op);
|
||||
DefaultVectorScalar<T, U, Op> opvs(op);
|
||||
DefaultVectorVector<T, U, Op> opvv(op);
|
||||
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
comparison_op<bool, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
comparison_op<uint8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
comparison_op<uint16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
comparison_op<uint32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
comparison_op<uint64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
comparison_op<int8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
comparison_op<int16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
comparison_op<int32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
comparison_op<int64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
comparison_op<float16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
comparison_op<float, bool>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
comparison_op<bfloat16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
comparison_op<complex64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Add::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Add());
|
||||
}
|
||||
|
||||
void DivMod::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto integral_op = [](auto x, auto y) {
|
||||
return std::make_pair(x / y, x % y);
|
||||
};
|
||||
auto float_op = [](auto x, auto y) {
|
||||
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||
};
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, integral_op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, float_op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case complex64:
|
||||
// Should never get here
|
||||
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Divide());
|
||||
}
|
||||
|
||||
void Remainder::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Remainder());
|
||||
}
|
||||
|
||||
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (equal_nan_) {
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual());
|
||||
} else {
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Equal());
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater());
|
||||
}
|
||||
|
||||
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
|
||||
}
|
||||
|
||||
void Less::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less());
|
||||
}
|
||||
|
||||
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
|
||||
}
|
||||
|
||||
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[logaddexp] Cannot compute logaddexp for arrays with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Maximum());
|
||||
}
|
||||
|
||||
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Minimum());
|
||||
}
|
||||
|
||||
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Multiply());
|
||||
}
|
||||
|
||||
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
|
||||
}
|
||||
|
||||
void Power::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Power());
|
||||
}
|
||||
|
||||
void Subtract::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Subtract());
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto dispatch_type = [&a, &b, &out](auto op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[BitwiseBinary::eval_cpu] Type not supported");
|
||||
break;
|
||||
}
|
||||
};
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
dispatch_type(detail::BitwiseAnd());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
dispatch_type(detail::BitwiseOr());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
dispatch_type(detail::BitwiseXor());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
dispatch_type(detail::LeftShift());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
dispatch_type(detail::RightShift());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[arctan2] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan2] Cannot compute inverse tangent for arrays"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
@@ -9,8 +8,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class BinaryOpType {
|
||||
ScalarScalar,
|
||||
ScalarVector,
|
||||
@@ -19,7 +16,7 @@ enum class BinaryOpType {
|
||||
General,
|
||||
};
|
||||
|
||||
BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
BinaryOpType bopt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1) {
|
||||
bopt = BinaryOpType::ScalarScalar;
|
||||
@@ -28,8 +25,8 @@ BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
} else if (b.data_size() == 1 && a.flags().contiguous) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
} else if (
|
||||
a.flags().row_contiguous && b.flags().row_contiguous ||
|
||||
a.flags().col_contiguous && b.flags().col_contiguous) {
|
||||
(a.flags().row_contiguous && b.flags().row_contiguous) ||
|
||||
(a.flags().col_contiguous && b.flags().col_contiguous)) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
} else {
|
||||
bopt = BinaryOpType::General;
|
||||
@@ -37,29 +34,24 @@ BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
return bopt;
|
||||
}
|
||||
|
||||
void set_binary_op_output_data(
|
||||
inline void set_binary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
BinaryOpType bopt,
|
||||
bool donate_with_move = false) {
|
||||
BinaryOpType bopt) {
|
||||
bool b_donatable = is_donatable(b, out);
|
||||
bool a_donatable = is_donatable(a, out);
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||
allocator::malloc(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -67,14 +59,10 @@ void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
if (a_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
out.copy_shared_buffer(a);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -82,20 +70,12 @@ void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
if (a_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b_donatable) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -103,428 +83,15 @@ void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (
|
||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
struct UseDefaultBinaryOp {};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorScalar {
|
||||
Op op;
|
||||
|
||||
DefaultVectorScalar(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *b;
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, scalar);
|
||||
dst++;
|
||||
a++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultScalarVector {
|
||||
Op op;
|
||||
|
||||
DefaultScalarVector(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *a;
|
||||
while (size-- > 0) {
|
||||
*dst = op(scalar, *b);
|
||||
dst++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorVector {
|
||||
Op op;
|
||||
|
||||
DefaultVectorVector(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, *b);
|
||||
dst++;
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int D, bool Strided>
|
||||
void binary_op_dims(
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
auto stride_out = out_strides[axis];
|
||||
auto N = shape[axis];
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
binary_op_dims<T, U, Op, D - 1, Strided>(
|
||||
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||
} else {
|
||||
if constexpr (Strided) {
|
||||
op(a, b, out, stride_out);
|
||||
} else {
|
||||
*out = op(*a, *b);
|
||||
}
|
||||
}
|
||||
out += stride_out;
|
||||
a += stride_a;
|
||||
b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, bool Strided, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int dim,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& a_strides,
|
||||
const std::vector<size_t>& b_strides,
|
||||
const std::vector<size_t>& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims<T, U, Op, 1, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims<T, U, Op, 2, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
|
||||
size_t stride = out_strides[dim - 4];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
dim - 3);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename OpSV,
|
||||
typename OpVS,
|
||||
typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||
const auto& a_strides = new_strides[0];
|
||||
const auto& b_strides = new_strides[1];
|
||||
const auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
||||
|
||||
auto ndim = new_shape.size();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false>(
|
||||
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||
// with template specializations and overloading. Would it be simpler?
|
||||
|
||||
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv and opvs were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opsv and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvs and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opvs was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opvv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// All ops provided
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, T, Op> opsv(op);
|
||||
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||
DefaultVectorVector<T, T, Op> opvv(op);
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename... Ops>
|
||||
void binary(const array& a, const array& b, array& out, Ops... ops) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, ops...);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, ops...);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, ops...);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, out, ops...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
24
mlx/backend/common/broadcasting.cpp
Normal file
24
mlx/backend/common/broadcasting.cpp
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/common/broadcasting.h
Normal file
11
mlx/backend/common/broadcasting.h
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,74 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||
// the matrix should be symmetric:
|
||||
// (A)ᵀ = A
|
||||
// and that a column-major lower triangular matrix is a row-major upper
|
||||
// triangular matrix, so uplo is the opposite of what we would expect from
|
||||
// upper
|
||||
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
float* matrix = factor.data<float>();
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(spotrf)
|
||||
(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
// to catch errors from the implementation we should throw.
|
||||
if (info < 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[cholesky] Cholesky decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Zero out the upper/lower triangle while advancing the pointer to the
|
||||
// next matrix at the same time.
|
||||
for (int row = 0; row < N; row++) {
|
||||
if (upper) {
|
||||
std::fill(matrix, matrix + row, 0);
|
||||
} else {
|
||||
std::fill(matrix + row + 1, matrix + N, 0);
|
||||
}
|
||||
matrix += N;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Cholesky::eval] only supports float32.");
|
||||
}
|
||||
cholesky_impl(inputs[0], output, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -43,22 +44,11 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
std::vector<size_t> strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -85,9 +75,19 @@ void Depends::eval(
|
||||
}
|
||||
}
|
||||
|
||||
void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
auto strides = in.strides();
|
||||
for (auto ax : axes_) {
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
@@ -135,15 +135,16 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
case bfloat16:
|
||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||
break;
|
||||
case float64:
|
||||
*out.data<double>() = static_cast<double>(numel);
|
||||
break;
|
||||
case complex64:
|
||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out) {
|
||||
// Special case for empty arrays or row contiguous arrays
|
||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||
return {false, out.strides()};
|
||||
@@ -151,8 +152,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
// Special case for scalars
|
||||
if (in.ndim() == 0) {
|
||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||
return {false, out_strides};
|
||||
return {false, Strides(out.ndim(), 0)};
|
||||
}
|
||||
|
||||
// Firstly let's collapse all the contiguous dimensions of the input
|
||||
@@ -160,7 +160,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
|
||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||
// let's check.
|
||||
std::vector<size_t> out_strides;
|
||||
Strides out_strides;
|
||||
bool copy_necessary = false;
|
||||
int j = 0;
|
||||
for (int i = 0; i < out.ndim(); i++) {
|
||||
@@ -181,9 +181,9 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||
return {copy_necessary, out_strides};
|
||||
}
|
||||
|
||||
void Reshape::shared_buffer_reshape(
|
||||
void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
const Strides& out_strides,
|
||||
array& out) {
|
||||
auto flags = in.flags();
|
||||
if (flags.row_contiguous) {
|
||||
@@ -249,16 +249,18 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
Strides strides;
|
||||
for (int i = 0, j = 0; i < in.ndim(); ++i) {
|
||||
if (j < axes_.size() && i == axes_[j]) {
|
||||
j++;
|
||||
} else {
|
||||
strides.push_back(in.strides(i));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(data_offset, inp_strides);
|
||||
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -268,7 +270,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
std::vector<size_t> out_strides(out.ndim());
|
||||
Strides out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
@@ -285,8 +287,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
|
@@ -130,7 +130,7 @@ std::string build_lib_name(
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape) {
|
||||
const Shape& shape) {
|
||||
bool contiguous = true;
|
||||
bool all_contig = true;
|
||||
bool all_row_contig = true;
|
||||
@@ -161,11 +161,10 @@ void compiled_allocate_outputs(
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers /* = false */) {
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
std::vector<size_t> strides;
|
||||
Strides strides;
|
||||
size_t data_size;
|
||||
array::Flags flags;
|
||||
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
|
||||
@@ -178,11 +177,7 @@ void compiled_allocate_outputs(
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o++].move_shared_buffer(in);
|
||||
} else {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
@@ -193,7 +188,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
@@ -210,18 +205,13 @@ void compiled_allocate_outputs(
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
} else {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
}
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -11,9 +11,7 @@
|
||||
namespace mlx::core {
|
||||
|
||||
inline bool is_static_cast(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
|
||||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
@@ -56,7 +54,7 @@ inline bool is_scalar(const array& x) {
|
||||
// Check if we can use a contiguous operation given inputs and the output shape
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& shape);
|
||||
const Shape& shape);
|
||||
|
||||
// Allocate space for the outputs possibly with input donation
|
||||
void compiled_allocate_outputs(
|
||||
@@ -64,7 +62,6 @@ void compiled_allocate_outputs(
|
||||
std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers = false);
|
||||
bool contiguous);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -23,18 +22,25 @@ enum class CopyType {
|
||||
GeneralGeneral
|
||||
};
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
|
||||
template <typename stride_t>
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(in);
|
||||
return true;
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,196 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define DEFAULT(primitive) \
|
||||
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
#define DEFAULT_MULTI(primitive) \
|
||||
void primitive::eval_cpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
primitive::eval(inputs, outputs); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
DEFAULT(Abs)
|
||||
DEFAULT(Add)
|
||||
DEFAULT(Arange)
|
||||
DEFAULT(ArcCos)
|
||||
DEFAULT(ArcCosh)
|
||||
DEFAULT(ArcSin)
|
||||
DEFAULT(ArcSinh)
|
||||
DEFAULT(ArcTan)
|
||||
DEFAULT(ArcTan2)
|
||||
DEFAULT(ArcTanh)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Full)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(Log)
|
||||
DEFAULT(Log1p)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
DEFAULT(Multiply)
|
||||
DEFAULT(Negative)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(Power)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Sin)
|
||||
DEFAULT(Sinh)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Square)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Subtract)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
DEFAULT_MULTI(Eigh)
|
||||
|
||||
namespace {
|
||||
|
||||
inline void matmul_common_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
auto check_transpose = [](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[Matmul::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return matmul_common_general(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
copy(c, out, ctype);
|
||||
|
||||
return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,117 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void ssyevd(
|
||||
char jobz,
|
||||
char uplo,
|
||||
float* a,
|
||||
int N,
|
||||
float* w,
|
||||
float* work,
|
||||
int lwork,
|
||||
int* iwork,
|
||||
int liwork) {
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(ssyevd)
|
||||
(
|
||||
/* jobz = */ &jobz,
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ a,
|
||||
/* lda = */ &N,
|
||||
/* w = */ w,
|
||||
/* work = */ work,
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ iwork,
|
||||
/* liwork = */ &liwork,
|
||||
/* info = */ &info);
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
const auto& a = inputs[0];
|
||||
auto& values = outputs[0];
|
||||
|
||||
auto vectors = compute_eigenvectors_
|
||||
? outputs[1]
|
||||
: array(a.shape(), a.dtype(), nullptr, {});
|
||||
|
||||
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
||||
|
||||
copy(
|
||||
a,
|
||||
vectors,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
if (compute_eigenvectors_) {
|
||||
// Set the strides and flags so the eigenvectors
|
||||
// are in the columns of the output
|
||||
auto flags = vectors.flags();
|
||||
auto strides = vectors.strides();
|
||||
auto ndim = a.ndim();
|
||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||
|
||||
if (a.size() > 1) {
|
||||
flags.row_contiguous = false;
|
||||
if (ndim > 2) {
|
||||
flags.col_contiguous = false;
|
||||
} else {
|
||||
flags.col_contiguous = true;
|
||||
}
|
||||
}
|
||||
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
|
||||
}
|
||||
|
||||
auto vec_ptr = vectors.data<float>();
|
||||
auto eig_ptr = values.data<float>();
|
||||
|
||||
char jobz = compute_eigenvectors_ ? 'V' : 'N';
|
||||
auto N = a.shape(-1);
|
||||
|
||||
// Work query
|
||||
int lwork;
|
||||
int liwork;
|
||||
{
|
||||
float work;
|
||||
int iwork;
|
||||
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < a.size() / (N * N); ++i) {
|
||||
ssyevd(
|
||||
jobz,
|
||||
uplo_[0],
|
||||
vec_ptr,
|
||||
N,
|
||||
eig_ptr,
|
||||
static_cast<float*>(work_buf.buffer.raw_ptr()),
|
||||
lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
liwork);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,40 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cmath>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/* Approximation to the inverse error function.
|
||||
* Based on code from:
|
||||
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
|
||||
*/
|
||||
float erfinv(float a) {
|
||||
auto t = std::fma(a, 0.0f - a, 1.0f);
|
||||
t = std::log(t);
|
||||
float p;
|
||||
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
||||
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||
} else { // maximum ulp error = 2.35002
|
||||
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
}
|
||||
return a * p;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,87 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
std::vector<std::ptrdiff_t> strides_in(
|
||||
in.strides().begin(), in.strides().end());
|
||||
for (auto& s : strides_in) {
|
||||
s *= in.itemsize();
|
||||
}
|
||||
std::vector<std::ptrdiff_t> strides_out(
|
||||
out.strides().begin(), out.strides().end());
|
||||
for (auto& s : strides_out) {
|
||||
s *= out.itemsize();
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
std::vector<size_t> shape;
|
||||
if (out.dtype() == float32) {
|
||||
shape.insert(shape.end(), out.shape().begin(), out.shape().end());
|
||||
} else {
|
||||
shape.insert(shape.end(), in.shape().begin(), in.shape().end());
|
||||
}
|
||||
|
||||
float scale = 1.0f;
|
||||
if (inverse_) {
|
||||
size_t nelem = std::accumulate(
|
||||
axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) {
|
||||
return x * shape[y];
|
||||
});
|
||||
scale /= nelem;
|
||||
}
|
||||
if (in.dtype() == complex64 && out.dtype() == complex64) {
|
||||
auto in_ptr =
|
||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||
auto out_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||
pocketfft::c2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
} else if (in.dtype() == float32 && out.dtype() == complex64) {
|
||||
auto in_ptr = in.data<float>();
|
||||
auto out_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||
pocketfft::r2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
} else if (in.dtype() == complex64 && out.dtype() == float32) {
|
||||
auto in_ptr =
|
||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||
auto out_ptr = out.data<float>();
|
||||
pocketfft::c2r(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[FFT] Received unexpected input and output type combination.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
if (n > (1 << 26)) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,394 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename IdxT>
|
||||
inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void gather(
|
||||
const array& src,
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& slice_sizes) {
|
||||
// If the array is row contiguous then we can do a contiguous copy given
|
||||
// two conditions on the slice size:
|
||||
// - Any number of leading ones in the slice sizes are allowed
|
||||
// - All other slice sizes match the corresponding dimension except the
|
||||
// first non-singleton slice size
|
||||
// If the array is col contiguous then the reverse is the case:
|
||||
// - Any number of trailing ones in the slice sizes are allowed
|
||||
// - All other slice sizes match the corresponding dimension except the
|
||||
// first non-singleton slice size from the end
|
||||
|
||||
bool can_copy = false;
|
||||
if (src.flags().row_contiguous) {
|
||||
can_copy = true;
|
||||
|
||||
// Ignore leading 1s
|
||||
int i = 0;
|
||||
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
|
||||
;
|
||||
|
||||
// Check the remaining
|
||||
i++;
|
||||
for (; i < src.ndim() && can_copy; ++i) {
|
||||
can_copy = (src.shape(i) == slice_sizes[i]);
|
||||
}
|
||||
} else if (src.flags().col_contiguous) {
|
||||
can_copy = true;
|
||||
|
||||
// Ignore trailing 1s
|
||||
int i = slice_sizes.size() - 1;
|
||||
for (; i >= 0 && slice_sizes[i] == 1; --i)
|
||||
;
|
||||
|
||||
// Skip the next slice size and check the remaining
|
||||
i--;
|
||||
for (; i >= 0 && can_copy; --i) {
|
||||
can_copy = (src.shape(i) == slice_sizes[i]);
|
||||
}
|
||||
}
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes) {
|
||||
slice_size *= s;
|
||||
}
|
||||
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||
const T* src_ptr = src.data<T>();
|
||||
T* dst_ptr = out.data<T>();
|
||||
size_t out_idx = 0;
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> src_it;
|
||||
if (!can_copy && src.ndim() > 0) {
|
||||
src_it = std::move(
|
||||
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
|
||||
}
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = its[ii].loc;
|
||||
its[ii].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
|
||||
src_idx += (idx_val * src.strides()[ax]);
|
||||
}
|
||||
|
||||
if (slice_size == 1) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx];
|
||||
} else if (can_copy) {
|
||||
std::copy(
|
||||
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
|
||||
src_it.step();
|
||||
}
|
||||
src_it.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IdxT>
|
||||
void dispatch_gather(
|
||||
const array& src,
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& size) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather<bool, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint8:
|
||||
gather<uint8_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint16:
|
||||
gather<uint16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint32:
|
||||
gather<uint32_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint64:
|
||||
gather<uint64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int8:
|
||||
gather<int8_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int16:
|
||||
gather<int16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int32:
|
||||
gather<int32_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int64:
|
||||
gather<int64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case float16:
|
||||
gather<float16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case float32:
|
||||
gather<float, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case complex64:
|
||||
gather<complex64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Gather::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end());
|
||||
|
||||
if (inds.empty()) {
|
||||
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case bool_:
|
||||
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case float16:
|
||||
case float32:
|
||||
case bfloat16:
|
||||
case complex64:
|
||||
throw std::runtime_error(
|
||||
"[Gather::eval] Cannot gather with floating point indices.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT, typename OpT>
|
||||
void scatter(
|
||||
const array& updates,
|
||||
array& out,
|
||||
const std::vector<array>& inds,
|
||||
const std::vector<int>& axes,
|
||||
const OpT& op) {
|
||||
int nind = inds.size();
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
|
||||
std::vector<int> update_shape(
|
||||
updates.shape().begin() + inds_ndim, updates.shape().end());
|
||||
size_t update_size = 1;
|
||||
for (auto us : update_shape) {
|
||||
update_size *= us;
|
||||
}
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> update_it(updates);
|
||||
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
|
||||
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < nind; ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
|
||||
out_offset += (idx_val * out.strides()[ax]);
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
op(updates.data<InT>()[update_it.loc],
|
||||
out.data<InT>() + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
}
|
||||
out_it.reset();
|
||||
update_it.reset();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT>
|
||||
void dispatch_scatter_inds(
|
||||
array& out,
|
||||
const std::vector<array>& indices,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype) {
|
||||
switch (rtype) {
|
||||
case Scatter::None:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
|
||||
break;
|
||||
case Scatter::Max:
|
||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
});
|
||||
break;
|
||||
case Scatter::Min:
|
||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void dispatch_scatter(
|
||||
array& out,
|
||||
const std::vector<array>& inds,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype) {
|
||||
if (inds.empty()) {
|
||||
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case float16:
|
||||
case float32:
|
||||
case bfloat16:
|
||||
case complex64:
|
||||
throw std::runtime_error(
|
||||
"[Scatter::eval_cpu] Cannot scatter with floating point indices.");
|
||||
}
|
||||
}
|
||||
|
||||
void Scatter::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() >= 2);
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
|
||||
auto& updates = inputs.back();
|
||||
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
copy(src, out, CopyType::General);
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,120 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(strtri)
|
||||
(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
return info;
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void general_inv(array& inv, int N, int i) {
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
// Compute LU factorization.
|
||||
sgetrf_(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
static const int lwork_query = -1;
|
||||
float workspace_size = 0;
|
||||
|
||||
// Compute workspace size.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ nullptr,
|
||||
/* work = */ &workspace_size,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
void tri_inv(array& inv, int N, int i, bool upper) {
|
||||
const char uplo = upper ? 'L' : 'U';
|
||||
const char diag = 'N';
|
||||
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: triangular inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the following
|
||||
// identity to avoid transposing (see
|
||||
// https://math.stackexchange.com/a/340234):
|
||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||
|
||||
// The inverse is computed in place, so just copy the input to the output.
|
||||
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
if (tri) {
|
||||
tri_inv(inv, N, i, upper);
|
||||
} else {
|
||||
general_inv(inv, N, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
||||
}
|
||||
inverse_impl(inputs[0], output, tri_, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,24 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)
|
||||
|
||||
// This is to work around a change in the function signatures of lapack >= 3.9.1
|
||||
// where functions taking char* also include a strlen argument, see a similar
|
||||
// change in OpenCV:
|
||||
// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57
|
||||
#define MLX_LAPACK_FUNC(f) LAPACK_##f
|
||||
|
||||
#else
|
||||
|
||||
#define MLX_LAPACK_FUNC(f) f##_
|
||||
|
||||
#endif
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user