mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
537 Commits
interrupt_
...
937ce79660
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
937ce79660 | ||
|
|
208f5441a7 | ||
|
|
b862d842e1 | ||
|
|
f7a400951a | ||
|
|
27232db1ba | ||
|
|
a4b3bc969b | ||
|
|
667c0f3bb9 | ||
|
|
6245824d42 | ||
|
|
39289ef025 | ||
|
|
aefc9bd3f6 | ||
|
|
997cfc7699 | ||
|
|
1fa8dc5797 | ||
|
|
a6d6717181 | ||
|
|
941cfe23d7 | ||
|
|
9abb0b8123 | ||
|
|
50d3914c67 | ||
|
|
cacbdbf995 | ||
|
|
193cdcd81a | ||
|
|
d8ceae7b77 | ||
|
|
eff0e31f00 | ||
|
|
6c5785bc2f | ||
|
|
8879ee00eb | ||
|
|
6e762fe2e2 | ||
|
|
2b95d0c270 | ||
|
|
b054838780 | ||
|
|
dd79d3c465 | ||
|
|
704fd1ae28 | ||
|
|
c9f4dc851f | ||
|
|
f8bd675655 | ||
|
|
23a9168d34 | ||
|
|
bca205e287 | ||
|
|
1d4eacb737 | ||
|
|
8abd37ad05 | ||
|
|
3e05cea9f8 | ||
|
|
5b0f047226 | ||
|
|
618c87af8c | ||
|
|
d5f61a93fa | ||
|
|
4a09264236 | ||
|
|
0dbc7e5bee | ||
|
|
0d68efd461 | ||
|
|
f9e1a14135 | ||
|
|
d8e9ded928 | ||
|
|
60939d010c | ||
|
|
fdcd2923fd | ||
|
|
54f1cc6e3e | ||
|
|
b3825ac149 | ||
|
|
7f4b7e553c | ||
|
|
ad16f41a7f | ||
|
|
f46877bc08 | ||
|
|
6f35017d1b | ||
|
|
b167f0df1c | ||
|
|
a9f0d6b160 | ||
|
|
940f4c7818 | ||
|
|
35f81728f1 | ||
|
|
4442ed86c1 | ||
|
|
698559c231 | ||
|
|
ecc4879b07 | ||
|
|
32b18d8b66 | ||
|
|
472c43a0c8 | ||
|
|
b7214ff01e | ||
|
|
76414c8971 | ||
|
|
49e4566df3 | ||
|
|
aad49f932f | ||
|
|
86765cce34 | ||
|
|
1bedcbd556 | ||
|
|
9ac7dbe877 | ||
|
|
1bf605d56d | ||
|
|
3c622ddd1d | ||
|
|
27ff069175 | ||
|
|
3b2ffcefc3 | ||
|
|
b65f882df3 | ||
|
|
b704e9e77a | ||
|
|
66519fb348 | ||
|
|
8973550ff3 | ||
|
|
3f866be665 | ||
|
|
23f81ed1c1 | ||
|
|
3fe2250c00 | ||
|
|
047114b988 | ||
|
|
9320eb89a8 | ||
|
|
75819d70ea | ||
|
|
60d80a3728 | ||
|
|
eba6a9d163 | ||
|
|
be9e2aebd6 | ||
|
|
df58b4133a | ||
|
|
27778156dc | ||
|
|
761f901a41 | ||
|
|
6ece97f69b | ||
|
|
d3bc6a9bff | ||
|
|
26ceb507eb | ||
|
|
910b3e3299 | ||
|
|
50fa315d18 | ||
|
|
1ff2b713b6 | ||
|
|
50514a6146 | ||
|
|
93d76b0f30 | ||
|
|
78678de0cd | ||
|
|
ed9c6b1117 | ||
|
|
39b04ce638 | ||
|
|
d9e6349657 | ||
|
|
b901a9f311 | ||
|
|
68c5fa1c95 | ||
|
|
793a31eeb6 | ||
|
|
74c1ed25bb | ||
|
|
ec72b44417 | ||
|
|
460691a0e8 | ||
|
|
969924cc69 | ||
|
|
d1e06117e8 | ||
|
|
539d8322d1 | ||
|
|
c4767d110f | ||
|
|
895217f25b | ||
|
|
0cfeeb60ca | ||
|
|
8f8af61a37 | ||
|
|
233384161e | ||
|
|
5bcf3a6794 | ||
|
|
7707196297 | ||
|
|
7e3471c987 | ||
|
|
9f0ba3ddf1 | ||
|
|
4bce5f9b2d | ||
|
|
e9eab527eb | ||
|
|
36ca62dba8 | ||
|
|
9cbb1b0148 | ||
|
|
9bfc476d72 | ||
|
|
25e2356316 | ||
|
|
226a1d24e0 | ||
|
|
630350ad3e | ||
|
|
380aeb58ae | ||
|
|
f37389d100 | ||
|
|
e89e8b4272 | ||
|
|
85a8824a8c | ||
|
|
f5d4397e5c | ||
|
|
343e33b6d5 | ||
|
|
0073096dd1 | ||
|
|
e3d004fed9 | ||
|
|
a393435d28 | ||
|
|
a7a94b29d7 | ||
|
|
22a5da76c8 | ||
|
|
287c63a093 | ||
|
|
1c9ae1eaa1 | ||
|
|
c2c3e0b0a2 | ||
|
|
b0cc71ae71 | ||
|
|
e88f2d4a8e | ||
|
|
9cee557423 | ||
|
|
bbf1423953 | ||
|
|
eb24267b56 | ||
|
|
dc371ae7a5 | ||
|
|
e76a8dd5c5 | ||
|
|
b466dea982 | ||
|
|
7a6adda1e6 | ||
|
|
1a9f820af6 | ||
|
|
d4f4ff3c5e | ||
|
|
7c7e48dbd1 | ||
|
|
fbbf3b9b3e | ||
|
|
bf01ad9367 | ||
|
|
ae438d05fa | ||
|
|
711a645807 | ||
|
|
aa9d44b3d4 | ||
|
|
ec2ab42888 | ||
|
|
787c0d90cd | ||
|
|
e8b604a6a3 | ||
|
|
50cc09887f | ||
|
|
3f730e77aa | ||
|
|
caecbe876a | ||
|
|
8afb6d62f2 | ||
|
|
6ccfa603cd | ||
|
|
36cad99a11 | ||
|
|
ee18e1cbf0 | ||
|
|
af120c2bc0 | ||
|
|
6a3acf2301 | ||
|
|
d6977f2a57 | ||
|
|
db5443e831 | ||
|
|
52b8384d10 | ||
|
|
44cc5da4bc | ||
|
|
dde3682b69 | ||
|
|
17310d91a6 | ||
|
|
b194d65a6a | ||
|
|
a44b27f5f8 | ||
|
|
e5a33f2223 | ||
|
|
c1e3340b23 | ||
|
|
8f163a367d | ||
|
|
89a3df9014 | ||
|
|
c5d2937aa5 | ||
|
|
b61a65e313 | ||
|
|
04cbb4191c | ||
|
|
c5460762e7 | ||
|
|
8ce49cd39e | ||
|
|
9c68b50853 | ||
|
|
111f1e71af | ||
|
|
827003d568 | ||
|
|
d363a76aa4 | ||
|
|
70560b6bd5 | ||
|
|
7ef8a6f2d5 | ||
|
|
31c6f6e33f | ||
|
|
584d48458e | ||
|
|
5cf984ca87 | ||
|
|
a9bac3d9e5 | ||
|
|
5458d43247 | ||
|
|
a4dba65220 | ||
|
|
3dcb286baf | ||
|
|
4822c3dbe9 | ||
|
|
2ca75bb529 | ||
|
|
db14e29a0b | ||
|
|
d2f540f4e0 | ||
|
|
333ffea273 | ||
|
|
f55b6f1f2f | ||
|
|
30561229c7 | ||
|
|
068a4612e9 | ||
|
|
5722c147de | ||
|
|
f6819a1f26 | ||
|
|
f93f87c802 | ||
|
|
9392fc3f88 | ||
|
|
e843c4d8d5 | ||
|
|
0c5fc63a36 | ||
|
|
e397177f6e | ||
|
|
f4c8888cbe | ||
|
|
25c1e03205 | ||
|
|
512281781c | ||
|
|
ac85ddfdb7 | ||
|
|
65d0d40232 | ||
|
|
cea9369610 | ||
|
|
e7c6e1db82 | ||
|
|
c5fcd5b61b | ||
|
|
1df9887998 | ||
|
|
73f22d6226 | ||
|
|
c422050ca7 | ||
|
|
1ba18ff7d9 | ||
|
|
37b440faa8 | ||
|
|
888b13ed63 | ||
|
|
4abb218d21 | ||
|
|
6441c21a94 | ||
|
|
dfb5022eab | ||
|
|
ac207ce7aa | ||
|
|
fce53b61d6 | ||
|
|
8ae4a76308 | ||
|
|
7fde1b6a1e | ||
|
|
aa7b47481a | ||
|
|
56be773610 | ||
|
|
a9bdd67baa | ||
|
|
f2adb5638d | ||
|
|
728d4db582 | ||
|
|
db5c7efcf6 | ||
|
|
7bb96e4249 | ||
|
|
fa89f0b150 | ||
|
|
ca973d1e83 | ||
|
|
828c5f1137 | ||
|
|
7d86a5c108 | ||
|
|
0b807893a7 | ||
|
|
6ad0889c8a | ||
|
|
737dd6d1ac | ||
|
|
aaf78f4c6b | ||
|
|
8831064493 | ||
|
|
be9bc96da4 | ||
|
|
86258f292f | ||
|
|
b26d88591c | ||
|
|
86c6a15571 | ||
|
|
8b25ce62d5 | ||
|
|
da5912e4f2 | ||
|
|
daafee676f | ||
|
|
d32519c8ee | ||
|
|
b405591249 | ||
|
|
3bf81ed1bd | ||
|
|
2204182bba | ||
|
|
3628e5d497 | ||
|
|
a0ae49d397 | ||
|
|
254476718b | ||
|
|
3adba92ebe | ||
|
|
ef631d63af | ||
|
|
970dbe8e25 | ||
|
|
641be9463b | ||
|
|
ab0e608862 | ||
|
|
1588659062 | ||
|
|
b9e88fb976 | ||
|
|
4ad53414dd | ||
|
|
d1165b215e | ||
|
|
dcb8319f3d | ||
|
|
5597fa089c | ||
|
|
9acec364c2 | ||
|
|
7d9d6ef456 | ||
|
|
6f5874a2f2 | ||
|
|
70dc336785 | ||
|
|
4e504039f5 | ||
|
|
d1f4d291e8 | ||
|
|
e1840853ce | ||
|
|
0f5ce173da | ||
|
|
588854195f | ||
|
|
28d068bce6 | ||
|
|
d107d8d495 | ||
|
|
1e496ddb82 | ||
|
|
74eccbf3fa | ||
|
|
08638223ca | ||
|
|
56cc858af9 | ||
|
|
f55c4ed1d6 | ||
|
|
93d70419e7 | ||
|
|
63f663d9c6 | ||
|
|
84b4d96efa | ||
|
|
aec67f2fa6 | ||
|
|
deee214a95 | ||
|
|
45adec102c | ||
|
|
31fc530c76 | ||
|
|
fbb3f65a1a | ||
|
|
6b1b8ea91b | ||
|
|
b2273733ea | ||
|
|
f409b229a4 | ||
|
|
30571e2326 | ||
|
|
d7734edd9f | ||
|
|
2ba69bc8fa | ||
|
|
cb349a291c | ||
|
|
f0a0b077a0 | ||
|
|
49114f28ab | ||
|
|
e7d2ebadd2 | ||
|
|
e569803d7c | ||
|
|
d34f887abc | ||
|
|
5201df5030 | ||
|
|
2d3c26c565 | ||
|
|
6325f60d52 | ||
|
|
42cc9cfbc7 | ||
|
|
8347575ba1 | ||
|
|
b6eec20260 | ||
|
|
0eb035b4b1 | ||
|
|
afb9817599 | ||
|
|
8fb3e7a26c | ||
|
|
8c7bc30ce4 | ||
|
|
85873cb162 | ||
|
|
e14ee12491 | ||
|
|
8b9a3f3cea | ||
|
|
fb4e8b896b | ||
|
|
2ca533b279 | ||
|
|
4a9b29a875 | ||
|
|
a4fcc893cd | ||
|
|
9d10239af7 | ||
|
|
19facd4b20 | ||
|
|
f5299f72cd | ||
|
|
0e0d9ac522 | ||
|
|
8917022deb | ||
|
|
ec0d5db67b | ||
|
|
e76e9b87f0 | ||
|
|
cfb6a244ea | ||
|
|
58f3860306 | ||
|
|
dd4f53db63 | ||
|
|
3d5e17e507 | ||
|
|
33bf1a244b | ||
|
|
772f471ff2 | ||
|
|
2c11d10f8d | ||
|
|
656ed7f780 | ||
|
|
81bb9a2a9e | ||
|
|
5adf185f86 | ||
|
|
c9a9180584 | ||
|
|
76831ed83d | ||
|
|
b3d7b85376 | ||
|
|
cad5c0241c | ||
|
|
b8022c578a | ||
|
|
bc53f8293f | ||
|
|
c552ff2451 | ||
|
|
4fda5fbdf9 | ||
|
|
580776559b | ||
|
|
a14aaa7c9d | ||
|
|
a6d780154f | ||
|
|
6871e2eeb7 | ||
|
|
8402a2acf4 | ||
|
|
fddb6933e1 | ||
|
|
c8b4787e4e | ||
|
|
2188199ff8 | ||
|
|
aa07429bad | ||
|
|
918761a25a | ||
|
|
a4fc671d3e | ||
|
|
f5f65ef48c | ||
|
|
c2dd81a8aa | ||
|
|
d7e680ffe4 | ||
|
|
c371baf53a | ||
|
|
ccf78f566c | ||
|
|
c9fa68664a | ||
|
|
c35f4d089a | ||
|
|
8590c0941e | ||
|
|
095163b8d1 | ||
|
|
99c33d011d | ||
|
|
62fecf3e13 | ||
|
|
7c4eb5d03e | ||
|
|
bae9a6b404 | ||
|
|
004c1d8ef2 | ||
|
|
7ebb2e0193 | ||
|
|
9ce77798b1 | ||
|
|
f8bad60609 | ||
|
|
5866b3857b | ||
|
|
1ca616844b | ||
|
|
2e8cf0b450 | ||
|
|
24f89173d1 | ||
|
|
c6a20b427a | ||
|
|
a5ac9244c4 | ||
|
|
c763fe1be0 | ||
|
|
52dc8c8cd5 | ||
|
|
aede70e81d | ||
|
|
85a8beb5e4 | ||
|
|
0bb89e9e5f | ||
|
|
5685ceb3c7 | ||
|
|
0408ba0a76 | ||
|
|
cbad6c3093 | ||
|
|
1b021f6984 | ||
|
|
95b7551d65 | ||
|
|
db5a7c6192 | ||
|
|
6ef2f67e7f | ||
|
|
f76ee1ffd2 | ||
|
|
54a71f270a | ||
|
|
55b4062dd8 | ||
|
|
79071bfba4 | ||
|
|
7774b87cbd | ||
|
|
35c87741cf | ||
|
|
4cbe605214 | ||
|
|
ab8883dd55 | ||
|
|
eebe73001a | ||
|
|
0359bf02c9 | ||
|
|
237f9e58a8 | ||
|
|
8576e6fe36 | ||
|
|
0654543dcc | ||
|
|
48ef3e74e2 | ||
|
|
7d4b378952 | ||
|
|
7ff5c41e06 | ||
|
|
602f43e3d1 | ||
|
|
a2cadb8218 | ||
|
|
c1eb9d05d9 | ||
|
|
cf6c939e86 | ||
|
|
130df35e1b | ||
|
|
0751263dec | ||
|
|
eca2f3eb97 | ||
|
|
3aa9cf3f9e | ||
|
|
8f3d208dce | ||
|
|
caaa3f1f8c | ||
|
|
659a51919f | ||
|
|
6661387066 | ||
|
|
a7fae8a176 | ||
|
|
0cae0bdac8 | ||
|
|
5a1a5d5ed1 | ||
|
|
1683975acf | ||
|
|
af705590ac | ||
|
|
825124af8f | ||
|
|
9c5e7da507 | ||
|
|
481349495b | ||
|
|
9daa6b003f | ||
|
|
a3a632d567 | ||
|
|
e496c5a4b4 | ||
|
|
ea890d8710 | ||
|
|
aa5d84f102 | ||
|
|
f1606486d2 | ||
|
|
87720a8908 | ||
|
|
bb6565ef14 | ||
|
|
7bb063bcb3 | ||
|
|
b36dd472bb | ||
|
|
167b759a38 | ||
|
|
99b9868859 | ||
|
|
6b2d5448f2 | ||
|
|
eaf709b83e | ||
|
|
f0e70afff0 | ||
|
|
86984cad68 | ||
|
|
fbc89e3ced | ||
|
|
38c1e720c2 | ||
|
|
600e87e03c | ||
|
|
3836445241 | ||
|
|
1d2c9d6a07 | ||
|
|
e8ac6bd2f5 | ||
|
|
fdadc4f22c | ||
|
|
79b527f45f | ||
|
|
dc4eada7f0 | ||
|
|
70ebc3b598 | ||
|
|
b13f2aed16 | ||
|
|
5f04c0f818 | ||
|
|
55935ccae7 | ||
|
|
b529515eb1 | ||
|
|
3cde719eb7 | ||
|
|
5de6d94a90 | ||
|
|
99eefd2ec0 | ||
|
|
e9e268336b | ||
|
|
7275ac7523 | ||
|
|
c4189a38e4 | ||
|
|
68d1b3256b | ||
|
|
9c6953bda7 | ||
|
|
ef7ece9851 | ||
|
|
ddaa4b7dcb | ||
|
|
dfae2c6989 | ||
|
|
515f104926 | ||
|
|
9ecefd56db | ||
|
|
e5d35aa187 | ||
|
|
00794c42bc | ||
|
|
08a1bf3f10 | ||
|
|
60c4154346 | ||
|
|
f2c85308c1 | ||
|
|
1a28b69ee2 | ||
|
|
ba09f01ce8 | ||
|
|
6cf48872b7 | ||
|
|
7b3b8fa000 | ||
|
|
ec5e2aae61 | ||
|
|
86389bf970 | ||
|
|
3290bfa690 | ||
|
|
8777fd104f | ||
|
|
c41f7565ed | ||
|
|
9ba81e3da4 | ||
|
|
c23888acd7 | ||
|
|
f98ce25ab9 | ||
|
|
de5f38fd48 | ||
|
|
ec2854b13a | ||
|
|
90823d2938 | ||
|
|
5f5770e3a2 | ||
|
|
28f39e9038 | ||
|
|
b2d2b37888 | ||
|
|
fe597e141c | ||
|
|
72ca1539e0 | ||
|
|
13b26775f1 | ||
|
|
05d7118561 | ||
|
|
98b901ad66 | ||
|
|
5580b47291 | ||
|
|
bc62932984 | ||
|
|
a6b5d6e759 | ||
|
|
a8931306e1 | ||
|
|
fecdb8717e | ||
|
|
916fd273ea | ||
|
|
0da8506552 | ||
|
|
eda7a7b43e | ||
|
|
022eabb734 | ||
|
|
aba899cef8 | ||
|
|
6a40e1c176 | ||
|
|
9307b2ab8b | ||
|
|
522d8d3917 | ||
|
|
a84cc0123f | ||
|
|
f018e248cd | ||
|
|
cfd7237a80 | ||
|
|
4eef8102c9 | ||
|
|
69e4dd506b | ||
|
|
25814a9458 | ||
|
|
2a980a76ce | ||
|
|
d343782c8b | ||
|
|
4e1994e9d7 | ||
|
|
65a38c452b | ||
|
|
7b7e2352cd | ||
|
|
1177d28395 | ||
|
|
005e7efa64 | ||
|
|
b42d13ec84 | ||
|
|
9adcd1a650 | ||
|
|
3c164fca8c | ||
|
|
95e335db7b | ||
|
|
f90206ad74 | ||
|
|
3779150750 |
@@ -1,418 +0,0 @@
|
|||||||
version: 2.1
|
|
||||||
|
|
||||||
orbs:
|
|
||||||
apple: ml-explore/pr-approval@0.1.0
|
|
||||||
|
|
||||||
parameters:
|
|
||||||
nightly_build:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
weekly_build:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
test_release:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
linux_release:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build_documentation:
|
|
||||||
parameters:
|
|
||||||
upload-docs:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
macos:
|
|
||||||
xcode: "15.2.0"
|
|
||||||
resource_class: macos.m1.medium.gen1
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install
|
|
||||||
command: |
|
|
||||||
brew install python@3.9
|
|
||||||
brew install doxygen
|
|
||||||
python3.9 -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install -r docs/requirements.txt
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
|
||||||
- when:
|
|
||||||
condition:
|
|
||||||
not: << parameters.upload-docs >>
|
|
||||||
steps:
|
|
||||||
- run:
|
|
||||||
name: Build documentation
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
cd docs && doxygen && make html O=-W
|
|
||||||
- when:
|
|
||||||
condition: << parameters.upload-docs >>
|
|
||||||
steps:
|
|
||||||
- add_ssh_keys:
|
|
||||||
fingerprints:
|
|
||||||
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
|
||||||
- run:
|
|
||||||
name: Upload documentation
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
git config user.email "mlx@group.apple.com"
|
|
||||||
git config user.name "CircleCI Docs"
|
|
||||||
git checkout gh-pages
|
|
||||||
git rebase main
|
|
||||||
cd docs
|
|
||||||
git rm -rf build/html
|
|
||||||
doxygen && make html O=-W
|
|
||||||
git add -f build/html
|
|
||||||
git commit -m "rebase"
|
|
||||||
git push -f origin gh-pages
|
|
||||||
|
|
||||||
linux_build_and_test:
|
|
||||||
docker:
|
|
||||||
- image: cimg/python:3.9
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Run style checks
|
|
||||||
command: |
|
|
||||||
pip install pre-commit
|
|
||||||
pre-commit run --all
|
|
||||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install nanobind==2.4.0
|
|
||||||
pip install numpy
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
- run:
|
|
||||||
name: Install Python package
|
|
||||||
command: |
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF
|
|
||||||
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
python3 setup.py build_ext --inplace
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF \
|
|
||||||
CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
python3 setup.py develop
|
|
||||||
- run:
|
|
||||||
name: Generate package stubs
|
|
||||||
command: |
|
|
||||||
echo "stubs"
|
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
|
||||||
name: Run Python tests
|
|
||||||
command: |
|
|
||||||
python3 -m unittest discover python/tests -v
|
|
||||||
- run:
|
|
||||||
name: Build CPP only
|
|
||||||
command: |
|
|
||||||
mkdir -p build && cd build
|
|
||||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
|
||||||
make -j `nproc`
|
|
||||||
- run:
|
|
||||||
name: Run CPP tests
|
|
||||||
command: ./build/tests/tests
|
|
||||||
|
|
||||||
mac_build_and_test:
|
|
||||||
parameters:
|
|
||||||
xcode_version:
|
|
||||||
type: string
|
|
||||||
default: "15.2.0"
|
|
||||||
macos:
|
|
||||||
xcode: << parameters.xcode_version >>
|
|
||||||
resource_class: macos.m1.medium.gen1
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
brew install python@3.9
|
|
||||||
brew install openmpi
|
|
||||||
python3.9 -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install nanobind==2.4.0
|
|
||||||
pip install numpy
|
|
||||||
pip install torch
|
|
||||||
pip install tensorflow
|
|
||||||
pip install unittest-xml-reporting
|
|
||||||
- run:
|
|
||||||
name: Install Python package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
CMAKE_ARGS="CMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
|
||||||
pip install -e . -v
|
|
||||||
- run:
|
|
||||||
name: Generate package stubs
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
|
||||||
name: Run Python tests
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
|
||||||
- run:
|
|
||||||
name: Build example extension
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
cd examples/extensions
|
|
||||||
pip install -r requirements.txt
|
|
||||||
python setup.py build_ext -j8
|
|
||||||
- store_test_results:
|
|
||||||
path: test-results
|
|
||||||
- run:
|
|
||||||
name: Build CPP only
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
|
||||||
- run:
|
|
||||||
name: Run CPP tests
|
|
||||||
command: |
|
|
||||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
|
||||||
- run:
|
|
||||||
name: Build small binary
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
cd build/
|
|
||||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
|
||||||
-DBUILD_SHARED_LIBS=ON \
|
|
||||||
-DMLX_BUILD_CPU=OFF \
|
|
||||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
|
||||||
-DMLX_BUILD_GGUF=OFF \
|
|
||||||
-DMLX_METAL_JIT=ON
|
|
||||||
make -j `sysctl -n hw.ncpu`
|
|
||||||
- run:
|
|
||||||
name: Run Python tests with JIT
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
|
||||||
pip install -e . -v
|
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
|
||||||
METAL_DEBUG_ERROR_MODE=0 \
|
|
||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
|
||||||
|
|
||||||
build_release:
|
|
||||||
parameters:
|
|
||||||
python_version:
|
|
||||||
type: string
|
|
||||||
default: "3.9"
|
|
||||||
xcode_version:
|
|
||||||
type: string
|
|
||||||
default: "15.2.0"
|
|
||||||
build_env:
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
macos:
|
|
||||||
xcode: << parameters.xcode_version >>
|
|
||||||
resource_class: macos.m1.medium.gen1
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
brew install python@<< parameters.python_version >>
|
|
||||||
brew install openmpi
|
|
||||||
python<< parameters.python_version >> -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install nanobind==2.4.0
|
|
||||||
pip install --upgrade setuptools
|
|
||||||
pip install numpy
|
|
||||||
pip install twine
|
|
||||||
pip install build
|
|
||||||
- run:
|
|
||||||
name: Install Python package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
DEV_RELEASE=1 \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
pip install . -v
|
|
||||||
- run:
|
|
||||||
name: Generate package stubs
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
|
||||||
name: Build Python package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
<< parameters.build_env >> \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
python -m build -w
|
|
||||||
- when:
|
|
||||||
condition: << parameters.build_env >>
|
|
||||||
steps:
|
|
||||||
- run:
|
|
||||||
name: Upload package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
twine upload dist/*
|
|
||||||
- store_artifacts:
|
|
||||||
path: dist/
|
|
||||||
|
|
||||||
build_linux_release:
|
|
||||||
parameters:
|
|
||||||
python_version:
|
|
||||||
type: string
|
|
||||||
default: "3.9"
|
|
||||||
extra_env:
|
|
||||||
type: string
|
|
||||||
default: "DEV_RELEASE=1"
|
|
||||||
docker:
|
|
||||||
- image: ubuntu:20.04
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Build wheel
|
|
||||||
command: |
|
|
||||||
PYTHON=python<< parameters.python_version >>
|
|
||||||
apt-get update
|
|
||||||
apt-get upgrade -y
|
|
||||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
|
||||||
apt-get install -y apt-utils
|
|
||||||
apt-get install -y software-properties-common
|
|
||||||
add-apt-repository -y ppa:deadsnakes/ppa
|
|
||||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
|
||||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
apt-get install -y build-essential git
|
|
||||||
$PYTHON -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install nanobind==2.4.0
|
|
||||||
pip install --upgrade setuptools
|
|
||||||
pip install numpy
|
|
||||||
pip install auditwheel
|
|
||||||
pip install patchelf
|
|
||||||
pip install build
|
|
||||||
pip install twine
|
|
||||||
<< parameters.extra_env >> \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
pip install . -v
|
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
|
||||||
<< parameters.extra_env >> \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
python -m build --wheel
|
|
||||||
auditwheel show dist/*
|
|
||||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
|
||||||
- run:
|
|
||||||
name: Upload package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
twine upload wheelhouse/*
|
|
||||||
- store_artifacts:
|
|
||||||
path: wheelhouse/
|
|
||||||
|
|
||||||
workflows:
|
|
||||||
build_and_test:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- matches:
|
|
||||||
pattern: "^(?!pull/)[-\\w]+$"
|
|
||||||
value: << pipeline.git.branch >>
|
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
|
||||||
jobs:
|
|
||||||
- mac_build_and_test:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
|
||||||
- linux_build_and_test
|
|
||||||
- build_documentation
|
|
||||||
|
|
||||||
build_pypi_release:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
|
||||||
jobs:
|
|
||||||
- build_release:
|
|
||||||
filters:
|
|
||||||
tags:
|
|
||||||
only: /^v.*/
|
|
||||||
branches:
|
|
||||||
ignore: /.*/
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
|
||||||
xcode_version: ["15.0.0", "15.2.0"]
|
|
||||||
build_env: ["PYPI_RELEASE=1"]
|
|
||||||
- build_documentation:
|
|
||||||
filters:
|
|
||||||
tags:
|
|
||||||
only: /^v.*/
|
|
||||||
branches:
|
|
||||||
ignore: /.*/
|
|
||||||
upload-docs: true
|
|
||||||
|
|
||||||
prb:
|
|
||||||
when:
|
|
||||||
matches:
|
|
||||||
pattern: "^pull/\\d+(/head)?$"
|
|
||||||
value: << pipeline.git.branch >>
|
|
||||||
jobs:
|
|
||||||
- hold:
|
|
||||||
type: approval
|
|
||||||
- apple/authenticate:
|
|
||||||
context: pr-approval
|
|
||||||
- mac_build_and_test:
|
|
||||||
requires: [ hold ]
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
|
||||||
- linux_build_and_test:
|
|
||||||
requires: [ hold ]
|
|
||||||
nightly_build:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
|
||||||
- << pipeline.parameters.nightly_build >>
|
|
||||||
jobs:
|
|
||||||
- build_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
|
||||||
xcode_version: ["15.0.0", "15.2.0"]
|
|
||||||
weekly_build:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
|
||||||
- << pipeline.parameters.weekly_build >>
|
|
||||||
jobs:
|
|
||||||
- build_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
|
||||||
build_env: ["DEV_RELEASE=1"]
|
|
||||||
linux_test_release:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
|
||||||
- << pipeline.parameters.linux_release >>
|
|
||||||
jobs:
|
|
||||||
- build_linux_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
|
||||||
24
.github/actions/build-cuda-release/action.yml
vendored
Normal file
24
.github/actions/build-cuda-release/action.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
name: 'Build CUDA wheel'
|
||||||
|
description: 'Build CUDA wheel'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
arch:
|
||||||
|
description: 'Platform architecture tag'
|
||||||
|
required: true
|
||||||
|
type: choice
|
||||||
|
options:
|
||||||
|
- x86_64
|
||||||
|
- aarch64
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Build package
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON
|
||||||
|
run: |
|
||||||
|
pip install auditwheel build patchelf setuptools
|
||||||
|
python setup.py clean --all
|
||||||
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
|
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}
|
||||||
38
.github/actions/build-docs/action.yml
vendored
Normal file
38
.github/actions/build-docs/action.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
name: 'Build Documentation'
|
||||||
|
description: 'Build documentation'
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Setup machine
|
||||||
|
uses: ./.github/actions/setup-linux
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
sudo apt-get install -y doxygen
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install -r docs/requirements.txt
|
||||||
|
pip install . -v
|
||||||
|
|
||||||
|
- name: Build documentation
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
cd docs
|
||||||
|
doxygen
|
||||||
|
make html O=-W
|
||||||
|
|
||||||
|
- name: Create artifact tar
|
||||||
|
shell: bash
|
||||||
|
run: tar -cf artifact.tar -C docs --dereference build/html index.html
|
||||||
|
|
||||||
|
# Do it manually because upload-pages-artifact requires gtar
|
||||||
|
- name: Upload artifact
|
||||||
|
id: upload-artifact
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
name: github-pages
|
||||||
|
path: artifact.tar
|
||||||
|
retention-days: 1
|
||||||
|
if-no-files-found: error
|
||||||
40
.github/actions/build-linux-release/action.yml
vendored
Normal file
40
.github/actions/build-linux-release/action.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
name: 'Build Linux wheel'
|
||||||
|
description: 'Build Linux wheel'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
build-backend:
|
||||||
|
description: 'Build the backend mlx-cpu package'
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
arch:
|
||||||
|
description: 'Platform architecture tag'
|
||||||
|
required: true
|
||||||
|
type: choice
|
||||||
|
options:
|
||||||
|
- x86_64
|
||||||
|
- aarch64
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Generate package stubs
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
pip install -e ".[dev]" -v
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
- name: Build Python package
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
pip install auditwheel patchelf build
|
||||||
|
python setup.py clean --all
|
||||||
|
MLX_BUILD_STAGE=1 python -m build -w
|
||||||
|
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
|
||||||
|
- name: Build backend package
|
||||||
|
if: ${{ inputs.build-backend }}
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python setup.py clean --all
|
||||||
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
|
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}
|
||||||
41
.github/actions/build-linux/action.yml
vendored
Normal file
41
.github/actions/build-linux/action.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
name: 'Build and Test on Linux'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
toolkit:
|
||||||
|
description: 'The toolkit to build with'
|
||||||
|
required: false
|
||||||
|
default: 'cpu'
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install Python package
|
||||||
|
id: python_build
|
||||||
|
shell: sh
|
||||||
|
env:
|
||||||
|
DEBUG: 1
|
||||||
|
CMAKE_ARGS: >-
|
||||||
|
-DCMAKE_COMPILE_WARNING_AS_ERROR=ON
|
||||||
|
-DMLX_BUILD_CUDA=${{ startsWith(inputs.toolkit, 'cuda') && 'ON' || 'OFF' }}
|
||||||
|
run: |
|
||||||
|
if ${{ startsWith(inputs.toolkit, 'cuda') && runner.arch == 'arm64' }} ; then
|
||||||
|
# There is no GPU in arm64 runner, use a common arch.
|
||||||
|
CMAKE_ARGS="$CMAKE_ARGS -DMLX_CUDA_ARCHITECTURES=90a"
|
||||||
|
# Can not build tests when the built executables can not run.
|
||||||
|
CMAKE_ARGS="$CMAKE_ARGS -DMLX_BUILD_TESTS=OFF"
|
||||||
|
fi
|
||||||
|
pip install --no-build-isolation -e ".[dev]" -v
|
||||||
|
# Pass the CMAKE_ARGS to following steps.
|
||||||
|
echo CMAKE_ARGS="$CMAKE_ARGS" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Generate package stubs
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
|
||||||
|
- name: Build CPP only
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cmake . -B build -DCMAKE_BUILD_TYPE=Debug ${{ steps.python_build.outputs.CMAKE_ARGS }}
|
||||||
|
cmake --build build -j $(nproc)
|
||||||
34
.github/actions/build-macos-release/action.yml
vendored
Normal file
34
.github/actions/build-macos-release/action.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
name: 'Build macOS release'
|
||||||
|
description: 'Build MLX releases macOS'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
macos-target:
|
||||||
|
description: 'macOS build target'
|
||||||
|
required: false
|
||||||
|
default: '15.0'
|
||||||
|
build-backend:
|
||||||
|
description: 'Build the backend mlx-metal package'
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Build Python package
|
||||||
|
shell: bash -l {0}
|
||||||
|
env:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||||
|
run: |
|
||||||
|
pip install build
|
||||||
|
python setup.py clean --all
|
||||||
|
MLX_BUILD_STAGE=1 python -m build -w
|
||||||
|
|
||||||
|
- name: Build backend package
|
||||||
|
if: ${{ inputs.build-backend }}
|
||||||
|
shell: bash -l {0}
|
||||||
|
env:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: ${{ inputs.macos-target }}
|
||||||
|
run: |
|
||||||
|
python setup.py clean --all
|
||||||
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
88
.github/actions/build-macos/action.yml
vendored
Normal file
88
.github/actions/build-macos/action.yml
vendored
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
name: 'Build and Test on macOS'
|
||||||
|
description: 'Build and test MLX on macOS'
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install dependencies
|
||||||
|
env:
|
||||||
|
DEBUG: 1
|
||||||
|
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install cmake setuptools nanobind==2.4.0
|
||||||
|
pip install -e . -v
|
||||||
|
|
||||||
|
- name: Generate package stubs
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
|
||||||
|
- name: Install tests dependencies
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install numpy torch tensorflow unittest-xml-reporting
|
||||||
|
|
||||||
|
- name: Run Python tests
|
||||||
|
shell: bash -l {0}
|
||||||
|
env:
|
||||||
|
LOW_MEMORY: 1
|
||||||
|
run: |
|
||||||
|
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
|
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
|
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||||
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||||
|
|
||||||
|
- name: Build example extension
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
cd examples/extensions
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python setup.py build_ext --inplace
|
||||||
|
python test.py
|
||||||
|
|
||||||
|
- name: Build CPP only
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
mkdir -p build
|
||||||
|
cd build
|
||||||
|
cmake ..
|
||||||
|
make -j $(sysctl -n hw.ncpu)
|
||||||
|
|
||||||
|
- name: Run CPP tests
|
||||||
|
shell: bash -l {0}
|
||||||
|
env:
|
||||||
|
DEVICE: gpu
|
||||||
|
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||||
|
METAL_DEBUG_ERROR_MODE: 0
|
||||||
|
run: ./build/tests/tests
|
||||||
|
|
||||||
|
- name: Build small binary with JIT
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
mkdir -p build
|
||||||
|
cd build
|
||||||
|
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
|
-DMLX_BUILD_CPU=OFF \
|
||||||
|
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||||
|
-DMLX_BUILD_GGUF=OFF \
|
||||||
|
-DMLX_METAL_JIT=ON
|
||||||
|
make -j $(sysctl -n hw.ncpu)
|
||||||
|
|
||||||
|
- name: Run Python tests with JIT
|
||||||
|
shell: bash -l {0}
|
||||||
|
env:
|
||||||
|
LOW_MEMORY: 1
|
||||||
|
DEVICE: gpu
|
||||||
|
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||||
|
METAL_DEBUG_ERROR_MODE: 0
|
||||||
|
run: |
|
||||||
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||||
|
pip install -e . -v
|
||||||
|
python -m xmlrunner discover \
|
||||||
|
-v python/tests \
|
||||||
|
-o test-results/gpu_jit
|
||||||
87
.github/actions/setup-linux/action.yml
vendored
Normal file
87
.github/actions/setup-linux/action.yml
vendored
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
name: 'Setup Linux Environment'
|
||||||
|
description: 'Install dependencies for Linux builds'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
toolkit:
|
||||||
|
description: 'Which toolkit to install'
|
||||||
|
required: false
|
||||||
|
default: 'cpu'
|
||||||
|
python-version:
|
||||||
|
description: 'Version of python to set up'
|
||||||
|
required: false
|
||||||
|
default: '3.10'
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Use ccache
|
||||||
|
if: ${{ runner.arch == 'x86_64' }}
|
||||||
|
uses: hendrikmuhs/ccache-action@v1.2
|
||||||
|
with:
|
||||||
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
||||||
|
max-size: 1GB
|
||||||
|
|
||||||
|
- name: Install common dependencies
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version: ${{ inputs.python-version }}
|
||||||
|
|
||||||
|
- name: Setup Python venv
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install setuptools cmake nanobind==2.4.0
|
||||||
|
echo PATH=$PATH >> $GITHUB_ENV
|
||||||
|
# Make cmake search .venv for nanobind
|
||||||
|
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Install MPI
|
||||||
|
shell: bash
|
||||||
|
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
|
||||||
|
|
||||||
|
- name: Install CUDA toolkit
|
||||||
|
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
# Note: the CI machine does not meet CUDA 13's driver requirement.
|
||||||
|
# Compatibility matrix:
|
||||||
|
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
||||||
|
PACKAGES: |
|
||||||
|
{
|
||||||
|
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
|
||||||
|
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
|
||||||
|
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
|
||||||
|
}
|
||||||
|
run: |
|
||||||
|
# The CUDA binaries are hosted in the "sbsa" repo, the "arm64" repo is
|
||||||
|
# Jetson specific. SBSA means Arm Server Base System Architecture.
|
||||||
|
ARCH=${{ runner.arch == 'arm64' && 'sbsa' || 'x86_64' }}
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y \
|
||||||
|
libnccl2 libnccl-dev \
|
||||||
|
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
|
||||||
|
echo "/usr/local/${{ inputs.toolkit }}/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
|
- name: CUDA packages and driver report
|
||||||
|
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
sudo apt-get install -y ubuntu-drivers-common dkms
|
||||||
|
echo "NVIDIA Driver Packages Available:"
|
||||||
|
sudo ubuntu-drivers list --gpgpu
|
||||||
|
echo "NVIDIA Driver Version:"
|
||||||
|
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
|
||||||
|
echo "Installed NVIDIA and CUDA packages:"
|
||||||
|
dpkg -l | egrep "cuda|nvidia" -i
|
||||||
|
echo "DKMS Status:"
|
||||||
|
dkms status || echo "dkms not found"
|
||||||
|
echo "NVIDIA-SMI Status:"
|
||||||
|
nvidia-smi || echo "nvidia-smi not found"
|
||||||
24
.github/actions/setup-macos/action.yml
vendored
Normal file
24
.github/actions/setup-macos/action.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
name: 'Setup macOS Environment'
|
||||||
|
description: 'Install dependencies for macOS builds'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
python-version:
|
||||||
|
description: 'Python version to use'
|
||||||
|
required: false
|
||||||
|
default: '3.10'
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install Homebrew packages
|
||||||
|
shell: sh
|
||||||
|
run: /opt/homebrew/bin/brew install openmpi
|
||||||
|
|
||||||
|
- name: Verify MetalToolchain installed
|
||||||
|
shell: bash
|
||||||
|
run: xcodebuild -showComponent MetalToolchain
|
||||||
|
|
||||||
|
- uses: conda-incubator/setup-miniconda@v3
|
||||||
|
with:
|
||||||
|
miniconda-version: "latest"
|
||||||
|
python-version: ${{ inputs.python-version }}
|
||||||
69
.github/actions/test-linux/action.yml
vendored
Normal file
69
.github/actions/test-linux/action.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
name: 'Run Linux tests'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
has-gpu:
|
||||||
|
description: 'Run GPU tests'
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Run MPI tests
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "::group::MPI tests"
|
||||||
|
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run distributed tests
|
||||||
|
if: ${{ inputs.has-gpu == 'false' }}
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "::group::Distributed tests"
|
||||||
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if grep -Fq '[WARN]' stderr.log ; then
|
||||||
|
grep -F '[WARN]' stderr.log
|
||||||
|
echo "Distributed ring test failed";
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run Python tests - CPU
|
||||||
|
if: ${{ inputs.has-gpu == 'false' }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: cpu
|
||||||
|
run: |
|
||||||
|
echo "::group::Python tests - CPU"
|
||||||
|
python -m unittest discover python/tests -v
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run Python tests - GPU
|
||||||
|
if: ${{ inputs.has-gpu == 'true' }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: gpu
|
||||||
|
run: |
|
||||||
|
echo "::group::Python tests - GPU"
|
||||||
|
python -m tests discover python/tests -v
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run CPP tests - CPU
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: cpu
|
||||||
|
run: |
|
||||||
|
echo "::group::CPP tests - CPU"
|
||||||
|
./build/tests/tests
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run CPP tests - GPU
|
||||||
|
if: ${{ inputs.has-gpu == 'true' }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: gpu
|
||||||
|
run: |
|
||||||
|
echo "::group::CPP tests - GPU"
|
||||||
|
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||||
|
echo "::endgroup::"
|
||||||
6
.github/dependabot.yml
vendored
Normal file
6
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "github-actions"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
27
.github/scripts/setup+build-cpp-linux-fedora-container.sh
vendored
Executable file
27
.github/scripts/setup+build-cpp-linux-fedora-container.sh
vendored
Executable file
@@ -0,0 +1,27 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# [Setup] Install dependencies inside the container.
|
||||||
|
dnf update -y
|
||||||
|
dnf install -y \
|
||||||
|
blas-devel \
|
||||||
|
lapack-devel \
|
||||||
|
openblas-devel \
|
||||||
|
make \
|
||||||
|
cmake \
|
||||||
|
clang \
|
||||||
|
git
|
||||||
|
dnf clean all
|
||||||
|
|
||||||
|
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
|
||||||
|
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||||
|
export DEBUG=1
|
||||||
|
export CMAKE_C_COMPILER=/usr/bin/clang
|
||||||
|
export CMAKE_CXX_COMPILER=/usr/bin/clang++
|
||||||
|
|
||||||
|
mkdir -p build
|
||||||
|
pushd build
|
||||||
|
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||||
|
make -j $(nproc)
|
||||||
|
./tests/tests
|
||||||
|
popd
|
||||||
108
.github/workflows/build_and_test.yml
vendored
Normal file
108
.github/workflows/build_and_test.yml
vendored
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
name: Build and Test
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
# For testing CI without starting a pull request:
|
||||||
|
- test/*
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check_lint:
|
||||||
|
name: Check Lint
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: pre-commit/action@v3.0.1
|
||||||
|
|
||||||
|
linux_build_and_test:
|
||||||
|
name: Linux (cpu, ${{ matrix.arch }})
|
||||||
|
needs: check_lint
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: ['x86_64', 'aarch64']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
- uses: ./.github/actions/build-linux
|
||||||
|
- uses: ./.github/actions/test-linux
|
||||||
|
|
||||||
|
cuda_build_and_test:
|
||||||
|
name: Linux (${{ matrix.toolkit }}, ${{ matrix.arch }})
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
needs: check_lint
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: ['x86_64', 'aarch64']
|
||||||
|
toolkit: ['cuda-12.6', 'cuda-12.9']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'gpu-t4-4-core' || 'ubuntu-22.04-arm' }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
with:
|
||||||
|
toolkit: ${{ matrix.toolkit }}
|
||||||
|
- uses: ./.github/actions/build-linux
|
||||||
|
with:
|
||||||
|
toolkit: ${{ matrix.toolkit }}
|
||||||
|
- uses: ./.github/actions/test-linux
|
||||||
|
if: matrix.arch == 'x86_64'
|
||||||
|
with:
|
||||||
|
has-gpu: true
|
||||||
|
|
||||||
|
mac_build_and_test:
|
||||||
|
name: macOS (${{ matrix.macos-target }})
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
macos-target: ["14.0", "15.0"]
|
||||||
|
runs-on: [self-hosted, macos]
|
||||||
|
env:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
|
||||||
|
needs: check_lint
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-macos
|
||||||
|
- uses: ./.github/actions/build-macos
|
||||||
|
|
||||||
|
build_documentation:
|
||||||
|
name: Build Documentation
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
needs: check_lint
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
|
linux_fedora_build_cpp:
|
||||||
|
name: Linux Fedora (${{ matrix.arch }})
|
||||||
|
needs: check_lint
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- host: ubuntu-22.04
|
||||||
|
arch: x86_64
|
||||||
|
- host: ubuntu-22.04-arm
|
||||||
|
arch: aarch64
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.host }}
|
||||||
|
container:
|
||||||
|
image: fedora:42
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: CPP Build Test - No Release
|
||||||
|
run: |
|
||||||
|
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
|
||||||
28
.github/workflows/documentation.yml
vendored
Normal file
28
.github/workflows/documentation.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
name: Documentation
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
|
deploy:
|
||||||
|
needs: build
|
||||||
|
permissions:
|
||||||
|
pages: write
|
||||||
|
id-token: write
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
environment:
|
||||||
|
name: github-pages
|
||||||
|
url: ${{ steps.deployment.outputs.page_url }}
|
||||||
|
steps:
|
||||||
|
- name: Deploy to GitHub Pages
|
||||||
|
id: deployment
|
||||||
|
uses: actions/deploy-pages@v4
|
||||||
96
.github/workflows/nightly.yml
vendored
Normal file
96
.github/workflows/nightly.yml
vendored
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
name: Nightly Build
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: 33 6 * * 1-5
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build_linux_release:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python_version: ["3.10", "3.14"]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
- uses: ./.github/actions/build-linux-release
|
||||||
|
with:
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
arch: "x86_64"
|
||||||
|
- name: Upload mlx artifacts
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
name: linux-wheels-${{ matrix.python_version }}
|
||||||
|
path: wheelhouse/mlx-*.whl
|
||||||
|
retention-days: 7
|
||||||
|
- name: Upload mlx-cpu artifacts
|
||||||
|
if: matrix.python_version == '3.10'
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
name: mlx-cpu
|
||||||
|
path: wheelhouse/mlx_cpu-*.whl
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
build_linux_with_tests:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python_version: ["3.11", "3.12", "3.13", "3.14"]
|
||||||
|
runner:
|
||||||
|
- ubuntu-22.04
|
||||||
|
- ubuntu-22.04-arm
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python_version }}
|
||||||
|
- uses: ./.github/actions/build-linux
|
||||||
|
- uses: ./.github/actions/test-linux
|
||||||
|
|
||||||
|
build_mac_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10", "3.13"]
|
||||||
|
runs-on: [self-hosted, macos]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-macos
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- uses: ./.github/actions/build-macos
|
||||||
|
- name: Build macOS 15 package
|
||||||
|
uses: ./.github/actions/build-macos-release
|
||||||
|
with:
|
||||||
|
macos-target: 15.0
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
- name: Build macOS 14 package
|
||||||
|
uses: ./.github/actions/build-macos-release
|
||||||
|
with:
|
||||||
|
macos-target: 14.0
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
|
||||||
|
build_cuda_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
runs-on: ubuntu-22-large
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
with:
|
||||||
|
toolkit: 'cuda-12.9'
|
||||||
|
- name: Build Python package
|
||||||
|
uses: ./.github/actions/build-cuda-release
|
||||||
|
with:
|
||||||
|
toolkit: 'cuda-12.9'
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
name: mlx-cuda
|
||||||
|
path: wheelhouse/mlx_cuda-*.whl
|
||||||
|
retention-days: 7
|
||||||
20
.github/workflows/pull_request.yml
vendored
20
.github/workflows/pull_request.yml
vendored
@@ -1,20 +0,0 @@
|
|||||||
on:
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check_lint:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: 3.8
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install pre-commit black isort clang-format
|
|
||||||
- name: Run lint
|
|
||||||
run: |
|
|
||||||
pre-commit run --all-files
|
|
||||||
244
.github/workflows/release.yml
vendored
Normal file
244
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
name: PyPI Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
dev_release:
|
||||||
|
description: "Do a dev release or regular release"
|
||||||
|
required: true
|
||||||
|
default: "false"
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
setup:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Set publishing variables
|
||||||
|
run: echo "Publishing setup complete"
|
||||||
|
|
||||||
|
build_documentation:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
|
deploy_documentation:
|
||||||
|
needs: build_documentation
|
||||||
|
permissions:
|
||||||
|
pages: write
|
||||||
|
id-token: write
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
environment:
|
||||||
|
name: github-pages
|
||||||
|
url: ${{ steps.deployment.outputs.page_url }}
|
||||||
|
steps:
|
||||||
|
- name: Deploy to GitHub Pages
|
||||||
|
id: deployment
|
||||||
|
uses: actions/deploy-pages@v4
|
||||||
|
|
||||||
|
build_linux_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
|
arch: ['x86_64', 'aarch64']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
||||||
|
env:
|
||||||
|
PYPI_RELEASE: 1
|
||||||
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python_version }}
|
||||||
|
- uses: ./.github/actions/build-linux-release
|
||||||
|
with:
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
arch: ${{ matrix.arch }}
|
||||||
|
- name: Upload MLX artifacts
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
overwrite: true
|
||||||
|
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
|
||||||
|
path: wheelhouse/mlx-*.whl
|
||||||
|
- name: Upload CPU artifacts
|
||||||
|
if: matrix.python_version == '3.10'
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
overwrite: true
|
||||||
|
name: mlx-cpu-${{ matrix.arch }}
|
||||||
|
path: wheelhouse/mlx_cpu-*.whl
|
||||||
|
|
||||||
|
build_mac_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
|
runs-on: [self-hosted, macos]
|
||||||
|
env:
|
||||||
|
PYPI_RELEASE: 1
|
||||||
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-macos
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install cmake setuptools nanobind==2.4.0
|
||||||
|
pip install -e . -v
|
||||||
|
- name: Generate package stubs
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
- name: Build macOS 14 package
|
||||||
|
uses: ./.github/actions/build-macos-release
|
||||||
|
with:
|
||||||
|
macos-target: 14.0
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
- name: Build macOS 15 package
|
||||||
|
uses: ./.github/actions/build-macos-release
|
||||||
|
with:
|
||||||
|
macos-target: 15.0
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
- name: Upload MLX artifacts
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
overwrite: true
|
||||||
|
name: mac-wheels-${{ matrix.python-version }}
|
||||||
|
path: dist/mlx-*.whl
|
||||||
|
- name: Upload Metal artifacts
|
||||||
|
if: matrix.python-version == '3.10'
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
overwrite: true
|
||||||
|
name: mlx-metal
|
||||||
|
path: dist/mlx_metal-*.whl
|
||||||
|
|
||||||
|
build_cuda_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
arch: ['x86_64', 'aarch64']
|
||||||
|
toolkit: ['cuda-12.9', 'cuda-13.0']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22-large' || 'ubuntu-22-large-arm' }}
|
||||||
|
env:
|
||||||
|
PYPI_RELEASE: 1
|
||||||
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
with:
|
||||||
|
toolkit: ${{ matrix.toolkit }}
|
||||||
|
- name: Build Python package
|
||||||
|
uses: ./.github/actions/build-cuda-release
|
||||||
|
with:
|
||||||
|
arch: ${{ matrix.arch }}
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
overwrite: true
|
||||||
|
name: mlx-cuda
|
||||||
|
path: wheelhouse/mlx_cuda-*.whl
|
||||||
|
|
||||||
|
pypi-publish:
|
||||||
|
name: Upload release to PyPI
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [setup, build_linux_release, build_mac_release]
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/mlx
|
||||||
|
steps:
|
||||||
|
- uses: actions/download-artifact@v6
|
||||||
|
with:
|
||||||
|
pattern: linux-wheels-*
|
||||||
|
merge-multiple: true
|
||||||
|
path: dist
|
||||||
|
- uses: actions/download-artifact@v6
|
||||||
|
with:
|
||||||
|
pattern: mac-wheels-*
|
||||||
|
merge-multiple: true
|
||||||
|
path: dist
|
||||||
|
- name: Display structure of downloaded files
|
||||||
|
run: ls -R dist
|
||||||
|
- name: Publish package distributions to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
|
|
||||||
|
pypi-publish-cuda:
|
||||||
|
name: Upload CUDA release to PyPI
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [setup, build_cuda_release]
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/mlx-cuda
|
||||||
|
steps:
|
||||||
|
- uses: actions/download-artifact@v6
|
||||||
|
with:
|
||||||
|
name: mlx-cuda
|
||||||
|
path: dist
|
||||||
|
- name: Display structure of downloaded files
|
||||||
|
run: ls -R dist
|
||||||
|
- name: Publish package distributions to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
|
|
||||||
|
pypi-publish-cpu:
|
||||||
|
name: Upload CPU release to PyPI
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [setup, build_linux_release]
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/mlx-cpu
|
||||||
|
steps:
|
||||||
|
- uses: actions/download-artifact@v6
|
||||||
|
with:
|
||||||
|
pattern: mlx-cpu-*
|
||||||
|
merge-multiple: true
|
||||||
|
path: dist
|
||||||
|
- name: Display structure of downloaded files
|
||||||
|
run: ls -R dist
|
||||||
|
- name: Publish package distributions to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
|
|
||||||
|
pypi-publish-metal:
|
||||||
|
name: Upload Metal release to PyPI
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [setup, build_mac_release]
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/mlx-metal
|
||||||
|
steps:
|
||||||
|
- uses: actions/download-artifact@v6
|
||||||
|
with:
|
||||||
|
name: mlx-metal
|
||||||
|
path: dist
|
||||||
|
- name: Display structure of downloaded files
|
||||||
|
run: ls -R dist
|
||||||
|
- name: Publish package distributions to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
|||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
|
uv.lock
|
||||||
|
|
||||||
# vim
|
# vim
|
||||||
*.swp
|
*.swp
|
||||||
|
|||||||
@@ -1,4 +1,10 @@
|
|||||||
repos:
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v6.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-yaml
|
||||||
|
# - id: end-of-file-fixer
|
||||||
|
# - id: trailing-whitespace
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v19.1.7
|
rev: v19.1.7
|
||||||
hooks:
|
hooks:
|
||||||
|
|||||||
@@ -19,11 +19,17 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
|
# Organizations
|
||||||
|
|
||||||
|
MLX has received contributions from the following companies:
|
||||||
|
- NVIDIA Corporation & Affiliates
|
||||||
|
|
||||||
# Third-Party Software
|
# Third-Party Software
|
||||||
|
|
||||||
MLX leverages several third-party software, listed here together with
|
MLX leverages several third-party software, listed here together with
|
||||||
|
|||||||
100
CMakeLists.txt
100
CMakeLists.txt
@@ -26,6 +26,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
# ----------------------------- Configuration -----------------------------
|
# ----------------------------- Configuration -----------------------------
|
||||||
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
||||||
@@ -34,13 +35,16 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
|||||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||||
|
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||||
|
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
message(
|
message(
|
||||||
@@ -63,10 +67,18 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|||||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
else()
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
endif()
|
||||||
|
|
||||||
|
if(MLX_USE_CCACHE)
|
||||||
|
find_program(CCACHE_PROGRAM ccache)
|
||||||
|
if(CCACHE_PROGRAM)
|
||||||
|
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Lib -----------------------------
|
# ----------------------------- Lib -----------------------------
|
||||||
@@ -77,18 +89,26 @@ cmake_policy(SET CMP0135 NEW)
|
|||||||
|
|
||||||
add_library(mlx)
|
add_library(mlx)
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
# Supress warnings: note: parameter passing for argument of type
|
||||||
set(METAL_LIB "-framework Metal")
|
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||||
set(FOUNDATION_LIB "-framework Foundation")
|
# 10.1
|
||||||
set(QUARTZ_LIB "-framework QuartzCore")
|
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||||
|
|
||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
enable_language(CUDA)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
if(MLX_BUILD_METAL)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
find_library(METAL_LIB Metal)
|
||||||
set(MLX_BUILD_METAL OFF)
|
find_library(FOUNDATION_LIB Foundation)
|
||||||
set(MLX_METAL_DEBUG OFF)
|
find_library(QUARTZ_LIB QuartzCore)
|
||||||
elseif(MLX_BUILD_METAL)
|
if(METAL_LIB)
|
||||||
message(STATUS "Building METAL sources")
|
message(STATUS "Metal found ${METAL_LIB}")
|
||||||
|
else()
|
||||||
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
|
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
add_compile_definitions(MLX_METAL_DEBUG)
|
add_compile_definitions(MLX_METAL_DEBUG)
|
||||||
@@ -97,7 +117,8 @@ elseif(MLX_BUILD_METAL)
|
|||||||
# Throw an error if xcrun not found
|
# Throw an error if xcrun not found
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||||
message(
|
message(
|
||||||
@@ -107,9 +128,12 @@ elseif(MLX_BUILD_METAL)
|
|||||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||||
|
|
||||||
set(METAL_CPP_URL
|
set(METAL_CPP_URL
|
||||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
|
||||||
|
|
||||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||||
|
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
|
||||||
|
message(FATAL_ERROR "MLX requires macOS >= 14.0")
|
||||||
|
endif()
|
||||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||||
endif()
|
endif()
|
||||||
execute_process(
|
execute_process(
|
||||||
@@ -118,7 +142,6 @@ elseif(MLX_BUILD_METAL)
|
|||||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||||
|
|
||||||
FetchContent_MakeAvailable(metal_cpp)
|
FetchContent_MakeAvailable(metal_cpp)
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||||
@@ -126,6 +149,12 @@ elseif(MLX_BUILD_METAL)
|
|||||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||||
|
# With newer clang/gcc versions following libs are implicitly linked, but when
|
||||||
|
# building on old distributions they need to be explicitly listed.
|
||||||
|
target_link_libraries(mlx PRIVATE dl pthread)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
if(MSVC)
|
if(MSVC)
|
||||||
# GGUF does not build with MSVC.
|
# GGUF does not build with MSVC.
|
||||||
@@ -153,7 +182,7 @@ if(MLX_BUILD_CPU)
|
|||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
set(MLX_BUILD_ACCELERATE ON)
|
set(MLX_BUILD_ACCELERATE ON)
|
||||||
else()
|
else()
|
||||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
message(STATUS "Accelerate not found, using default backend.")
|
||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@@ -212,24 +241,6 @@ else()
|
|||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_package(MPI)
|
|
||||||
if(MPI_FOUND)
|
|
||||||
execute_process(
|
|
||||||
COMMAND zsh "-c" "mpirun --version"
|
|
||||||
OUTPUT_VARIABLE MPI_VERSION
|
|
||||||
ERROR_QUIET)
|
|
||||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
|
||||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
|
||||||
elseif(MPI_VERSION STREQUAL "")
|
|
||||||
set(MPI_FOUND FALSE)
|
|
||||||
message(
|
|
||||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
|
||||||
else()
|
|
||||||
set(MPI_FOUND FALSE)
|
|
||||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
message(STATUS "Downloading json")
|
message(STATUS "Downloading json")
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
json
|
json
|
||||||
@@ -244,18 +255,25 @@ target_include_directories(
|
|||||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
$<INSTALL_INTERFACE:include>)
|
$<INSTALL_INTERFACE:include>)
|
||||||
|
|
||||||
FetchContent_Declare(
|
# Do not add mlx_EXPORTS define for shared library.
|
||||||
fmt
|
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
|
||||||
GIT_TAG 10.2.1
|
if(USE_SYSTEM_FMT)
|
||||||
EXCLUDE_FROM_ALL)
|
find_package(fmt REQUIRED)
|
||||||
FetchContent_MakeAvailable(fmt)
|
else()
|
||||||
|
FetchContent_Declare(
|
||||||
|
fmt
|
||||||
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
|
GIT_TAG 10.2.1
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
endif()
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||||
|
|
||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
message(STATUS "Building Python bindings.")
|
message(STATUS "Building Python bindings.")
|
||||||
find_package(
|
find_package(
|
||||||
Python 3.8
|
Python 3.10
|
||||||
COMPONENTS Interpreter Development.Module
|
COMPONENTS Interpreter Development.Module
|
||||||
REQUIRED)
|
REQUIRED)
|
||||||
execute_process(
|
execute_process(
|
||||||
|
|||||||
@@ -5,26 +5,26 @@ possible.
|
|||||||
|
|
||||||
## Pull Requests
|
## Pull Requests
|
||||||
|
|
||||||
1. Fork and submit pull requests to the repo.
|
1. Fork and submit pull requests to the repo.
|
||||||
2. If you've added code that should be tested, add tests.
|
2. If you've added code that should be tested, add tests.
|
||||||
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
||||||
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
||||||
4. If you've changed APIs, update the documentation.
|
4. If you've changed APIs, update the documentation.
|
||||||
5. Every PR should have passing tests and at least one review.
|
5. Every PR should have passing tests and at least one review.
|
||||||
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||||
This should install hooks for running `black` and `clang-format` to ensure
|
This should install hooks for running `black` and `clang-format` to ensure
|
||||||
consistent style for C++ and python code.
|
consistent style for C++ and python code.
|
||||||
|
|
||||||
You can also run the formatters manually as follows:
|
You can also run the formatters manually as follows:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
clang-format -i file.cpp
|
clang-format -i file.cpp
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```shell
|
||||||
black file.py
|
black file.py
|
||||||
```
|
```
|
||||||
|
|
||||||
or run `pre-commit run --all-files` to check all files in the repo.
|
or run `pre-commit run --all-files` to check all files in the repo.
|
||||||
|
|
||||||
## Issues
|
## Issues
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
include CMakeLists.txt
|
include CMakeLists.txt
|
||||||
|
include mlx.pc.in
|
||||||
recursive-include mlx/ *
|
recursive-include mlx/ *
|
||||||
|
include cmake/*
|
||||||
include python/src/*
|
include python/src/*
|
||||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||||
|
|||||||
57
README.md
57
README.md
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
|
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
|
||||||
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
|
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
|
||||||
[**Examples**](#examples)
|
[**Examples**](#examples)
|
||||||
|
|
||||||
[](https://circleci.com/gh/ml-explore/mlx)
|
[](https://circleci.com/gh/ml-explore/mlx)
|
||||||
|
|
||||||
@@ -11,37 +11,37 @@ brought to you by Apple machine learning research.
|
|||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
|
||||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||||
more complex models.
|
more complex models.
|
||||||
|
|
||||||
- **Composable function transformations**: MLX supports composable function
|
- **Composable function transformations**: MLX supports composable function
|
||||||
transformations for automatic differentiation, automatic vectorization,
|
transformations for automatic differentiation, automatic vectorization,
|
||||||
and computation graph optimization.
|
and computation graph optimization.
|
||||||
|
|
||||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||||
materialized when needed.
|
materialized when needed.
|
||||||
|
|
||||||
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||||
dynamically. Changing the shapes of function arguments does not trigger
|
dynamically. Changing the shapes of function arguments does not trigger
|
||||||
slow compilations, and debugging is simple and intuitive.
|
slow compilations, and debugging is simple and intuitive.
|
||||||
|
|
||||||
- **Multi-device**: Operations can run on any of the supported devices
|
- **Multi-device**: Operations can run on any of the supported devices
|
||||||
(currently the CPU and the GPU).
|
(currently the CPU and the GPU).
|
||||||
|
|
||||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||||
Operations on MLX arrays can be performed on any of the supported
|
Operations on MLX arrays can be performed on any of the supported
|
||||||
device types without transferring data.
|
device types without transferring data.
|
||||||
|
|
||||||
MLX is designed by machine learning researchers for machine learning
|
MLX is designed by machine learning researchers for machine learning
|
||||||
researchers. The framework is intended to be user-friendly, but still efficient
|
researchers. The framework is intended to be user-friendly, but still efficient
|
||||||
to train and deploy models. The design of the framework itself is also
|
to train and deploy models. The design of the framework itself is also
|
||||||
conceptually simple. We intend to make it easy for researchers to extend and
|
conceptually simple. We intend to make it easy for researchers to extend and
|
||||||
improve MLX with the goal of quickly exploring new ideas.
|
improve MLX with the goal of quickly exploring new ideas.
|
||||||
|
|
||||||
The design of MLX is inspired by frameworks like
|
The design of MLX is inspired by frameworks like
|
||||||
[NumPy](https://numpy.org/doc/stable/index.html),
|
[NumPy](https://numpy.org/doc/stable/index.html),
|
||||||
@@ -68,25 +68,30 @@ in the documentation.
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||||
|
macOS, run:
|
||||||
|
|
||||||
**With `pip`**:
|
```bash
|
||||||
|
|
||||||
```
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
**With `conda`**:
|
To install the CUDA backend on Linux, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cuda]
|
||||||
```
|
```
|
||||||
conda install -c conda-forge mlx
|
|
||||||
|
To install a CPU-only Linux package, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cpu]
|
||||||
```
|
```
|
||||||
|
|
||||||
Checkout the
|
Checkout the
|
||||||
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
||||||
for more information on building the C++ and Python APIs from source.
|
for more information on building the C++ and Python APIs from source.
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
||||||
on contributing to MLX. See the
|
on contributing to MLX. See the
|
||||||
@@ -105,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
|||||||
MLX useful in your research and wish to cite it, please use the following
|
MLX useful in your research and wish to cite it, please use the following
|
||||||
BibTex entry:
|
BibTex entry:
|
||||||
|
|
||||||
```
|
```text
|
||||||
@software{mlx2023,
|
@software{mlx2023,
|
||||||
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
||||||
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ void time_irregular_binary_ops_3D() {
|
|||||||
|
|
||||||
void time_irregular_binary_ops_4D() {
|
void time_irregular_binary_ops_4D() {
|
||||||
auto device = mx::default_device();
|
auto device = mx::default_device();
|
||||||
std::vector<int> shape = {8, 8, 512, 512};
|
mx::Shape shape = {8, 8, 512, 512};
|
||||||
auto a = mx::random::uniform(shape);
|
auto a = mx::random::uniform(shape);
|
||||||
auto b = mx::random::uniform(shape);
|
auto b = mx::random::uniform(shape);
|
||||||
|
|
||||||
@@ -114,7 +115,7 @@ void time_irregular_binary_ops_4D() {
|
|||||||
|
|
||||||
void time_irregular_reshape() {
|
void time_irregular_reshape() {
|
||||||
auto device = mx::default_device();
|
auto device = mx::default_device();
|
||||||
std::vector<int> shape;
|
mx::Shape shape;
|
||||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||||
return mx::reshape(a, shape, device);
|
return mx::reshape(a, shape, device);
|
||||||
};
|
};
|
||||||
@@ -169,7 +170,7 @@ void time_irregular_astype_1D() {
|
|||||||
void time_irregular_astype_2D() {
|
void time_irregular_astype_2D() {
|
||||||
auto device = mx::default_device();
|
auto device = mx::default_device();
|
||||||
int size = 2048;
|
int size = 2048;
|
||||||
std::vector<int> shape = {size, size};
|
mx::Shape shape = {size, size};
|
||||||
|
|
||||||
auto a = mx::random::uniform(shape);
|
auto a = mx::random::uniform(shape);
|
||||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||||
|
|||||||
@@ -192,6 +192,22 @@ void time_reductions() {
|
|||||||
|
|
||||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||||
TIME(argmin_along_1);
|
TIME(argmin_along_1);
|
||||||
|
|
||||||
|
auto indices = mx::array({1});
|
||||||
|
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||||
|
std::vector<int> axes{0};
|
||||||
|
auto b = scatter(a, {indices}, updates, axes);
|
||||||
|
mx::eval(b);
|
||||||
|
|
||||||
|
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||||
|
TIME(max_along_0);
|
||||||
|
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||||
|
TIME(max_along_1);
|
||||||
|
|
||||||
|
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||||
|
TIME(min_along_0);
|
||||||
|
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||||
|
TIME(min_along_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_gather_scatter() {
|
void time_gather_scatter() {
|
||||||
|
|||||||
@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
|||||||
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
||||||
|
|
||||||
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
||||||
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
|
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
|
||||||
np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||||
|
|
||||||
dtypes = ("float32", "float16")
|
dtypes = ("float32", "float16", "complex64")
|
||||||
transposes = ("nn", "nt", "tn")
|
transposes = ("nn", "nt", "tn")
|
||||||
shapes = (
|
shapes = (
|
||||||
(16, 234, 768, 3072),
|
(16, 234, 768, 3072),
|
||||||
@@ -187,7 +185,7 @@ if __name__ == "__main__":
|
|||||||
diff = gflops_mx / gflops_pt - 1.0
|
diff = gflops_mx / gflops_pt - 1.0
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
|
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
|
||||||
)
|
)
|
||||||
if gflops_pt >= 2.0 * gflops_mx:
|
if gflops_pt >= 2.0 * gflops_mx:
|
||||||
print("ATTENTION ^^^^^^^")
|
print("ATTENTION ^^^^^^^")
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
@@ -196,7 +195,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
|||||||
|
|
||||||
|
|
||||||
for transpose in (False, True):
|
for transpose in (False, True):
|
||||||
for dtype in ("float32", "float16"):
|
for dtype in ("float32", "float16", "complex64"):
|
||||||
fig, axs = plt.subplots(
|
fig, axs = plt.subplots(
|
||||||
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
||||||
)
|
)
|
||||||
@@ -215,7 +214,7 @@ for transpose in (False, True):
|
|||||||
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
||||||
fig.savefig(
|
fig.savefig(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
|
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.cuda
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
@@ -44,8 +45,10 @@ def bench(f, *args):
|
|||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
def sync_if_needed(x):
|
||||||
if x.device != torch.device("cpu"):
|
if x.device == torch.device("mps"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
elif x.device == torch.device("cuda"):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
@@ -340,7 +351,11 @@ if __name__ == "__main__":
|
|||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "cpu" if args.cpu else "mps"
|
device = "mps"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
if args.cpu:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
types = args.dtype
|
types = args.dtype
|
||||||
if not types:
|
if not types:
|
||||||
@@ -460,5 +475,8 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
|||||||
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dtype = "float32"
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 21, 3, 3, 128),
|
||||||
|
(4, 32, 32, 21, 3, 3, 37),
|
||||||
|
(4, 32, 32, 370, 3, 3, 370),
|
||||||
|
(4, 32, 32, 370, 7, 7, 128),
|
||||||
|
(2, 320, 640, 21, 7, 7, 21),
|
||||||
|
)
|
||||||
|
for N, H, W, C, kh, kw, O in shapes:
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_mm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = x @ w1.T
|
||||||
|
x = x @ w2.T
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_mm()
|
||||||
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate(
|
||||||
|
[
|
||||||
|
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||||
|
for i, j in enumerate(idx.tolist())
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_qmm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||||
|
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_qmm()
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
@@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def time_layer_norm():
|
def time_layer_norm(N, dt):
|
||||||
|
L = 1024
|
||||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_loop(g, x, w, b):
|
def layer_norm_loop(f, x, w, b):
|
||||||
|
for _ in range(32):
|
||||||
|
x = f(x, w, b)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||||
|
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||||
|
|
||||||
|
def layer_norm_grad_loop(g, x, w, b):
|
||||||
gx, gw, gb = x, w, b
|
gx, gw, gb = x, w, b
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx, gw, gb = g(gx, gw, gb, y)
|
gx, gw, gb = g(gx, gw, gb, y)
|
||||||
return gx, gw, gb
|
return gx, gw, gb
|
||||||
|
|
||||||
time_fn(layer_norm_loop, g1, x, w, b)
|
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||||
time_fn(layer_norm_loop, g2, x, w, b)
|
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||||
|
|
||||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0,))
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
g2 = mx.grad(f2, argnums=(0,))
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_loop(g, x):
|
def layer_norm_grad_x_loop(g, x):
|
||||||
gx = x
|
gx = x
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx = g(gx, y)
|
gx = g(gx, y)
|
||||||
return gx
|
return gx
|
||||||
|
|
||||||
time_fn(layer_norm_loop, g1, x)
|
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||||
time_fn(layer_norm_loop, g2, x)
|
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
time_layer_norm()
|
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||||
|
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||||
|
print(dt, n)
|
||||||
|
time_layer_norm(n, dt)
|
||||||
|
|||||||
212
benchmarks/python/masked_scatter.py
Normal file
212
benchmarks/python/masked_scatter.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from copy import copy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from matplotlib.ticker import FuncFormatter
|
||||||
|
|
||||||
|
RESULTS_DIR = "./results"
|
||||||
|
|
||||||
|
|
||||||
|
if not os.path.isdir(RESULTS_DIR):
|
||||||
|
os.mkdir(RESULTS_DIR)
|
||||||
|
|
||||||
|
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
|
||||||
|
|
||||||
|
TORCH_DEVICE = torch.device(
|
||||||
|
"mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
N_WARMUP = 5
|
||||||
|
N_ITER_BENCH = 50
|
||||||
|
N_ITER_FUNC = 20
|
||||||
|
|
||||||
|
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
|
||||||
|
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
|
||||||
|
D_TYPES = ("float32", "float16")
|
||||||
|
|
||||||
|
|
||||||
|
def _power_of_two_formatter(value, _position):
|
||||||
|
if value <= 0:
|
||||||
|
return ""
|
||||||
|
exponent = int(round(math.log2(value)))
|
||||||
|
if abs(value - (1 << exponent)) / value > 1e-6:
|
||||||
|
return f"{value:g}"
|
||||||
|
return f"$2^{{{exponent}}}$"
|
||||||
|
|
||||||
|
|
||||||
|
def torch_sync():
|
||||||
|
if TORCH_DEVICE.type == "cuda":
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elif TORCH_DEVICE.type == "mps":
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
|
||||||
|
outs = []
|
||||||
|
for _ in range(N_ITER_FUNC):
|
||||||
|
out = copy(self_arr)
|
||||||
|
out[mask_arr] = src_arr
|
||||||
|
outs.append(out)
|
||||||
|
mx.eval(outs)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
|
||||||
|
outs = []
|
||||||
|
for _ in range(N_ITER_FUNC):
|
||||||
|
out = self_tensor.clone()
|
||||||
|
out.masked_scatter_(mask_tensor, src_tensor)
|
||||||
|
outs.append(out)
|
||||||
|
torch_sync()
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
def measure(fn):
|
||||||
|
for _ in range(N_WARMUP):
|
||||||
|
fn()
|
||||||
|
start = time.perf_counter_ns()
|
||||||
|
for _ in range(N_ITER_BENCH):
|
||||||
|
fn()
|
||||||
|
end = time.perf_counter_ns()
|
||||||
|
return (end - start) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_touched(length, true_count, item_size):
|
||||||
|
mask_bytes = length
|
||||||
|
self_bytes = length * item_size * 2 # read + write
|
||||||
|
src_bytes = true_count * item_size
|
||||||
|
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
|
||||||
|
|
||||||
|
|
||||||
|
def build_case(length, density, np_dtype, torch_dtype):
|
||||||
|
true_count = max(1, int(round(length * density)))
|
||||||
|
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
|
||||||
|
mask_np = np.zeros(length, dtype=bool)
|
||||||
|
mask_np[:true_count] = True
|
||||||
|
rng.shuffle(mask_np)
|
||||||
|
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
|
||||||
|
|
||||||
|
self_mlx = mx.array(self_np)
|
||||||
|
mask_mlx = mx.array(mask_np)
|
||||||
|
src_mlx = mx.array(src_np)
|
||||||
|
|
||||||
|
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||||
|
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
|
||||||
|
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||||
|
|
||||||
|
# Correctness check once per configuration
|
||||||
|
mx_out = mx.array(self_np)
|
||||||
|
mx_out[mask_mlx] = src_mlx
|
||||||
|
mx.eval(mx_out)
|
||||||
|
torch_out = self_torch.clone()
|
||||||
|
torch_out.masked_scatter_(mask_torch, src_torch)
|
||||||
|
|
||||||
|
atol = 5e-3 if np_dtype == np.float16 else 1e-5
|
||||||
|
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
|
||||||
|
raise AssertionError("masked_scatter results diverged between MLX and Torch")
|
||||||
|
|
||||||
|
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_case(length, density, dtype):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
torch_dtype = getattr(torch, dtype)
|
||||||
|
(
|
||||||
|
self_mlx,
|
||||||
|
mask_mlx,
|
||||||
|
src_mlx,
|
||||||
|
self_torch,
|
||||||
|
mask_torch,
|
||||||
|
src_torch,
|
||||||
|
true_count,
|
||||||
|
) = build_case(length, density, np_dtype, torch_dtype)
|
||||||
|
|
||||||
|
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
|
||||||
|
time_torch = measure(
|
||||||
|
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
|
||||||
|
bytes_per_gb = float(1024**3)
|
||||||
|
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
|
||||||
|
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
|
||||||
|
|
||||||
|
return time_mlx, time_torch, mlx_gbps, torch_gbps
|
||||||
|
|
||||||
|
|
||||||
|
def plot_density(ax_perf, ax_speedup, density, dtype):
|
||||||
|
mlx_gbps = []
|
||||||
|
torch_gbps = []
|
||||||
|
mlx_times = []
|
||||||
|
torch_times = []
|
||||||
|
|
||||||
|
for length in VECTOR_LENGTHS:
|
||||||
|
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
|
||||||
|
mlx_gbps.append(gbps_mlx)
|
||||||
|
torch_gbps.append(gbps_torch)
|
||||||
|
mlx_times.append(t_mlx)
|
||||||
|
torch_times.append(t_torch)
|
||||||
|
|
||||||
|
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
|
||||||
|
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
|
||||||
|
ax_perf.set_xscale("log", base=2)
|
||||||
|
ax_perf.set_xticks(VECTOR_LENGTHS)
|
||||||
|
formatter = FuncFormatter(_power_of_two_formatter)
|
||||||
|
ax_perf.xaxis.set_major_formatter(formatter)
|
||||||
|
ax_perf.set_title(f"density={density:.2f}")
|
||||||
|
ax_perf.set_ylabel("GB/s")
|
||||||
|
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||||
|
ax_perf.legend()
|
||||||
|
|
||||||
|
speedup = np.array(torch_times) / np.array(mlx_times)
|
||||||
|
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
|
||||||
|
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
|
||||||
|
ax_speedup.set_xscale("log", base=2)
|
||||||
|
ax_speedup.set_xticks(VECTOR_LENGTHS)
|
||||||
|
ax_speedup.xaxis.set_major_formatter(formatter)
|
||||||
|
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
|
||||||
|
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
for dtype in D_TYPES:
|
||||||
|
fig, axs = plt.subplots(
|
||||||
|
len(MASK_DENSITIES),
|
||||||
|
2,
|
||||||
|
figsize=(10, 12),
|
||||||
|
layout="constrained",
|
||||||
|
sharex=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, density in enumerate(MASK_DENSITIES):
|
||||||
|
plot_density(axs[i][0], axs[i][1], density, dtype)
|
||||||
|
axs[i][0].set_xlabel("vector length")
|
||||||
|
axs[i][1].set_xlabel("vector length")
|
||||||
|
|
||||||
|
fig.suptitle(
|
||||||
|
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
|
||||||
|
)
|
||||||
|
output_path = os.path.join(
|
||||||
|
RESULTS_DIR,
|
||||||
|
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
|
||||||
|
)
|
||||||
|
fig.savefig(output_path)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -28,11 +28,34 @@ def bench(f, *args):
|
|||||||
return (e - s) * 1e-9
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
def mlx_sdpa_fused_inner(q, k, v, scale):
|
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
np_dtype = getattr(np, dtype)
|
||||||
|
|
||||||
|
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||||
|
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||||
|
|
||||||
|
scale = 1.0 / math.sqrt(D)
|
||||||
|
|
||||||
|
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||||
|
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
|
||||||
|
q_mx = mx.array(q_np)
|
||||||
|
k_mx = mx.array(k_np)
|
||||||
|
v_mx = mx.array(v_np)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask == "additive":
|
||||||
|
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
elif mask == "bool":
|
||||||
|
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
|
||||||
|
return q_mx, k_mx, v_mx, scale, mask
|
||||||
|
|
||||||
|
|
||||||
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||||
q_dtype = q.dtype
|
q_dtype = q.dtype
|
||||||
q = q * mx.array(scale, q_dtype)
|
q = q * mx.array(scale, q_dtype)
|
||||||
n_q_heads = q.shape[-3]
|
n_q_heads = q.shape[-3]
|
||||||
@@ -41,6 +64,7 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
|||||||
|
|
||||||
B = q.shape[0]
|
B = q.shape[0]
|
||||||
L = q.shape[2]
|
L = q.shape[2]
|
||||||
|
kL = k.shape[2]
|
||||||
|
|
||||||
if n_repeats > 1:
|
if n_repeats > 1:
|
||||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||||
@@ -48,10 +72,27 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
|||||||
v = mx.expand_dims(v, 2)
|
v = mx.expand_dims(v, 2)
|
||||||
|
|
||||||
scores = q @ mx.swapaxes(k, -1, -2)
|
scores = q @ mx.swapaxes(k, -1, -2)
|
||||||
if f32softmax:
|
|
||||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
if mask is not None:
|
||||||
else:
|
|
||||||
scores = mx.softmax(scores, axis=-1)
|
if mask == "causal":
|
||||||
|
q_offset = max(0, kL - L)
|
||||||
|
q_indices = mx.arange(q_offset, q_offset + L)
|
||||||
|
k_indices = mx.arange(kL)
|
||||||
|
mask = q_indices[:, None] >= k_indices[None]
|
||||||
|
|
||||||
|
if n_repeats > 1 and mask.ndim >= 3:
|
||||||
|
if mask.shape[-3] == 1:
|
||||||
|
mask = mx.expand_dims(mask, -3)
|
||||||
|
else:
|
||||||
|
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||||
|
|
||||||
|
if mask.dtype == mx.bool_:
|
||||||
|
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||||
|
else:
|
||||||
|
scores += mask
|
||||||
|
|
||||||
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||||
|
|
||||||
out = scores @ v
|
out = scores @ v
|
||||||
if n_repeats > 1:
|
if n_repeats > 1:
|
||||||
@@ -60,74 +101,55 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def mlx_spda_unfused(q, k, v, scale, transpose):
|
def mlx_fused_attn(q, k, v, scale, mask):
|
||||||
q_out = q
|
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||||
if transpose:
|
if transpose:
|
||||||
k = mx.transpose(k, (0, 2, 1, 3))
|
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||||
v = mx.transpose(v, (0, 2, 1, 3))
|
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||||
|
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||||
|
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||||
|
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||||
|
else:
|
||||||
|
return f(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
||||||
|
q_out = q
|
||||||
|
|
||||||
for i in range(N_iter_func):
|
for i in range(N_iter_func):
|
||||||
if transpose:
|
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
|
||||||
if transpose:
|
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
|
|
||||||
mx.eval(q_out)
|
mx.eval(q_out)
|
||||||
return q_out
|
return q_out
|
||||||
|
|
||||||
|
|
||||||
def mlx_spda_fused(q, k, v, scale, transpose):
|
def bench_shape(
|
||||||
q_out = q
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
||||||
if transpose:
|
):
|
||||||
k = mx.transpose(k, (0, 2, 1, 3))
|
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||||
v = mx.transpose(v, (0, 2, 1, 3))
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
||||||
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
if transpose:
|
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
|
||||||
if transpose:
|
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
|
|
||||||
mx.eval(q_out)
|
|
||||||
return q_out
|
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
|
||||||
shape_q = (
|
|
||||||
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
|
||||||
)
|
|
||||||
shape_kv = (
|
|
||||||
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
time_mlx_unfused = bench(
|
||||||
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
)
|
||||||
|
time_mlx_fused = bench(
|
||||||
|
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
|
||||||
scale = math.sqrt(1.0 / head_dim)
|
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
||||||
|
o_mlx_unfused = do_attention(
|
||||||
|
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
|
||||||
q_mx = mx.array(q_np)
|
atol = 1e-5 if dtype == "float32" else 2e-4
|
||||||
k_mx = mx.array(k_np)
|
|
||||||
v_mx = mx.array(v_np)
|
|
||||||
|
|
||||||
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
||||||
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
|
||||||
|
|
||||||
if transpose:
|
|
||||||
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
|
||||||
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
|
||||||
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
|
||||||
|
|
||||||
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
|
||||||
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
|
||||||
|
|
||||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
|
||||||
|
|
||||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
|
||||||
print(
|
print(
|
||||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return time_mlx_fused, time_mlx_unfused
|
return time_mlx_fused, time_mlx_unfused
|
||||||
@@ -151,39 +173,51 @@ if __name__ == "__main__":
|
|||||||
( 1, 128, 128, 64, 32, 32),
|
( 1, 128, 128, 64, 32, 32),
|
||||||
( 1, 256, 256, 64, 32, 32),
|
( 1, 256, 256, 64, 32, 32),
|
||||||
( 1, 512, 512, 64, 32, 32),
|
( 1, 512, 512, 64, 32, 32),
|
||||||
( 1, 1024, 1024, 64, 32, 32),
|
( 1, 1024, 1024, 64, 32, 8),
|
||||||
( 1, 2048, 2048, 64, 32, 32),
|
( 1, 2048, 2048, 64, 32, 8),
|
||||||
( 1, 4096, 4096, 64, 32, 32),
|
( 1, 4096, 4096, 64, 32, 8),
|
||||||
)
|
)
|
||||||
|
|
||||||
shapes_80 = (
|
shapes_80 = (
|
||||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
( 1, 1024, 1024, 80, 32, 32),
|
( 1, 1024, 1024, 80, 32, 8),
|
||||||
( 1, 2048, 2048, 80, 32, 32),
|
( 1, 2048, 2048, 80, 32, 8),
|
||||||
( 1, 4096, 4096, 80, 32, 32),
|
( 1, 4096, 4096, 80, 32, 8),
|
||||||
)
|
)
|
||||||
|
|
||||||
shapes_128 = (
|
shapes_128 = (
|
||||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
( 1, 1024, 1024, 128, 32, 32),
|
( 1, 1024, 1024, 128, 32, 8),
|
||||||
( 1, 2048, 2048, 128, 32, 32),
|
( 1, 2048, 2048, 128, 32, 8),
|
||||||
( 1, 4096, 4096, 128, 32, 32),
|
( 1, 4096, 4096, 128, 32, 8),
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
shapes = shapes_64 + shapes_80 + shapes_128
|
shapes = shapes_64 + shapes_80 + shapes_128
|
||||||
|
|
||||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
masks = [None, "bool", "causal"]
|
||||||
|
|
||||||
|
print(
|
||||||
|
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
||||||
|
)
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
for transpose in transposes:
|
for transpose in transposes:
|
||||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||||
np_dtype = getattr(np, dtype)
|
for mask_in in masks:
|
||||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
B,
|
||||||
)
|
qsl,
|
||||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
ksl,
|
||||||
t_str = 1 if transpose else 0
|
head_dim,
|
||||||
print(
|
n_q_heads,
|
||||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
n_kv_heads,
|
||||||
)
|
dtype,
|
||||||
|
transpose,
|
||||||
|
mask_in,
|
||||||
|
)
|
||||||
|
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||||
|
t_str = 1 if transpose else 0
|
||||||
|
print(
|
||||||
|
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
|||||||
@@ -51,6 +51,20 @@ def time_maximum():
|
|||||||
time_fn(mx.maximum, a, b)
|
time_fn(mx.maximum, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_max():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.max, a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def time_min():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.min, a, 0)
|
||||||
|
|
||||||
|
|
||||||
def time_negative():
|
def time_negative():
|
||||||
a = mx.random.uniform(shape=(10000, 1000))
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
@@ -108,6 +122,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
time_add()
|
time_add()
|
||||||
time_matmul()
|
time_matmul()
|
||||||
|
time_min()
|
||||||
|
time_max()
|
||||||
time_maximum()
|
time_maximum()
|
||||||
time_exp()
|
time_exp()
|
||||||
time_negative()
|
time_negative()
|
||||||
|
|||||||
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
||||||
|
# directories.
|
||||||
|
|
||||||
|
set(NCCL_ROOT_DIR
|
||||||
|
$ENV{NCCL_ROOT_DIR}
|
||||||
|
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||||
|
|
||||||
|
find_path(
|
||||||
|
NCCL_INCLUDE_DIRS
|
||||||
|
NAMES nccl.h
|
||||||
|
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||||
|
|
||||||
|
if($ENV{USE_STATIC_NCCL})
|
||||||
|
message(
|
||||||
|
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||||
|
set(NCCL_LIBNAME "libnccl_static.a")
|
||||||
|
else()
|
||||||
|
set(NCCL_LIBNAME "nccl")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(
|
||||||
|
NCCL_LIBRARIES
|
||||||
|
NAMES ${NCCL_LIBNAME}
|
||||||
|
HINTS ${NCCL_LIB_DIR}
|
||||||
|
${NCCL_ROOT_DIR}
|
||||||
|
${NCCL_ROOT_DIR}/lib
|
||||||
|
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||||
|
${NCCL_ROOT_DIR}/lib64
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||||
|
NCCL_LIBRARIES)
|
||||||
|
|
||||||
|
if(NCCL_FOUND)
|
||||||
|
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||||
|
message(
|
||||||
|
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||||
|
file(
|
||||||
|
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||||
|
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||||
|
LIMIT_COUNT 1)
|
||||||
|
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||||
|
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||||
|
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||||
|
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||||
|
endif()
|
||||||
|
message(
|
||||||
|
STATUS
|
||||||
|
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
|
endif()
|
||||||
3
cmake/Findnvpl.cmake
Normal file
3
cmake/Findnvpl.cmake
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# This file does nothing but to suppress the cmake warning: "By not providing
|
||||||
|
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
|
||||||
|
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.
|
||||||
@@ -11,13 +11,14 @@ include(CMakeParseArguments)
|
|||||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||||
# files (like headers)
|
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
||||||
|
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
||||||
#
|
#
|
||||||
# clang format on
|
# clang format on
|
||||||
|
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
||||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
|
|
||||||
@@ -26,6 +27,10 @@ macro(mlx_build_metallib)
|
|||||||
|
|
||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||||
|
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
||||||
|
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
||||||
|
-frecord-sources)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
|
|||||||
CREATE_SUBDIRS = NO
|
CREATE_SUBDIRS = NO
|
||||||
FULL_PATH_NAMES = YES
|
FULL_PATH_NAMES = YES
|
||||||
RECURSIVE = YES
|
RECURSIVE = YES
|
||||||
GENERATE_HTML = YES
|
GENERATE_HTML = NO
|
||||||
GENERATE_LATEX = NO
|
GENERATE_LATEX = NO
|
||||||
GENERATE_XML = YES
|
GENERATE_XML = YES
|
||||||
XML_PROGRAMLISTING = YES
|
XML_PROGRAMLISTING = YES
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
sphinx
|
sphinx
|
||||||
breathe
|
breathe
|
||||||
sphinx-book-theme
|
sphinx-book-theme
|
||||||
|
sphinx-copybutton
|
||||||
mlx
|
mlx
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import mlx.core as mx
|
|||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "MLX"
|
project = "MLX"
|
||||||
copyright = "2023, MLX Contributors"
|
copyright = "2023, Apple"
|
||||||
author = "MLX Contributors"
|
author = "MLX Contributors"
|
||||||
version = ".".join(mx.__version__.split(".")[:3])
|
version = ".".join(mx.__version__.split(".")[:3])
|
||||||
release = version
|
release = version
|
||||||
@@ -18,6 +18,7 @@ release = version
|
|||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
extensions = [
|
extensions = [
|
||||||
|
"sphinx_copybutton",
|
||||||
"sphinx.ext.autodoc",
|
"sphinx.ext.autodoc",
|
||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
|
|||||||
@@ -8,23 +8,26 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
|||||||
Simple Example
|
Simple Example
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
source = """
|
||||||
source = """
|
uint elem = thread_position_in_grid.x;
|
||||||
uint elem = thread_position_in_grid.x;
|
T tmp = inp[elem];
|
||||||
T tmp = inp[elem];
|
out[elem] = metal::exp(tmp);
|
||||||
out[elem] = metal::exp(tmp);
|
"""
|
||||||
"""
|
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="myexp",
|
name="myexp",
|
||||||
input_names=["inp"],
|
input_names=["inp"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
b = exp_elementwise(a)
|
b = exp_elementwise(a)
|
||||||
assert mx.allclose(b, mx.exp(a))
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
|
Every time you make a kernel, a new Metal library is created and possibly
|
||||||
|
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||||
|
:func:`fast.metal_kernel` and then use it many times.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
We are only required to pass the body of the Metal kernel in ``source``.
|
Only pass the body of the Metal kernel in ``source``. The function
|
||||||
|
signature is generated automatically.
|
||||||
|
|
||||||
The full function signature will be generated using:
|
The full function signature will be generated using:
|
||||||
|
|
||||||
@@ -78,44 +86,52 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
|||||||
|
|
||||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||||
|
|
||||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||||
|
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||||
|
dimension should be less than or equal to the corresponding grid dimension.
|
||||||
|
|
||||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||||
|
generated code for debugging purposes.
|
||||||
|
|
||||||
Using Shape/Strides
|
Using Shape/Strides
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
is ``True`` by default. This will copy the array inputs if needed
|
||||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
before the kernel is launched to ensure that the memory layout is row
|
||||||
when indexing.
|
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||||
|
have to worry about gaps or the ordering of the dims when indexing.
|
||||||
|
|
||||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||||
input array ``a`` if any are present in ``source``.
|
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||||
|
the right elements for each thread.
|
||||||
|
|
||||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||||
|
relying on a copy from ``ensure_row_contiguous``:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||||
|
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||||
|
T tmp = inp[loc];
|
||||||
|
// Output arrays are always row contiguous
|
||||||
|
out[elem] = metal::exp(tmp);
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="myexp_strided",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source,
|
||||||
|
ensure_row_contiguous=False,
|
||||||
|
)
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
def exp_elementwise(a: mx.array):
|
||||||
source = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
|
||||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
|
||||||
T tmp = inp[loc];
|
|
||||||
// Output arrays are always row contiguous
|
|
||||||
out[elem] = metal::exp(tmp);
|
|
||||||
"""
|
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="myexp_strided",
|
|
||||||
input_names=["inp"],
|
|
||||||
output_names=["out"],
|
|
||||||
source=source
|
|
||||||
)
|
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -123,7 +139,6 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
|
|||||||
threadgroup=(256, 1, 1),
|
threadgroup=(256, 1, 1),
|
||||||
output_shapes=[a.shape],
|
output_shapes=[a.shape],
|
||||||
output_dtypes=[a.dtype],
|
output_dtypes=[a.dtype],
|
||||||
ensure_row_contiguous=False,
|
|
||||||
)
|
)
|
||||||
return outputs[0]
|
return outputs[0]
|
||||||
|
|
||||||
@@ -142,137 +157,139 @@ We'll start with the following MLX implementation using standard ops:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def grid_sample_ref(x, grid):
|
def grid_sample_ref(x, grid):
|
||||||
N, H_in, W_in, _ = x.shape
|
N, H_in, W_in, _ = x.shape
|
||||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||||
|
|
||||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||||
|
|
||||||
ix_ne = ix_nw + 1
|
ix_ne = ix_nw + 1
|
||||||
iy_ne = iy_nw
|
iy_ne = iy_nw
|
||||||
|
|
||||||
ix_sw = ix_nw
|
ix_sw = ix_nw
|
||||||
iy_sw = iy_nw + 1
|
iy_sw = iy_nw + 1
|
||||||
|
|
||||||
ix_se = ix_nw + 1
|
ix_se = ix_nw + 1
|
||||||
iy_se = iy_nw + 1
|
iy_se = iy_nw + 1
|
||||||
|
|
||||||
nw = (ix_se - ix) * (iy_se - iy)
|
nw = (ix_se - ix) * (iy_se - iy)
|
||||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||||
se = (ix - ix_nw) * (iy - iy_nw)
|
se = (ix - ix_nw) * (iy - iy_nw)
|
||||||
|
|
||||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||||
|
|
||||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||||
|
|
||||||
I_nw *= mask_nw[..., None]
|
I_nw *= mask_nw[..., None]
|
||||||
I_ne *= mask_ne[..., None]
|
I_ne *= mask_ne[..., None]
|
||||||
I_sw *= mask_sw[..., None]
|
I_sw *= mask_sw[..., None]
|
||||||
I_se *= mask_se[..., None]
|
I_se *= mask_se[..., None]
|
||||||
|
|
||||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||||
to write a fast GPU kernel for both the forward and backward passes.
|
to write a fast GPU kernel for both the forward and backward passes.
|
||||||
|
|
||||||
First we'll implement the forward pass as a fused kernel:
|
First we'll implement the forward pass as a fused kernel:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@mx.custom_function
|
source = """
|
||||||
def grid_sample(x, grid):
|
uint elem = thread_position_in_grid.x;
|
||||||
|
int H = x_shape[1];
|
||||||
|
int W = x_shape[2];
|
||||||
|
int C = x_shape[3];
|
||||||
|
int gH = grid_shape[1];
|
||||||
|
int gW = grid_shape[2];
|
||||||
|
|
||||||
assert x.ndim == 4, "`x` must be 4D."
|
int w_stride = C;
|
||||||
assert grid.ndim == 4, "`grid` must be 4D."
|
int h_stride = W * w_stride;
|
||||||
|
int b_stride = H * h_stride;
|
||||||
|
|
||||||
B, _, _, C = x.shape
|
uint grid_idx = elem / C * 2;
|
||||||
_, gN, gM, D = grid.shape
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
out_shape = (B, gN, gM, C)
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
int ix_nw = floor(ix);
|
||||||
|
int iy_nw = floor(iy);
|
||||||
|
|
||||||
source = """
|
int ix_ne = ix_nw + 1;
|
||||||
uint elem = thread_position_in_grid.x;
|
int iy_ne = iy_nw;
|
||||||
int H = x_shape[1];
|
|
||||||
int W = x_shape[2];
|
|
||||||
int C = x_shape[3];
|
|
||||||
int gH = grid_shape[1];
|
|
||||||
int gW = grid_shape[2];
|
|
||||||
|
|
||||||
int w_stride = C;
|
int ix_sw = ix_nw;
|
||||||
int h_stride = W * w_stride;
|
int iy_sw = iy_nw + 1;
|
||||||
int b_stride = H * h_stride;
|
|
||||||
|
|
||||||
uint grid_idx = elem / C * 2;
|
int ix_se = ix_nw + 1;
|
||||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
int iy_se = iy_nw + 1;
|
||||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
|
||||||
|
|
||||||
int ix_nw = floor(ix);
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
int iy_nw = floor(iy);
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
|
|
||||||
int ix_ne = ix_nw + 1;
|
int batch_idx = elem / C / gH / gW * b_stride;
|
||||||
int iy_ne = iy_nw;
|
int channel_idx = elem % C;
|
||||||
|
int base_idx = batch_idx + channel_idx;
|
||||||
|
|
||||||
int ix_sw = ix_nw;
|
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||||
int iy_sw = iy_nw + 1;
|
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||||
|
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||||
|
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||||
|
|
||||||
int ix_se = ix_nw + 1;
|
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||||
int iy_se = iy_nw + 1;
|
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||||
|
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||||
|
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||||
|
|
||||||
T nw = (ix_se - ix) * (iy_se - iy);
|
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
"""
|
||||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
|
||||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
|
||||||
|
|
||||||
int batch_idx = elem / C / gH / gW * b_stride;
|
kernel = mx.fast.metal_kernel(
|
||||||
int channel_idx = elem % C;
|
name="grid_sample",
|
||||||
int base_idx = batch_idx + channel_idx;
|
input_names=["x", "grid"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
|
||||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
@mx.custom_function
|
||||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
def grid_sample(x, grid):
|
||||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
|
||||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
|
||||||
|
|
||||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
assert x.ndim == 4, "`x` must be 4D."
|
||||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
assert grid.ndim == 4, "`grid` must be 4D."
|
||||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
|
||||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
|
||||||
|
|
||||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
B, _, _, C = x.shape
|
||||||
"""
|
_, gN, gM, D = grid.shape
|
||||||
kernel = mx.fast.metal_kernel(
|
out_shape = (B, gN, gM, C)
|
||||||
name="grid_sample",
|
|
||||||
input_names=["x", "grid"],
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
output_names=["out"],
|
|
||||||
source=source,
|
outputs = kernel(
|
||||||
)
|
inputs=[x, grid],
|
||||||
outputs = kernel(
|
template=[("T", x.dtype)],
|
||||||
inputs=[x, grid],
|
output_shapes=[out_shape],
|
||||||
template=[("T", x.dtype)],
|
output_dtypes=[x.dtype],
|
||||||
output_shapes=[out_shape],
|
grid=(np.prod(out_shape), 1, 1),
|
||||||
output_dtypes=[x.dtype],
|
threadgroup=(256, 1, 1),
|
||||||
grid=(np.prod(out_shape), 1, 1),
|
)
|
||||||
threadgroup=(256, 1, 1),
|
return outputs[0]
|
||||||
)
|
|
||||||
return outputs[0]
|
|
||||||
|
|
||||||
For a reasonably sized input such as:
|
For a reasonably sized input such as:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
x.shape = (8, 1024, 1024, 64)
|
x.shape = (8, 1024, 1024, 64)
|
||||||
grid.shape = (8, 256, 256, 2)
|
grid.shape = (8, 256, 256, 2)
|
||||||
|
|
||||||
On an M1 Max, we see a big performance improvement:
|
On an M1 Max, we see a big performance improvement:
|
||||||
|
|
||||||
@@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
|
|||||||
Grid Sample VJP
|
Grid Sample VJP
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||||
its custom vjp transform so MLX can differentiate it.
|
define its custom vjp transform so MLX can differentiate it.
|
||||||
|
|
||||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
requires a few extra :func:`fast.metal_kernel` features:
|
||||||
|
|
||||||
* ``init_value=0``
|
* ``init_value=0``
|
||||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||||
@@ -299,128 +316,129 @@ We can then implement the backwards pass as follows:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@grid_sample.vjp
|
source = """
|
||||||
def grid_sample_vjp(primals, cotangent, _):
|
uint elem = thread_position_in_grid.x;
|
||||||
x, grid = primals
|
int H = x_shape[1];
|
||||||
B, _, _, C = x.shape
|
int W = x_shape[2];
|
||||||
_, gN, gM, D = grid.shape
|
int C = x_shape[3];
|
||||||
|
// Pad C to the nearest larger simdgroup size multiple
|
||||||
|
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
int gH = grid_shape[1];
|
||||||
|
int gW = grid_shape[2];
|
||||||
|
|
||||||
source = """
|
int w_stride = C;
|
||||||
uint elem = thread_position_in_grid.x;
|
int h_stride = W * w_stride;
|
||||||
int H = x_shape[1];
|
int b_stride = H * h_stride;
|
||||||
int W = x_shape[2];
|
|
||||||
int C = x_shape[3];
|
|
||||||
// Pad C to the nearest larger simdgroup size multiple
|
|
||||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
|
||||||
|
|
||||||
int gH = grid_shape[1];
|
uint grid_idx = elem / C_padded * 2;
|
||||||
int gW = grid_shape[2];
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
int w_stride = C;
|
int ix_nw = floor(ix);
|
||||||
int h_stride = W * w_stride;
|
int iy_nw = floor(iy);
|
||||||
int b_stride = H * h_stride;
|
|
||||||
|
|
||||||
uint grid_idx = elem / C_padded * 2;
|
int ix_ne = ix_nw + 1;
|
||||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
int iy_ne = iy_nw;
|
||||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
|
||||||
|
|
||||||
int ix_nw = floor(ix);
|
int ix_sw = ix_nw;
|
||||||
int iy_nw = floor(iy);
|
int iy_sw = iy_nw + 1;
|
||||||
|
|
||||||
int ix_ne = ix_nw + 1;
|
int ix_se = ix_nw + 1;
|
||||||
int iy_ne = iy_nw;
|
int iy_se = iy_nw + 1;
|
||||||
|
|
||||||
int ix_sw = ix_nw;
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
int iy_sw = iy_nw + 1;
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
|
|
||||||
int ix_se = ix_nw + 1;
|
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||||
int iy_se = iy_nw + 1;
|
int channel_idx = elem % C_padded;
|
||||||
|
int base_idx = batch_idx + channel_idx;
|
||||||
|
|
||||||
T nw = (ix_se - ix) * (iy_se - iy);
|
T gix = T(0);
|
||||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
T giy = T(0);
|
||||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
if (channel_idx < C) {
|
||||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
int cot_index = elem / C_padded * C + channel_idx;
|
||||||
|
T cot = cotangent[cot_index];
|
||||||
|
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||||
|
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||||
|
|
||||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
T I_nw = x[offset];
|
||||||
int channel_idx = elem % C_padded;
|
gix -= I_nw * (iy_se - iy) * cot;
|
||||||
int base_idx = batch_idx + channel_idx;
|
giy -= I_nw * (ix_se - ix) * cot;
|
||||||
|
}
|
||||||
|
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||||
|
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||||
|
|
||||||
T gix = T(0);
|
T I_ne = x[offset];
|
||||||
T giy = T(0);
|
gix += I_ne * (iy_sw - iy) * cot;
|
||||||
if (channel_idx < C) {
|
giy -= I_ne * (ix - ix_sw) * cot;
|
||||||
int cot_index = elem / C_padded * C + channel_idx;
|
}
|
||||||
T cot = cotangent[cot_index];
|
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_nw = x[offset];
|
T I_sw = x[offset];
|
||||||
gix -= I_nw * (iy_se - iy) * cot;
|
gix -= I_sw * (iy - iy_ne) * cot;
|
||||||
giy -= I_nw * (ix_se - ix) * cot;
|
giy += I_sw * (ix_ne - ix) * cot;
|
||||||
}
|
}
|
||||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||||
|
|
||||||
T I_ne = x[offset];
|
T I_se = x[offset];
|
||||||
gix += I_ne * (iy_sw - iy) * cot;
|
gix += I_se * (iy - iy_nw) * cot;
|
||||||
giy -= I_ne * (ix - ix_sw) * cot;
|
giy += I_se * (ix - ix_nw) * cot;
|
||||||
}
|
}
|
||||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
}
|
||||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_sw = x[offset];
|
T gix_mult = W / 2;
|
||||||
gix -= I_sw * (iy - iy_ne) * cot;
|
T giy_mult = H / 2;
|
||||||
giy += I_sw * (ix_ne - ix) * cot;
|
|
||||||
}
|
|
||||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
|
||||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_se = x[offset];
|
// Reduce across each simdgroup first.
|
||||||
gix += I_se * (iy - iy_nw) * cot;
|
// This is much faster than relying purely on atomics.
|
||||||
giy += I_se * (ix - ix_nw) * cot;
|
gix = simd_sum(gix);
|
||||||
}
|
giy = simd_sum(giy);
|
||||||
}
|
|
||||||
|
|
||||||
T gix_mult = W / 2;
|
if (thread_index_in_simdgroup == 0) {
|
||||||
T giy_mult = H / 2;
|
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||||
|
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="grid_sample_grad",
|
||||||
|
input_names=["x", "grid", "cotangent"],
|
||||||
|
output_names=["x_grad", "grid_grad"],
|
||||||
|
source=source,
|
||||||
|
atomic_outputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
// Reduce across each simdgroup first.
|
@grid_sample.vjp
|
||||||
// This is much faster than relying purely on atomics.
|
def grid_sample_vjp(primals, cotangent, _):
|
||||||
gix = simd_sum(gix);
|
x, grid = primals
|
||||||
giy = simd_sum(giy);
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
|
||||||
if (thread_index_in_simdgroup == 0) {
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
|
||||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
# pad the output channels to simd group size
|
||||||
}
|
# so that our `simd_sum`s don't overlap.
|
||||||
"""
|
simdgroup_size = 32
|
||||||
kernel = mx.fast.metal_kernel(
|
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||||
name="grid_sample_grad",
|
grid_size = B * gN * gM * C_padded
|
||||||
input_names=["x", "grid", "cotangent"],
|
outputs = kernel(
|
||||||
output_names=["x_grad", "grid_grad"],
|
inputs=[x, grid, cotangent],
|
||||||
source=source,
|
template=[("T", x.dtype)],
|
||||||
atomic_outputs=True,
|
output_shapes=[x.shape, grid.shape],
|
||||||
)
|
output_dtypes=[x.dtype, x.dtype],
|
||||||
# pad the output channels to simd group size
|
grid=(grid_size, 1, 1),
|
||||||
# so that our `simd_sum`s don't overlap.
|
threadgroup=(256, 1, 1),
|
||||||
simdgroup_size = 32
|
init_value=0,
|
||||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
)
|
||||||
grid_size = B * gN * gM * C_padded
|
return outputs[0], outputs[1]
|
||||||
outputs = kernel(
|
|
||||||
inputs=[x, grid, cotangent],
|
|
||||||
template=[("T", x.dtype)],
|
|
||||||
output_shapes=[x.shape, grid.shape],
|
|
||||||
output_dtypes=[x.dtype, x.dtype],
|
|
||||||
grid=(grid_size, 1, 1),
|
|
||||||
threadgroup=(256, 1, 1),
|
|
||||||
init_value=0,
|
|
||||||
)
|
|
||||||
return outputs[0], outputs[1]
|
|
||||||
|
|
||||||
There's an even larger speed up for the vjp:
|
There's an even larger speed up for the vjp:
|
||||||
|
|
||||||
|
|||||||
@@ -93,9 +93,9 @@ Primitives
|
|||||||
^^^^^^^^^^^
|
^^^^^^^^^^^
|
||||||
|
|
||||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||||
defines how to create outputs arrays given a input arrays. Further, a
|
defines how to create output arrays given input arrays. Further, a
|
||||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||||
more concrete:
|
more concrete:
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
@@ -128,7 +128,7 @@ more concrete:
|
|||||||
/** The vector-Jacobian product. */
|
/** The vector-Jacobian product. */
|
||||||
std::vector<array> vjp(
|
std::vector<array> vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const array& cotan,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
@@ -138,13 +138,13 @@ more concrete:
|
|||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
*/
|
*/
|
||||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
@@ -247,9 +247,7 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
|||||||
float alpha_,
|
float alpha_,
|
||||||
float beta_,
|
float beta_,
|
||||||
mx::Stream stream) {
|
mx::Stream stream) {
|
||||||
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
// memory, potentially waiting if the system is under memory pressure
|
|
||||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// Get the CPU command encoder and register input and output arrays
|
// Get the CPU command encoder and register input and output arrays
|
||||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||||
@@ -393,17 +391,17 @@ below.
|
|||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
// Allocate output memory
|
// Allocate output memory
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
// Resolve name of kernel
|
// Resolve name of kernel
|
||||||
std::ostringstream kname;
|
std::stream kname;
|
||||||
kname << "axpby_" << "general_" << type_to_name(out);
|
kname = "axpby_general_" + type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Load the metal library
|
||||||
d.register_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname, lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@@ -471,7 +469,7 @@ one we just defined:
|
|||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Forward mode diff that pushes along the tangents
|
// Forward mode diff that pushes along the tangents
|
||||||
// The jvp transform on the primitive can built with ops
|
// The jvp transform on the primitive can be built with ops
|
||||||
// that are scheduled on the same stream as the primitive
|
// that are scheduled on the same stream as the primitive
|
||||||
|
|
||||||
// If argnums = {0}, we only push along x in which case the
|
// If argnums = {0}, we only push along x in which case the
|
||||||
@@ -483,7 +481,7 @@ one we just defined:
|
|||||||
auto scale_arr = array(scale, tangents[0].dtype());
|
auto scale_arr = array(scale, tangents[0].dtype());
|
||||||
return {multiply(scale_arr, tangents[0], stream())};
|
return {multiply(scale_arr, tangents[0], stream())};
|
||||||
}
|
}
|
||||||
// If, argnums = {0, 1}, we take contributions from both
|
// If argnums = {0, 1}, we take contributions from both
|
||||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||||
else {
|
else {
|
||||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||||
@@ -737,7 +735,7 @@ Let's look at a simple script and its results:
|
|||||||
|
|
||||||
print(f"c shape: {c.shape}")
|
print(f"c shape: {c.shape}")
|
||||||
print(f"c dtype: {c.dtype}")
|
print(f"c dtype: {c.dtype}")
|
||||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
|
|
||||||
@@ -745,7 +743,7 @@ Output:
|
|||||||
|
|
||||||
c shape: [3, 4]
|
c shape: [3, 4]
|
||||||
c dtype: float32
|
c dtype: float32
|
||||||
c correctness: True
|
c is correct: True
|
||||||
|
|
||||||
Results
|
Results
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ are the CPU and GPU.
|
|||||||
python/fft
|
python/fft
|
||||||
python/linalg
|
python/linalg
|
||||||
python/metal
|
python/metal
|
||||||
|
python/cuda
|
||||||
|
python/memory_management
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
python/distributed
|
python/distributed
|
||||||
|
|||||||
@@ -13,22 +13,51 @@ silicon computer is
|
|||||||
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
|
|
||||||
To install from PyPI you must meet the following requirements:
|
To install from PyPI your system must meet the following requirements:
|
||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.9
|
- Using a native Python >= 3.10
|
||||||
- macOS >= 13.5
|
- macOS >= 14.0
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
MLX is only available on devices running macOS >= 13.5
|
MLX is only available on devices running macOS >= 14.0 and higher.
|
||||||
It is highly recommended to use macOS 14 (Sonoma)
|
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
MLX is also available on conda-forge. To install MLX with conda do:
|
MLX has a CUDA backend which you can install with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
conda install conda-forge::mlx
|
pip install mlx[cuda12]
|
||||||
|
|
||||||
|
|
||||||
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Nvidia architecture >= SM 7.5
|
||||||
|
- Nvidia driver >= 550.54.14
|
||||||
|
- CUDA toolkit >= 12.0
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.10
|
||||||
|
|
||||||
|
For CUDA 13 use ``pip install mlx[cuda13]``. The CUDA 13 package requires
|
||||||
|
an Nvidia driver >= 580 or an appropriate CUDA compatibility package.
|
||||||
|
|
||||||
|
CPU-only (Linux)
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
For a CPU-only version of MLX that runs on Linux use:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install mlx[cpu]
|
||||||
|
|
||||||
|
To install the CPU-only package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.10
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
@@ -65,6 +94,8 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
.. _python install:
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@@ -76,20 +107,20 @@ Then simply build and install MLX using pip:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
pip install .
|
||||||
|
|
||||||
For developing, install the package with development dependencies, and use an
|
For developing, install the package with development dependencies, and use an
|
||||||
editable install:
|
editable install:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
Once the development dependencies are installed, you can build faster with:
|
Once the development dependencies are installed, you can build faster with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
python setup.py build_ext --inplace
|
||||||
|
|
||||||
Run the tests with:
|
Run the tests with:
|
||||||
|
|
||||||
@@ -107,6 +138,8 @@ IDE:
|
|||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
.. _cpp install:
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@@ -185,6 +218,7 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -213,6 +247,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
|||||||
application. Once a kernel is compiled, it will be cached by the system. The
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
Metal kernel cache persists across reboots.
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
|
Linux
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||||
|
For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
apt-get update -y
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
From here follow the instructions to install either the :ref:`Python <python
|
||||||
|
install>` or :ref:`C++ <cpp install>` APIs.
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||||
|
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
apt-get update -y
|
||||||
|
apt-get -y install cuda-toolkit-12-9
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
||||||
|
|
||||||
|
|
||||||
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||||
|
|
||||||
|
To build the C++ package run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ Array
|
|||||||
array.ndim
|
array.ndim
|
||||||
array.shape
|
array.shape
|
||||||
array.size
|
array.size
|
||||||
|
array.real
|
||||||
|
array.imag
|
||||||
array.abs
|
array.abs
|
||||||
array.all
|
array.all
|
||||||
array.any
|
array.any
|
||||||
@@ -38,6 +40,7 @@ Array
|
|||||||
array.log10
|
array.log10
|
||||||
array.log1p
|
array.log1p
|
||||||
array.log2
|
array.log2
|
||||||
|
array.logcumsumexp
|
||||||
array.logsumexp
|
array.logsumexp
|
||||||
array.max
|
array.max
|
||||||
array.mean
|
array.mean
|
||||||
|
|||||||
9
docs/src/python/cuda.rst
Normal file
9
docs/src/python/cuda.rst
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
CUDA
|
||||||
|
=====
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.cuda
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
is_available
|
||||||
@@ -13,3 +13,4 @@ Fast
|
|||||||
rope
|
rope
|
||||||
scaled_dot_product_attention
|
scaled_dot_product_attention
|
||||||
metal_kernel
|
metal_kernel
|
||||||
|
cuda_kernel
|
||||||
|
|||||||
@@ -20,3 +20,5 @@ FFT
|
|||||||
irfft2
|
irfft2
|
||||||
rfftn
|
rfftn
|
||||||
irfftn
|
irfftn
|
||||||
|
fftshift
|
||||||
|
ifftshift
|
||||||
|
|||||||
@@ -16,9 +16,12 @@ Linear Algebra
|
|||||||
cross
|
cross
|
||||||
qr
|
qr
|
||||||
svd
|
svd
|
||||||
|
eigvals
|
||||||
|
eig
|
||||||
eigvalsh
|
eigvalsh
|
||||||
eigh
|
eigh
|
||||||
lu
|
lu
|
||||||
lu_factor
|
lu_factor
|
||||||
|
pinv
|
||||||
solve
|
solve
|
||||||
solve_triangular
|
solve_triangular
|
||||||
|
|||||||
16
docs/src/python/memory_management.rst
Normal file
16
docs/src/python/memory_management.rst
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
Memory Management
|
||||||
|
=================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
get_active_memory
|
||||||
|
get_peak_memory
|
||||||
|
reset_peak_memory
|
||||||
|
get_cache_memory
|
||||||
|
set_memory_limit
|
||||||
|
set_cache_limit
|
||||||
|
set_wired_limit
|
||||||
|
clear_cache
|
||||||
@@ -8,13 +8,5 @@ Metal
|
|||||||
|
|
||||||
is_available
|
is_available
|
||||||
device_info
|
device_info
|
||||||
get_active_memory
|
|
||||||
get_peak_memory
|
|
||||||
reset_peak_memory
|
|
||||||
get_cache_memory
|
|
||||||
set_memory_limit
|
|
||||||
set_cache_limit
|
|
||||||
set_wired_limit
|
|
||||||
clear_cache
|
|
||||||
start_capture
|
start_capture
|
||||||
stop_capture
|
stop_capture
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ simple functions.
|
|||||||
mish
|
mish
|
||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
|
relu2
|
||||||
relu6
|
relu6
|
||||||
selu
|
selu
|
||||||
sigmoid
|
sigmoid
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ Layers
|
|||||||
QuantizedLinear
|
QuantizedLinear
|
||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
|
ReLU2
|
||||||
ReLU6
|
ReLU6
|
||||||
RNN
|
RNN
|
||||||
RoPE
|
RoPE
|
||||||
|
|||||||
@@ -36,10 +36,12 @@ Operations
|
|||||||
bitwise_or
|
bitwise_or
|
||||||
bitwise_xor
|
bitwise_xor
|
||||||
block_masked_mm
|
block_masked_mm
|
||||||
|
broadcast_arrays
|
||||||
broadcast_to
|
broadcast_to
|
||||||
ceil
|
ceil
|
||||||
clip
|
clip
|
||||||
concatenate
|
concatenate
|
||||||
|
contiguous
|
||||||
conj
|
conj
|
||||||
conjugate
|
conjugate
|
||||||
convolve
|
convolve
|
||||||
@@ -101,6 +103,7 @@ Operations
|
|||||||
log10
|
log10
|
||||||
log1p
|
log1p
|
||||||
logaddexp
|
logaddexp
|
||||||
|
logcumsumexp
|
||||||
logical_not
|
logical_not
|
||||||
logical_and
|
logical_and
|
||||||
logical_or
|
logical_or
|
||||||
@@ -109,6 +112,7 @@ Operations
|
|||||||
max
|
max
|
||||||
maximum
|
maximum
|
||||||
mean
|
mean
|
||||||
|
median
|
||||||
meshgrid
|
meshgrid
|
||||||
min
|
min
|
||||||
minimum
|
minimum
|
||||||
|
|||||||
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
# Save the state
|
# Save the state
|
||||||
state = tree_flatten(optimizer.state)
|
state = tree_flatten(optimizer.state, destination={})
|
||||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
mx.save_safetensors("optimizer.safetensors", state)
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
# Later on, for example when loading from a checkpoint,
|
||||||
# recreate the optimizer and load the state
|
# recreate the optimizer and load the state
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||||
optimizer.state = state
|
optimizer.state = state
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
|||||||
@@ -18,3 +18,5 @@ Common Optimizers
|
|||||||
AdamW
|
AdamW
|
||||||
Adamax
|
Adamax
|
||||||
Lion
|
Lion
|
||||||
|
MultiOptimizer
|
||||||
|
Muon
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ Transforms
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
eval
|
eval
|
||||||
|
async_eval
|
||||||
compile
|
compile
|
||||||
custom_function
|
custom_function
|
||||||
disable_compile
|
disable_compile
|
||||||
|
|||||||
@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(32, 1000, 4096))
|
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||||
timeit(nn.gelu, x)
|
timeit(gelu, x)
|
||||||
timeit(mx.compile(nn.gelu), x)
|
timeit(mx.compile(gelu), x)
|
||||||
|
|
||||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||||
five times faster.
|
five times faster.
|
||||||
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
|
|||||||
def fun(x, y):
|
def fun(x, y):
|
||||||
z = x + y
|
z = x + y
|
||||||
state.append(z)
|
state.append(z)
|
||||||
return mx.exp(z), state
|
return mx.exp(z)
|
||||||
|
|
||||||
fun(mx.array(1.0), mx.array(2.0))
|
fun(mx.array(1.0), mx.array(2.0))
|
||||||
# Prints [array(3, dtype=float32)]
|
# Prints [array(3, dtype=float32)]
|
||||||
|
|||||||
@@ -7,12 +7,13 @@ Distributed Communication
|
|||||||
|
|
||||||
MLX supports distributed communication operations that allow the computational cost
|
MLX supports distributed communication operations that allow the computational cost
|
||||||
of training or inference to be shared across many physical machines. At the
|
of training or inference to be shared across many physical machines. At the
|
||||||
moment we support two different communication backends:
|
moment we support three different communication backends:
|
||||||
|
|
||||||
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
||||||
full-featured and mature distributed communications library
|
full-featured and mature distributed communications library
|
||||||
* A **ring** backend of our own that uses native TCP sockets and should be
|
* A **ring** backend of our own that uses native TCP sockets. It should be
|
||||||
faster for thunderbolt connections.
|
faster for thunderbolt connections, but it also works over Ethernet.
|
||||||
|
* `nccl <https://developer.nvidia.com/nccl>`_, for use in CUDA environments.
|
||||||
|
|
||||||
The list of all currently supported operations and their documentation can be
|
The list of all currently supported operations and their documentation can be
|
||||||
seen in the :ref:`API docs<distributed>`.
|
seen in the :ref:`API docs<distributed>`.
|
||||||
@@ -84,9 +85,8 @@ Selecting Backend
|
|||||||
^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
You can select the backend you want to use when calling :func:`init` by passing
|
You can select the backend you want to use when calling :func:`init` by passing
|
||||||
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
|
one of ``{'any', 'ring', 'mpi', 'nccl'}``. When passing ``any``, MLX will try all
|
||||||
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
|
available backends. If they all fail then a singleton group is created.
|
||||||
both fail then a singleton group is created.
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
After a distributed backend is successfully initialized :func:`init` will
|
After a distributed backend is successfully initialized :func:`init` will
|
||||||
@@ -184,7 +184,7 @@ almost identical to the example above:
|
|||||||
|
|
||||||
def step(model, x, y):
|
def step(model, x, y):
|
||||||
loss, grads = loss_grad_fn(model, x, y)
|
loss, grads = loss_grad_fn(model, x, y)
|
||||||
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
grads = mx.nn.average_gradients(grads) # <---- This line was added
|
||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@@ -220,7 +220,7 @@ print 4 etc.
|
|||||||
Installing MPI
|
Installing MPI
|
||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
MPI can be installed with Homebrew, pip, using the Anaconda package manager, or
|
||||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||||
with the Anaconda package manager as follows:
|
with the Anaconda package manager as follows:
|
||||||
|
|
||||||
@@ -228,14 +228,16 @@ with the Anaconda package manager as follows:
|
|||||||
|
|
||||||
$ conda install conda-forge::openmpi
|
$ conda install conda-forge::openmpi
|
||||||
|
|
||||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
Installing with Homebrew or pip requires specifying the location of ``libmpi.dyld``
|
||||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
||||||
done automatically by ``mlx.launch``.
|
done automatically by ``mlx.launch``. Some environments use a non-standard
|
||||||
|
library filename that can be specified using the ``MPI_LIBNAME`` environment
|
||||||
|
variable. This is automatically taken care of by ``mlx.launch`` as well.
|
||||||
|
|
||||||
.. code:: shell
|
.. code:: shell
|
||||||
|
|
||||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ -x MPI_LIBNAME=libmpi.40.dylib python test.py
|
||||||
$ # or simply
|
$ # or simply
|
||||||
$ mlx.launch -n 2 test.py
|
$ mlx.launch -n 2 test.py
|
||||||
|
|
||||||
|
|||||||
@@ -7,17 +7,17 @@ Exporting Functions
|
|||||||
|
|
||||||
MLX has an API to export and import functions to and from a file. This lets you
|
MLX has an API to export and import functions to and from a file. This lets you
|
||||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||||
front-end (e.g. C++).
|
front-end (e.g. C++).
|
||||||
|
|
||||||
This guide walks through the basics of the MLX export API with some examples.
|
This guide walks through the basics of the MLX export API with some examples.
|
||||||
To see the full list of functions check-out the :ref:`API documentation
|
To see the full list of functions check-out the :ref:`API documentation
|
||||||
<export>`.
|
<export>`.
|
||||||
|
|
||||||
Basics of Exporting
|
Basics of Exporting
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
Let's start with a simple example:
|
Let's start with a simple example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def fun(x, y):
|
def fun(x, y):
|
||||||
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
|
|||||||
|
|
||||||
x = mx.array(1.0)
|
x = mx.array(1.0)
|
||||||
y = mx.array(1.0)
|
y = mx.array(1.0)
|
||||||
|
|
||||||
# Both arguments to fun are positional
|
# Both arguments to fun are positional
|
||||||
mx.export_function("add.mlxfn", fun, x, y)
|
mx.export_function("add.mlxfn", fun, x, y)
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
|
|||||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||||
they are evaluated. The computation graph that gets exported will include
|
they are evaluated. The computation graph that gets exported will include
|
||||||
the computation that produces enclosed inputs.
|
the computation that produces enclosed inputs.
|
||||||
|
|
||||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||||
exported function would include the random initialization of the
|
exported function would include the random initialization of the
|
||||||
:obj:`mlx.nn.Module` parameters.
|
:obj:`mlx.nn.Module` parameters.
|
||||||
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
|||||||
# Set the model's parameters to the input parameters
|
# Set the model's parameters to the input parameters
|
||||||
model.update(tree_unflatten(list(params.items())))
|
model.update(tree_unflatten(list(params.items())))
|
||||||
return model(x)
|
return model(x)
|
||||||
|
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||||
|
|
||||||
|
|
||||||
@@ -164,13 +164,13 @@ to export a function which can be used for inputs with variable shapes:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
|
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
|
||||||
imported_abs = mx.import_function("fun.mlxfn")
|
imported_abs = mx.import_function("fun.mlxfn")
|
||||||
|
|
||||||
# Ok
|
# Ok
|
||||||
out, = imported_abs(mx.array(-1.0))
|
out, = imported_abs(mx.array([-1.0]))
|
||||||
|
|
||||||
# Also ok
|
# Also ok
|
||||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||||
|
|
||||||
With ``shapeless=False`` (which is the default), the second call to
|
With ``shapeless=False`` (which is the default), the second call to
|
||||||
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
|||||||
def fun(x, y=None):
|
def fun(x, y=None):
|
||||||
constant = mx.array(3.0)
|
constant = mx.array(3.0)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
x += y
|
x += y
|
||||||
return x + constant
|
return x + constant
|
||||||
|
|
||||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||||
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
|||||||
print(out)
|
print(out)
|
||||||
|
|
||||||
In the above example the function constant data, (i.e. ``constant``), is only
|
In the above example the function constant data, (i.e. ``constant``), is only
|
||||||
saved once.
|
saved once.
|
||||||
|
|
||||||
Transformations with Imported Functions
|
Transformations with Imported Functions
|
||||||
---------------------------------------
|
---------------------------------------
|
||||||
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
|
|||||||
# Prints: array(1, dtype=float32)
|
# Prints: array(1, dtype=float32)
|
||||||
print(dfdx(x))
|
print(dfdx(x))
|
||||||
|
|
||||||
# Compile the imported function
|
# Compile the imported function
|
||||||
mx.compile(imported_fun)
|
mx.compile(imported_fun)
|
||||||
# Prints: array(0, dtype=float32)
|
# Prints: array(0, dtype=float32)
|
||||||
print(compiled_fun(x)[0])
|
print(compiled_fun(x)[0])
|
||||||
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
|
|||||||
// Prints: array(2, dtype=float32)
|
// Prints: array(2, dtype=float32)
|
||||||
std::cout << outputs[0] << std::endl;
|
std::cout << outputs[0] << std::endl;
|
||||||
|
|
||||||
Imported functions can be transformed in C++ just like in Python. Use
|
Imported functions can be transformed in C++ just like in Python. Use
|
||||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||||
|
|
||||||
|
|||||||
@@ -70,7 +70,8 @@ Differences from NumPy
|
|||||||
|
|
||||||
* Indexing does not perform bounds checking. Indexing out of bounds is
|
* Indexing does not perform bounds checking. Indexing out of bounds is
|
||||||
undefined behavior.
|
undefined behavior.
|
||||||
* Boolean mask based indexing is not yet supported.
|
* Boolean mask based indexing is supported for assignment only (see
|
||||||
|
:ref:`boolean-mask-assignment`).
|
||||||
|
|
||||||
The reason for the lack of bounds checking is that exceptions cannot propagate
|
The reason for the lack of bounds checking is that exceptions cannot propagate
|
||||||
from the GPU. Performing bounds checking for array indices before launching the
|
from the GPU. Performing bounds checking for array indices before launching the
|
||||||
@@ -107,6 +108,28 @@ same array:
|
|||||||
>>> a
|
>>> a
|
||||||
array([1, 2, 0], dtype=int32)
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
Note that unlike NumPy, slicing an array creates a copy, not a view. So
|
||||||
|
mutating it does not mutate the original array:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> b = a[:]
|
||||||
|
>>> b[2] = 0
|
||||||
|
>>> b
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
>>> a
|
||||||
|
array([1, 2, 3], dtype=int32)
|
||||||
|
|
||||||
|
Also unlike NumPy, updates to the same location are nondeterministic:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> a[[0, 0]] = mx.array([4, 5])
|
||||||
|
|
||||||
|
The first element of ``a`` could be ``4`` or ``5``.
|
||||||
|
|
||||||
Transformations of functions which use in-place updates are allowed and work as
|
Transformations of functions which use in-place updates are allowed and work as
|
||||||
expected. For example:
|
expected. For example:
|
||||||
|
|
||||||
@@ -121,3 +144,51 @@ expected. For example:
|
|||||||
|
|
||||||
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
||||||
and ones elsewhere.
|
and ones elsewhere.
|
||||||
|
|
||||||
|
.. _boolean-mask-assignment:
|
||||||
|
|
||||||
|
Boolean Mask Assignment
|
||||||
|
-----------------------
|
||||||
|
|
||||||
|
MLX supports boolean indices using NumPy syntax. A mask must already be
|
||||||
|
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
|
||||||
|
Other index types are routed through the standard scatter code.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1.0, 2.0, 3.0])
|
||||||
|
>>> mask = mx.array([True, False, True])
|
||||||
|
>>> updates = mx.array([5.0, 6.0])
|
||||||
|
>>> a[mask] = updates
|
||||||
|
>>> a
|
||||||
|
array([5.0, 2.0, 6.0], dtype=float32)
|
||||||
|
|
||||||
|
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
|
||||||
|
assignments, ``updates`` must provide at least as many elements as there are
|
||||||
|
``True`` entries in ``mask``.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.zeros((2, 3))
|
||||||
|
>>> mask = mx.array([[True, False, True],
|
||||||
|
[False, False, True]])
|
||||||
|
>>> a[mask] = 1.0
|
||||||
|
>>> a
|
||||||
|
array([[1.0, 0.0, 1.0],
|
||||||
|
[0.0, 0.0, 1.0]], dtype=float32)
|
||||||
|
|
||||||
|
Boolean masks follow NumPy semantics:
|
||||||
|
|
||||||
|
- The mask shape must match the shape of the axes it indexes exactly. The only
|
||||||
|
exception is a scalar boolean mask, which broadcasts to the full array.
|
||||||
|
- Any axes not covered by the mask are taken in full.
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
||||||
|
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||||
|
|
||||||
|
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
||||||
|
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
||||||
|
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
|
||||||
|
axes and therefore raise errors.
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023-2025 Apple Inc.
|
// Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
@@ -16,6 +17,19 @@
|
|||||||
|
|
||||||
namespace my_ext {
|
namespace my_ext {
|
||||||
|
|
||||||
|
// A helper function to find the location of the current binary on disk.
|
||||||
|
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
|
||||||
|
std::string current_binary_dir() {
|
||||||
|
static std::string binary_dir = []() {
|
||||||
|
Dl_info info;
|
||||||
|
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||||
|
throw std::runtime_error("Unable to get current binary dir.");
|
||||||
|
}
|
||||||
|
return std::filesystem::path(info.dli_fname).parent_path().string();
|
||||||
|
}();
|
||||||
|
return binary_dir;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Operation Implementation
|
// Operation Implementation
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -72,9 +86,7 @@ void axpby_impl(
|
|||||||
float alpha_,
|
float alpha_,
|
||||||
float beta_,
|
float beta_,
|
||||||
mx::Stream stream) {
|
mx::Stream stream) {
|
||||||
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
// memory, potentially waiting if the system is under memory pressure
|
|
||||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// Get the CPU command encoder and register input and output arrays
|
// Get the CPU command encoder and register input and output arrays
|
||||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||||
@@ -160,25 +172,24 @@ void Axpby::eval_gpu(
|
|||||||
// Allocate output memory with strides based on specialization
|
// Allocate output memory with strides based on specialization
|
||||||
if (contiguous_kernel) {
|
if (contiguous_kernel) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
} else {
|
} else {
|
||||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve name of kernel (corresponds to axpby.metal)
|
// Resolve name of kernel (corresponds to axpby.metal)
|
||||||
std::ostringstream kname;
|
std::string kname = "axpby_";
|
||||||
kname << "axpby_";
|
kname += (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
kname += type_to_name(out);
|
||||||
kname << type_to_name(out);
|
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Load the metal library
|
||||||
d.register_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname, lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|||||||
@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
|
|||||||
const std::vector<mx::array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
setuptools>=42
|
setuptools>=42
|
||||||
cmake>=3.25
|
cmake>=3.25
|
||||||
mlx>=0.21.0
|
mlx>=0.21.0
|
||||||
nanobind==2.2.0
|
nanobind==2.4.0
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
|
|||||||
|
|
||||||
a = mx.ones((3, 4))
|
a = mx.ones((3, 4))
|
||||||
b = mx.ones((3, 4))
|
b = mx.ones((3, 4))
|
||||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||||
|
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
|
||||||
|
|
||||||
print(f"c shape: {c.shape}")
|
print(f"c shape: {c_cpu.shape}")
|
||||||
print(f"c dtype: {c.dtype}")
|
print(f"c dtype: {c_cpu.dtype}")
|
||||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
|
||||||
|
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
@@ -20,7 +20,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||||
|
|
||||||
# Define MLX_VERSION only in the version.cpp file.
|
# Define MLX_VERSION only in the version.cpp file.
|
||||||
add_library(mlx_version STATIC ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||||
|
|
||||||
@@ -48,5 +48,19 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
|||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||||
|
else()
|
||||||
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||||
|
else()
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/scheduler.h"
|
|
||||||
|
|
||||||
namespace mlx::core::allocator {
|
|
||||||
|
|
||||||
Buffer malloc(size_t size) {
|
|
||||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
|
||||||
if (size && !buffer.ptr()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
void free(Buffer buffer) {
|
|
||||||
allocator().free(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
|
||||||
void* ptr = std::malloc(size + sizeof(size_t));
|
|
||||||
if (ptr != nullptr) {
|
|
||||||
*static_cast<size_t*>(ptr) = size;
|
|
||||||
}
|
|
||||||
return Buffer{ptr};
|
|
||||||
}
|
|
||||||
|
|
||||||
void CommonAllocator::free(Buffer buffer) {
|
|
||||||
std::free(buffer.ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CommonAllocator::size(Buffer buffer) const {
|
|
||||||
if (buffer.ptr() == nullptr) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return *static_cast<size_t*>(buffer.ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
Buffer malloc_or_wait(size_t size) {
|
|
||||||
auto buffer = allocator().malloc(size);
|
|
||||||
|
|
||||||
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
|
||||||
scheduler::wait_for_one();
|
|
||||||
buffer = allocator().malloc(size);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try swapping if needed
|
|
||||||
if (size && !buffer.ptr()) {
|
|
||||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (size && !buffer.ptr()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
|
||||||
@@ -14,7 +14,7 @@ class Buffer {
|
|||||||
void* ptr_;
|
void* ptr_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Buffer(void* ptr) : ptr_(ptr) {};
|
explicit Buffer(void* ptr) : ptr_(ptr) {};
|
||||||
|
|
||||||
// Get the raw data pointer from the buffer
|
// Get the raw data pointer from the buffer
|
||||||
void* raw_ptr();
|
void* raw_ptr();
|
||||||
@@ -28,20 +28,16 @@ class Buffer {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
Buffer malloc(size_t size);
|
|
||||||
|
|
||||||
void free(Buffer buffer);
|
|
||||||
|
|
||||||
// Wait for running tasks to finish and free up memory
|
|
||||||
// if allocation fails
|
|
||||||
Buffer malloc_or_wait(size_t size);
|
|
||||||
|
|
||||||
class Allocator {
|
class Allocator {
|
||||||
/** Abstract base class for a memory allocator. */
|
/** Abstract base class for a memory allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
virtual Buffer malloc(size_t size) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
virtual size_t size(Buffer buffer) const = 0;
|
virtual size_t size(Buffer buffer) const = 0;
|
||||||
|
virtual Buffer make_buffer(void* ptr, size_t size) {
|
||||||
|
return Buffer{nullptr};
|
||||||
|
};
|
||||||
|
virtual void release(Buffer buffer) {}
|
||||||
|
|
||||||
Allocator() = default;
|
Allocator() = default;
|
||||||
Allocator(const Allocator& other) = delete;
|
Allocator(const Allocator& other) = delete;
|
||||||
@@ -53,16 +49,25 @@ class Allocator {
|
|||||||
|
|
||||||
Allocator& allocator();
|
Allocator& allocator();
|
||||||
|
|
||||||
class CommonAllocator : public Allocator {
|
inline Buffer malloc(size_t size) {
|
||||||
/** A general CPU allocator. */
|
return allocator().malloc(size);
|
||||||
public:
|
}
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
|
||||||
virtual void free(Buffer buffer) override;
|
|
||||||
virtual size_t size(Buffer buffer) const override;
|
|
||||||
|
|
||||||
private:
|
inline void free(Buffer buffer) {
|
||||||
CommonAllocator() = default;
|
allocator().free(buffer);
|
||||||
friend Allocator& allocator();
|
}
|
||||||
|
|
||||||
|
// Make a Buffer from a raw pointer of the given size without a copy. If a
|
||||||
|
// no-copy conversion is not possible then the returned buffer.ptr() will be
|
||||||
|
// nullptr. Any buffer created with this function must be released with
|
||||||
|
// release(buffer)
|
||||||
|
inline Buffer make_buffer(void* ptr, size_t size) {
|
||||||
|
return allocator().make_buffer(ptr, size);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Release a buffer from the allocator made with make_buffer
|
||||||
|
inline void release(Buffer buffer) {
|
||||||
|
allocator().release(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ array array::unsafe_weak_copy(const array& other) {
|
|||||||
other.strides(),
|
other.strides(),
|
||||||
other.flags(),
|
other.flags(),
|
||||||
[](auto) {});
|
[](auto) {});
|
||||||
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
cpy.array_desc_->offset = other.array_desc_->offset;
|
||||||
return cpy;
|
return cpy;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,6 +82,28 @@ array::array(std::initializer_list<int> data, Dtype dtype)
|
|||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array::array(
|
||||||
|
void* data,
|
||||||
|
Shape shape,
|
||||||
|
Dtype dtype,
|
||||||
|
const std::function<void(void*)>& deleter)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
|
auto buffer = allocator::make_buffer(data, nbytes());
|
||||||
|
if (buffer.ptr() == nullptr) {
|
||||||
|
set_data(allocator::malloc(nbytes()));
|
||||||
|
auto ptr = static_cast<char*>(data);
|
||||||
|
std::copy(ptr, ptr + nbytes(), this->data<char>());
|
||||||
|
deleter(data);
|
||||||
|
} else {
|
||||||
|
auto wrapped_deleter = [deleter](allocator::Buffer buffer) {
|
||||||
|
auto ptr = buffer.ptr();
|
||||||
|
allocator::release(buffer);
|
||||||
|
return deleter(ptr);
|
||||||
|
};
|
||||||
|
set_data(buffer, std::move(wrapped_deleter));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* Build an array from a shared buffer */
|
/* Build an array from a shared buffer */
|
||||||
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
@@ -141,7 +163,7 @@ bool array::is_tracer() const {
|
|||||||
|
|
||||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||||
array_desc_->data_ptr = buffer.raw_ptr();
|
array_desc_->offset = 0;
|
||||||
array_desc_->data_size = size();
|
array_desc_->data_size = size();
|
||||||
array_desc_->flags.contiguous = true;
|
array_desc_->flags.contiguous = true;
|
||||||
array_desc_->flags.row_contiguous = true;
|
array_desc_->flags.row_contiguous = true;
|
||||||
@@ -156,7 +178,7 @@ void array::set_data(
|
|||||||
Flags flags,
|
Flags flags,
|
||||||
Deleter d) {
|
Deleter d) {
|
||||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||||
array_desc_->data_ptr = buffer.raw_ptr();
|
array_desc_->offset = 0;
|
||||||
array_desc_->data_size = data_size;
|
array_desc_->data_size = data_size;
|
||||||
array_desc_->strides = std::move(strides);
|
array_desc_->strides = std::move(strides);
|
||||||
array_desc_->flags = flags;
|
array_desc_->flags = flags;
|
||||||
@@ -167,14 +189,13 @@ void array::copy_shared_buffer(
|
|||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
Flags flags,
|
Flags flags,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
size_t offset /* = 0 */) {
|
int64_t offset /* = 0 */) {
|
||||||
array_desc_->data = other.array_desc_->data;
|
array_desc_->data = other.array_desc_->data;
|
||||||
array_desc_->strides = strides;
|
array_desc_->strides = strides;
|
||||||
array_desc_->flags = flags;
|
array_desc_->flags = flags;
|
||||||
array_desc_->data_size = data_size;
|
array_desc_->data_size = data_size;
|
||||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
array_desc_->offset =
|
||||||
array_desc_->data_ptr = static_cast<void*>(
|
sizeof(char) * itemsize() * offset + other.array_desc_->offset;
|
||||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::copy_shared_buffer(const array& other) {
|
void array::copy_shared_buffer(const array& other) {
|
||||||
@@ -241,8 +262,8 @@ array::ArrayDesc::ArrayDesc(
|
|||||||
std::vector<array> inputs)
|
std::vector<array> inputs)
|
||||||
: shape(std::move(shape)),
|
: shape(std::move(shape)),
|
||||||
dtype(dtype),
|
dtype(dtype),
|
||||||
status(Status::unscheduled),
|
|
||||||
primitive(std::move(primitive)),
|
primitive(std::move(primitive)),
|
||||||
|
status(Status::unscheduled),
|
||||||
inputs(std::move(inputs)) {
|
inputs(std::move(inputs)) {
|
||||||
init();
|
init();
|
||||||
}
|
}
|
||||||
|
|||||||
52
mlx/array.h
52
mlx/array.h
@@ -10,6 +10,7 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/dtype.h"
|
#include "mlx/dtype.h"
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
|
#include "mlx/small_vector.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -18,8 +19,8 @@ class Primitive;
|
|||||||
|
|
||||||
using Deleter = std::function<void(allocator::Buffer)>;
|
using Deleter = std::function<void(allocator::Buffer)>;
|
||||||
using ShapeElem = int32_t;
|
using ShapeElem = int32_t;
|
||||||
using Shape = std::vector<ShapeElem>;
|
using Shape = SmallVector<ShapeElem>;
|
||||||
using Strides = std::vector<int64_t>;
|
using Strides = SmallVector<int64_t>;
|
||||||
|
|
||||||
class array {
|
class array {
|
||||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||||
@@ -56,6 +57,16 @@ class array {
|
|||||||
Shape shape,
|
Shape shape,
|
||||||
Dtype dtype = TypeToDtype<T>());
|
Dtype dtype = TypeToDtype<T>());
|
||||||
|
|
||||||
|
/* Build an array from a raw pointer. The constructor will attempt to use the
|
||||||
|
* input data without a copy. The deleter will be called when the array no
|
||||||
|
* longer needs the underlying memory - after the array is destroyed in the
|
||||||
|
* no-copy case and after the copy otherwise. */
|
||||||
|
explicit array(
|
||||||
|
void* data,
|
||||||
|
Shape shape,
|
||||||
|
Dtype dtype,
|
||||||
|
const std::function<void(void*)>& deleter);
|
||||||
|
|
||||||
/* Build an array from a buffer */
|
/* Build an array from a buffer */
|
||||||
explicit array(
|
explicit array(
|
||||||
allocator::Buffer data,
|
allocator::Buffer data,
|
||||||
@@ -224,6 +235,10 @@ class array {
|
|||||||
// Not copyable
|
// Not copyable
|
||||||
Data(const Data& d) = delete;
|
Data(const Data& d) = delete;
|
||||||
Data& operator=(const Data& d) = delete;
|
Data& operator=(const Data& d) = delete;
|
||||||
|
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||||
|
o.buffer = allocator::Buffer(nullptr);
|
||||||
|
o.d = [](allocator::Buffer) {};
|
||||||
|
}
|
||||||
~Data() {
|
~Data() {
|
||||||
d(buffer);
|
d(buffer);
|
||||||
}
|
}
|
||||||
@@ -289,6 +304,11 @@ class array {
|
|||||||
return array_desc_->siblings;
|
return array_desc_->siblings;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** The array's position in the sibling list. */
|
||||||
|
int sibling_position() const {
|
||||||
|
return array_desc_->position;
|
||||||
|
}
|
||||||
|
|
||||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||||
array_desc_->siblings = std::move(siblings);
|
array_desc_->siblings = std::move(siblings);
|
||||||
array_desc_->position = position;
|
array_desc_->position = position;
|
||||||
@@ -339,24 +359,32 @@ class array {
|
|||||||
return allocator::allocator().size(buffer());
|
return allocator::allocator().size(buffer());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a copy of the shared pointer
|
// Return the shared pointer to the array::Data struct
|
||||||
// to the array::Data struct
|
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||||
std::shared_ptr<Data> data_shared_ptr() const {
|
|
||||||
return array_desc_->data;
|
return array_desc_->data;
|
||||||
}
|
}
|
||||||
// Return a raw pointer to the arrays data
|
|
||||||
|
// Return a raw pointer to the arrays data. This function may do a copy if
|
||||||
|
// the underlying buffer is not accessible on the CPU. When accessing the
|
||||||
|
// data for GPU kernels, be sure to use the correct method / function for the
|
||||||
|
// given backend to access the GPU pointer.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* data() {
|
T* data() {
|
||||||
return static_cast<T*>(array_desc_->data_ptr);
|
return reinterpret_cast<T*>(
|
||||||
|
(static_cast<char*>(buffer().raw_ptr()) + array_desc_->offset));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
const T* data() const {
|
const T* data() const {
|
||||||
return static_cast<T*>(array_desc_->data_ptr);
|
return const_cast<array&>(*this).data<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t offset() const {
|
||||||
|
return array_desc_->offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Status {
|
enum Status {
|
||||||
// The ouptut of a computation which has not been scheduled.
|
// The output of a computation which has not been scheduled.
|
||||||
// For example, the status of `x` in `auto x = a + b`.
|
// For example, the status of `x` in `auto x = a + b`.
|
||||||
unscheduled,
|
unscheduled,
|
||||||
|
|
||||||
@@ -421,7 +449,7 @@ class array {
|
|||||||
const Strides& strides,
|
const Strides& strides,
|
||||||
Flags flags,
|
Flags flags,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
size_t offset = 0);
|
int64_t offset = 0);
|
||||||
|
|
||||||
void copy_shared_buffer(const array& other);
|
void copy_shared_buffer(const array& other);
|
||||||
|
|
||||||
@@ -456,8 +484,8 @@ class array {
|
|||||||
// can share the underlying data buffer.
|
// can share the underlying data buffer.
|
||||||
std::shared_ptr<Data> data;
|
std::shared_ptr<Data> data;
|
||||||
|
|
||||||
// Properly offset data pointer
|
// Offset from beginning of data pointer
|
||||||
void* data_ptr{nullptr};
|
int64_t offset{0};
|
||||||
|
|
||||||
// The size in elements of the data buffer the array accesses
|
// The size in elements of the data buffer the array accesses
|
||||||
size_t data_size;
|
size_t data_size;
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
|
|||||||
@@ -38,20 +38,20 @@ inline void set_binary_op_output_data(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
array& out,
|
array& out,
|
||||||
BinaryOpType bopt) {
|
BinaryOpType bopt,
|
||||||
|
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||||
bool b_donatable = is_donatable(b, out);
|
bool b_donatable = is_donatable(b, out);
|
||||||
bool a_donatable = is_donatable(a, out);
|
bool a_donatable = is_donatable(a, out);
|
||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::ScalarScalar:
|
case BinaryOpType::ScalarScalar:
|
||||||
out.set_data(
|
out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags());
|
||||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
if (b_donatable) {
|
if (b_donatable) {
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
mallocfn(b.data_size() * out.itemsize()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
|
|||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
mallocfn(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
|
|||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
mallocfn(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
|
|||||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(mallocfn(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
24
mlx/backend/common/broadcasting.cpp
Normal file
24
mlx/backend/common/broadcasting.cpp
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void broadcast(const array& in, array& out) {
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(allocator::malloc(0));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Strides strides(out.ndim(), 0);
|
||||||
|
int diff = out.ndim() - in.ndim();
|
||||||
|
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||||
|
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||||
|
}
|
||||||
|
auto flags = in.flags();
|
||||||
|
if (out.size() > in.size()) {
|
||||||
|
flags.row_contiguous = flags.col_contiguous = false;
|
||||||
|
}
|
||||||
|
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
11
mlx/backend/common/broadcasting.h
Normal file
11
mlx/backend/common/broadcasting.h
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void broadcast(const array& in, array& out);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
157
mlx/backend/common/buffer_cache.h
Normal file
157
mlx/backend/common/buffer_cache.h
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class BufferCache {
|
||||||
|
public:
|
||||||
|
BufferCache(
|
||||||
|
size_t page_size,
|
||||||
|
std::function<size_t(T*)> get_size,
|
||||||
|
std::function<void(T*)> free)
|
||||||
|
: page_size_(page_size),
|
||||||
|
get_size_(std::move(get_size)),
|
||||||
|
free_(std::move(free)) {}
|
||||||
|
|
||||||
|
~BufferCache() {
|
||||||
|
clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferCache(const BufferCache&) = delete;
|
||||||
|
BufferCache& operator=(const BufferCache&) = delete;
|
||||||
|
|
||||||
|
T* reuse_from_cache(size_t size) {
|
||||||
|
// Find the closest buffer in pool.
|
||||||
|
auto it = buffer_pool_.lower_bound(size);
|
||||||
|
if (it == buffer_pool_.end() ||
|
||||||
|
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect from the cache.
|
||||||
|
T* buf = it->second->buf;
|
||||||
|
pool_size_ -= it->first;
|
||||||
|
|
||||||
|
// Remove from record.
|
||||||
|
remove_from_list(it->second);
|
||||||
|
buffer_pool_.erase(it);
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void recycle_to_cache(T* buf) {
|
||||||
|
assert(buf);
|
||||||
|
// Add to cache.
|
||||||
|
BufferHolder* bh = new BufferHolder(buf);
|
||||||
|
add_at_head(bh);
|
||||||
|
size_t size = get_size_(buf);
|
||||||
|
pool_size_ += size;
|
||||||
|
buffer_pool_.emplace(size, bh);
|
||||||
|
}
|
||||||
|
|
||||||
|
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||||
|
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||||
|
return clear();
|
||||||
|
} else {
|
||||||
|
int n_release = 0;
|
||||||
|
size_t total_bytes_freed = 0;
|
||||||
|
|
||||||
|
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||||
|
// Release buffer.
|
||||||
|
size_t size = get_size_(tail_->buf);
|
||||||
|
total_bytes_freed += size;
|
||||||
|
free_(tail_->buf);
|
||||||
|
n_release++;
|
||||||
|
|
||||||
|
// Remove from record.
|
||||||
|
auto its = buffer_pool_.equal_range(size);
|
||||||
|
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||||
|
return el.second == tail_;
|
||||||
|
});
|
||||||
|
assert(it != buffer_pool_.end());
|
||||||
|
buffer_pool_.erase(it);
|
||||||
|
remove_from_list(tail_);
|
||||||
|
}
|
||||||
|
|
||||||
|
pool_size_ -= total_bytes_freed;
|
||||||
|
return n_release;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int clear() {
|
||||||
|
int n_release = 0;
|
||||||
|
for (auto& [size, holder] : buffer_pool_) {
|
||||||
|
free_(holder->buf);
|
||||||
|
n_release++;
|
||||||
|
delete holder;
|
||||||
|
}
|
||||||
|
buffer_pool_.clear();
|
||||||
|
pool_size_ = 0;
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
return n_release;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t cache_size() const {
|
||||||
|
return pool_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t page_size() const {
|
||||||
|
return page_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct BufferHolder {
|
||||||
|
public:
|
||||||
|
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||||
|
|
||||||
|
BufferHolder* prev{nullptr};
|
||||||
|
BufferHolder* next{nullptr};
|
||||||
|
T* buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
void add_at_head(BufferHolder* to_add) {
|
||||||
|
if (!head_) {
|
||||||
|
head_ = to_add;
|
||||||
|
tail_ = to_add;
|
||||||
|
} else {
|
||||||
|
head_->prev = to_add;
|
||||||
|
to_add->next = head_;
|
||||||
|
head_ = to_add;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void remove_from_list(BufferHolder* to_remove) {
|
||||||
|
if (to_remove->prev && to_remove->next) { // if middle
|
||||||
|
to_remove->prev->next = to_remove->next;
|
||||||
|
to_remove->next->prev = to_remove->prev;
|
||||||
|
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||||
|
tail_ = to_remove->prev;
|
||||||
|
tail_->next = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||||
|
head_ = to_remove->next;
|
||||||
|
head_->prev = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
delete to_remove;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||||
|
BufferHolder* head_{nullptr};
|
||||||
|
BufferHolder* tail_{nullptr};
|
||||||
|
size_t pool_size_{0};
|
||||||
|
|
||||||
|
const size_t page_size_;
|
||||||
|
std::function<size_t(T*)> get_size_;
|
||||||
|
std::function<void(T*)> free_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/broadcasting.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@@ -42,23 +43,6 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void broadcast(const array& in, array& out) {
|
|
||||||
if (out.size() == 0) {
|
|
||||||
out.set_data(nullptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Strides strides(out.ndim(), 0);
|
|
||||||
int diff = out.ndim() - in.ndim();
|
|
||||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
|
||||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
|
||||||
}
|
|
||||||
auto flags = in.flags();
|
|
||||||
if (out.size() > in.size()) {
|
|
||||||
flags.row_contiguous = flags.col_contiguous = false;
|
|
||||||
}
|
|
||||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||||
broadcast(inputs[0], out);
|
broadcast(inputs[0], out);
|
||||||
}
|
}
|
||||||
@@ -103,7 +87,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
double numel = 1;
|
double numel = 1;
|
||||||
for (auto ax : axes_) {
|
for (auto ax : axes_) {
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -15,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) {
|
|||||||
return print_float_constant<float16_t>(os, x);
|
return print_float_constant<float16_t>(os, x);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return print_float_constant<bfloat16_t>(os, x);
|
return print_float_constant<bfloat16_t>(os, x);
|
||||||
|
case float64:
|
||||||
|
return print_float_constant<double>(os, x);
|
||||||
case complex64:
|
case complex64:
|
||||||
return print_complex_constant<complex64_t>(os, x);
|
return print_complex_constant<complex64_t>(os, x);
|
||||||
case int8:
|
case int8:
|
||||||
@@ -51,6 +52,8 @@ std::string get_type_string(Dtype d) {
|
|||||||
return "float16_t";
|
return "float16_t";
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return "bfloat16_t";
|
return "bfloat16_t";
|
||||||
|
case float64:
|
||||||
|
return "double";
|
||||||
case complex64:
|
case complex64:
|
||||||
return "complex64_t";
|
return "complex64_t";
|
||||||
case bool_:
|
case bool_:
|
||||||
@@ -79,55 +82,6 @@ std::string get_type_string(Dtype d) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_lib_name(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
const std::vector<array>& tape,
|
|
||||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
|
||||||
NodeNamer namer;
|
|
||||||
std::ostringstream os;
|
|
||||||
std::ostringstream constant_hasher;
|
|
||||||
|
|
||||||
// Fill the input names. This is not really necessary, I just like having A,
|
|
||||||
// B, C, ... as the inputs.
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
namer.get_name(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The primitives describing the tape. For unary and binary primitives this
|
|
||||||
// must be enough to describe the full computation.
|
|
||||||
for (auto& a : tape) {
|
|
||||||
// name and type of output
|
|
||||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
|
||||||
// computation performed
|
|
||||||
a.primitive().print(os);
|
|
||||||
// name of inputs to the function
|
|
||||||
for (auto& inp : a.inputs()) {
|
|
||||||
os << namer.get_name(inp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
os << "_";
|
|
||||||
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
|
||||||
os << "C";
|
|
||||||
print_constant(constant_hasher, x);
|
|
||||||
} else {
|
|
||||||
os << (is_scalar(x) ? "S" : "V");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
os << "_";
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
os << kindof(x.dtype()) << x.itemsize();
|
|
||||||
}
|
|
||||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
|
||||||
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool compiled_check_contiguity(
|
bool compiled_check_contiguity(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const Shape& shape) {
|
const Shape& shape) {
|
||||||
@@ -159,9 +113,10 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::vector<array>& inputs_,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
bool contiguous,
|
||||||
bool contiguous) {
|
const std::function<allocator::Buffer(size_t)>&
|
||||||
|
mallocfn /* = allocator::malloc */) {
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
int o = 0;
|
int o = 0;
|
||||||
Strides strides;
|
Strides strides;
|
||||||
@@ -175,8 +130,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||||
in.is_donatable() &&
|
in.is_donatable() && is_constant(i)) {
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o++].copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
// Get representative input flags to properly set non-donated outputs
|
// Get representative input flags to properly set non-donated outputs
|
||||||
@@ -188,7 +142,7 @@ void compiled_allocate_outputs(
|
|||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(
|
outputs[o].set_data(
|
||||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
mallocfn(data_size * outputs[o].itemsize()),
|
||||||
data_size,
|
data_size,
|
||||||
strides,
|
strides,
|
||||||
flags);
|
flags);
|
||||||
@@ -204,16 +158,86 @@ void compiled_allocate_outputs(
|
|||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
is_constant(i)) {
|
||||||
outputs[o].copy_shared_buffer(
|
outputs[o].copy_shared_buffer(
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
o++;
|
o++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
outputs[o].set_data(mallocfn(outputs[o].nbytes()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const array& out,
|
||||||
|
const std::function<bool(size_t)>& is_constant) {
|
||||||
|
const Shape& shape = out.shape();
|
||||||
|
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||||
|
if (contiguous) {
|
||||||
|
return {true, shape, {}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Strides> strides_vec{out.strides()};
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
// Skip constants.
|
||||||
|
if (is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip scalar inputs.
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
if (is_scalar(x)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast the inputs to the output shape.
|
||||||
|
Strides xstrides;
|
||||||
|
size_t j = 0;
|
||||||
|
for (; j < shape.size() - x.ndim(); ++j) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(out.strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||||
|
if (x.shape(i) == 1) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(out.strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(x.strides()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
strides_vec.push_back(std::move(xstrides));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||||
|
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool compiled_use_large_index(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
bool contiguous) {
|
||||||
|
if (contiguous) {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (const auto& in : inputs) {
|
||||||
|
max_size = std::max(max_size, in.data_size());
|
||||||
|
}
|
||||||
|
return max_size > UINT32_MAX;
|
||||||
|
} else {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (const auto& o : outputs) {
|
||||||
|
max_size = std::max(max_size, o.size());
|
||||||
|
}
|
||||||
|
return max_size > UINT32_MAX;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <sstream>
|
|
||||||
#include <unordered_set>
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -14,19 +13,17 @@ inline bool is_static_cast(const Primitive& p) {
|
|||||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_lib_name(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
const std::vector<array>& tape,
|
|
||||||
const std::unordered_set<uintptr_t>& constant_ids);
|
|
||||||
|
|
||||||
std::string get_type_string(Dtype d);
|
std::string get_type_string(Dtype d);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void print_float_constant(std::ostream& os, const array& x) {
|
void print_float_constant(std::ostream& os, const array& x) {
|
||||||
auto old_precision = os.precision();
|
auto old_precision = os.precision();
|
||||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
if constexpr (std::is_same_v<T, double>) {
|
||||||
<< x.item<T>() << std::setprecision(old_precision);
|
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
||||||
|
} else {
|
||||||
|
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
||||||
|
}
|
||||||
|
os << x.item<T>() << std::setprecision(old_precision);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -60,8 +57,21 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::vector<array>& inputs_,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
bool contiguous,
|
||||||
|
const std::function<allocator::Buffer(size_t)>& mallocfn =
|
||||||
|
allocator::malloc);
|
||||||
|
|
||||||
|
// Collapse contiguous dims ignoring scalars and constants.
|
||||||
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const array& out,
|
||||||
|
const std::function<bool(size_t)>& is_constant);
|
||||||
|
|
||||||
|
// Return whether the kernel should use large index.
|
||||||
|
bool compiled_use_large_index(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
bool contiguous);
|
bool contiguous);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -22,23 +22,27 @@ enum class CopyType {
|
|||||||
GeneralGeneral
|
GeneralGeneral
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
inline bool set_copy_output_data(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
CopyType ctype,
|
||||||
|
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
// If the input is donateable, we are doing a vector copy and the types
|
// If the input is donateable, we are doing a vector copy and the types
|
||||||
// have the same size, then the input buffer can hold the output.
|
// have the same size, then the input buffer can hold the output.
|
||||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
if (is_donatable(in, out)) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
mallocfn(in.data_size() * out.itemsize()),
|
||||||
in.data_size(),
|
in.data_size(),
|
||||||
in.strides(),
|
in.strides(),
|
||||||
in.flags());
|
in.flags());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(mallocfn(out.nbytes()));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,7 +99,11 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
|||||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (n > (1 << 26)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||||
|
}
|
||||||
return {n, m};
|
return {n, m};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto read_task = [out_ptr = out.data<char>(),
|
auto read_task = [out_ptr = out.data<char>(),
|
||||||
size = out.size(),
|
size = out.size(),
|
||||||
itemsize = out.itemsize(),
|
itemsize = out.itemsize(),
|
||||||
|
|||||||
67
mlx/backend/common/matmul.h
Normal file
67
mlx/backend/common/matmul.h
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||||
|
const array& a,
|
||||||
|
const array& b) {
|
||||||
|
if (a.ndim() == 2) {
|
||||||
|
return {Shape{1}, Strides{0}, Strides{0}};
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||||
|
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
|
|
||||||
|
auto [batch_shape, batch_strides] =
|
||||||
|
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||||
|
|
||||||
|
auto a_batch_strides = batch_strides[0];
|
||||||
|
auto b_batch_strides = batch_strides[1];
|
||||||
|
|
||||||
|
if (batch_shape.empty()) {
|
||||||
|
batch_shape.push_back(1);
|
||||||
|
a_batch_strides.push_back(0);
|
||||||
|
b_batch_strides.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||||
|
collapse_batches(const array& a, const array& b, const array& c) {
|
||||||
|
if (a.ndim() == 2) {
|
||||||
|
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||||
|
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
|
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||||
|
|
||||||
|
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||||
|
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||||
|
|
||||||
|
auto A_batch_stride = batch_strides[0];
|
||||||
|
auto B_batch_stride = batch_strides[1];
|
||||||
|
auto C_batch_stride = batch_strides[2];
|
||||||
|
|
||||||
|
if (batch_shape.empty()) {
|
||||||
|
batch_shape.push_back(1);
|
||||||
|
A_batch_stride.push_back(0);
|
||||||
|
B_batch_stride.push_back(0);
|
||||||
|
C_batch_stride.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(
|
||||||
|
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -5,11 +5,9 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
Shape shape,
|
||||||
|
Strides strides,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto shape = x.shape();
|
|
||||||
auto strides = x.strides();
|
|
||||||
|
|
||||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||||
int a = axes[i];
|
int a = axes[i];
|
||||||
shape.erase(shape.begin() + a);
|
shape.erase(shape.begin() + a);
|
||||||
@@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|||||||
return std::make_pair(shape, strides);
|
return std::make_pair(shape, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
|
const array& x,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto shape = x.shape();
|
||||||
|
auto strides = x.strides();
|
||||||
|
return shapes_without_reduction_axes(
|
||||||
|
std::move(shape), std::move(strides), axes);
|
||||||
|
}
|
||||||
|
|
||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
|
|||||||
@@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
|||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
const array& x,
|
||||||
const std::vector<int>& axes);
|
const std::vector<int>& axes);
|
||||||
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
|
Shape shape,
|
||||||
|
Strides strides,
|
||||||
|
const std::vector<int>& axes);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -14,17 +14,13 @@ std::tuple<int64_t, Strides> prepare_slice(
|
|||||||
data_offset += start_indices[i] * in.strides()[i];
|
data_offset += start_indices[i] * in.strides()[i];
|
||||||
inp_strides[i] = in.strides()[i] * strides[i];
|
inp_strides[i] = in.strides()[i] * strides[i];
|
||||||
}
|
}
|
||||||
// Normalize the offset
|
|
||||||
if (data_offset < 0) {
|
|
||||||
data_offset += in.data_size();
|
|
||||||
}
|
|
||||||
return std::make_tuple(data_offset, inp_strides);
|
return std::make_tuple(data_offset, inp_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
void shared_buffer_slice(
|
void shared_buffer_slice(
|
||||||
const array& in,
|
const array& in,
|
||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
size_t data_offset,
|
int64_t data_offset,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
array& out) {
|
array& out) {
|
||||||
// Compute row/col contiguity
|
// Compute row/col contiguity
|
||||||
@@ -45,23 +41,30 @@ void slice(
|
|||||||
const Shape& start_indices,
|
const Shape& start_indices,
|
||||||
const Shape& strides) {
|
const Shape& strides) {
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
out.set_data(nullptr);
|
out.set_data(allocator::malloc(0));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate out strides, initial offset
|
// Calculate out strides, initial offset
|
||||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||||
int64_t data_end = 1;
|
|
||||||
for (int i = 0; i < start_indices.size(); ++i) {
|
// Get the location of the end based on the inp strides and out.shape()
|
||||||
if (in.shape()[i] > 1) {
|
int64_t low_idx = 0;
|
||||||
auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1;
|
int64_t high_idx = 0;
|
||||||
data_end += end_idx * in.strides()[i];
|
for (int i = 0; i < inp_strides.size(); ++i) {
|
||||||
|
auto delta = inp_strides[i] * (out.shape()[i] - 1);
|
||||||
|
if (inp_strides[i] > 0) {
|
||||||
|
high_idx += delta;
|
||||||
|
} else {
|
||||||
|
low_idx += delta;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (data_end < 0) {
|
int64_t data_size = (high_idx - low_idx) + 1;
|
||||||
data_end += in.data_size();
|
if (data_size < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[slice] Computed invalid data size: " << data_size << ".";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
size_t data_size = (data_end - data_offset);
|
|
||||||
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ namespace mlx::core {
|
|||||||
enum class TernaryOpType {
|
enum class TernaryOpType {
|
||||||
ScalarScalarScalar,
|
ScalarScalarScalar,
|
||||||
VectorVectorVector,
|
VectorVectorVector,
|
||||||
|
VectorVectorScalar,
|
||||||
|
VectorScalarVector,
|
||||||
General,
|
General,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -25,6 +27,14 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
|
|||||||
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
||||||
c.flags().col_contiguous)) {
|
c.flags().col_contiguous)) {
|
||||||
topt = TernaryOpType::VectorVectorVector;
|
topt = TernaryOpType::VectorVectorVector;
|
||||||
|
} else if (
|
||||||
|
b.data_size() == 1 && a.flags().row_contiguous &&
|
||||||
|
c.flags().row_contiguous) {
|
||||||
|
topt = TernaryOpType::VectorScalarVector;
|
||||||
|
} else if (
|
||||||
|
c.data_size() == 1 && a.flags().row_contiguous &&
|
||||||
|
b.flags().row_contiguous) {
|
||||||
|
topt = TernaryOpType::VectorVectorScalar;
|
||||||
} else {
|
} else {
|
||||||
topt = TernaryOpType::General;
|
topt = TernaryOpType::General;
|
||||||
}
|
}
|
||||||
@@ -36,7 +46,8 @@ inline void set_ternary_op_output_data(
|
|||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
array& out,
|
array& out,
|
||||||
TernaryOpType topt) {
|
TernaryOpType topt,
|
||||||
|
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||||
auto maybe_donate = [&out](const array& x) {
|
auto maybe_donate = [&out](const array& x) {
|
||||||
if (is_donatable(x, out)) {
|
if (is_donatable(x, out)) {
|
||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
@@ -47,24 +58,25 @@ inline void set_ternary_op_output_data(
|
|||||||
|
|
||||||
switch (topt) {
|
switch (topt) {
|
||||||
case TernaryOpType::ScalarScalarScalar:
|
case TernaryOpType::ScalarScalarScalar:
|
||||||
out.set_data(
|
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
|
||||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
|
||||||
break;
|
break;
|
||||||
case TernaryOpType::VectorVectorVector:
|
case TernaryOpType::VectorVectorVector:
|
||||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
mallocfn(out.itemsize() * b.data_size()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case TernaryOpType::VectorVectorScalar:
|
||||||
|
case TernaryOpType::VectorScalarVector:
|
||||||
case TernaryOpType::General:
|
case TernaryOpType::General:
|
||||||
// Try to donate an input which is row_contiguous
|
// Try to donate an input which is row_contiguous
|
||||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(mallocfn(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
29
mlx/backend/common/unary.h
Normal file
29
mlx/backend/common/unary.h
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline void set_unary_output_data(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
||||||
|
if (in.flags().contiguous) {
|
||||||
|
if (is_donatable(in, out)) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
mallocfn(in.data_size() * out.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(mallocfn(out.nbytes()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,9 +1,22 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
std::filesystem::path current_binary_dir() {
|
||||||
|
static std::filesystem::path binary_dir = []() {
|
||||||
|
Dl_info info;
|
||||||
|
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||||
|
throw std::runtime_error("Unable to get current binary dir.");
|
||||||
|
}
|
||||||
|
return std::filesystem::path(info.dli_fname).parent_path();
|
||||||
|
}();
|
||||||
|
return binary_dir;
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const std::vector<Strides>& strides,
|
const std::vector<Strides>& strides,
|
||||||
@@ -101,4 +114,118 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||||
|
int pows[3] = {0, 0, 0};
|
||||||
|
int sum = 0;
|
||||||
|
while (true) {
|
||||||
|
int presum = sum;
|
||||||
|
// Check all the pows
|
||||||
|
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||||
|
pows[0]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||||
|
pows[1]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||||
|
pows[2]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == presum || sum == pow2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
||||||
|
// Dims with strides of 0 are ignored as they
|
||||||
|
// correspond to broadcasted dimensions
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims get_2d_grid_dims_common(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor) {
|
||||||
|
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||||
|
// divided by divisor.
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// No need to add this shape we can just remove it from the divisor.
|
||||||
|
if (divisor % shape[i] == 0) {
|
||||||
|
divisor /= shape[i];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (divisor > 1) {
|
||||||
|
if (grid_x % divisor == 0) {
|
||||||
|
grid_x /= divisor;
|
||||||
|
divisor = 1;
|
||||||
|
} else if (grid_y % divisor == 0) {
|
||||||
|
grid_y /= divisor;
|
||||||
|
divisor = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
|
if (divisor > 1) {
|
||||||
|
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||||
|
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
||||||
|
auto gx = (dim0 + bx - 1) / bx;
|
||||||
|
auto gy = (dim1 + by - 1) / by;
|
||||||
|
auto gz = (dim2 + bz - 1) / bz;
|
||||||
|
|
||||||
|
return std::make_pair(
|
||||||
|
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,12 +2,17 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Return the directory that contains current shared library.
|
||||||
|
std::filesystem::path current_binary_dir();
|
||||||
|
|
||||||
inline int64_t
|
inline int64_t
|
||||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||||
int64_t loc = 0;
|
int64_t loc = 0;
|
||||||
@@ -70,6 +75,31 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
const array& a,
|
const array& a,
|
||||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||||
|
|
||||||
|
// Compute the thread block dimensions which fit the given
|
||||||
|
// input dimensions.
|
||||||
|
// - The thread block dimensions will be powers of two
|
||||||
|
// - The thread block size will be less than 2^pow2
|
||||||
|
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
||||||
|
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||||
|
|
||||||
|
// Computes a 2D grid where each element is < UINT_MAX
|
||||||
|
// Assumes:
|
||||||
|
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||||
|
// - shape and strides correspond to a contiguous (no holes) but
|
||||||
|
// possibly broadcasted array
|
||||||
|
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
||||||
|
|
||||||
|
// Same as above but we do an implicit division with divisor.
|
||||||
|
// Basically, equivalent to factorizing
|
||||||
|
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||||
|
Dims get_2d_grid_dims_common(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor);
|
||||||
|
|
||||||
|
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
||||||
|
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
||||||
|
|
||||||
struct ContiguousIterator {
|
struct ContiguousIterator {
|
||||||
inline void step() {
|
inline void step() {
|
||||||
int dims = shape_.size();
|
int dims = shape_.size();
|
||||||
@@ -165,4 +195,11 @@ void shared_buffer_reshape(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||||
|
vec.erase(std::next(vec.begin(), index));
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -40,11 +40,13 @@ add_dependencies(mlx cpu_compiled_preamble)
|
|||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
@@ -58,6 +60,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
@@ -73,8 +76,8 @@ target_sources(
|
|||||||
if(MLX_BUILD_ACCELERATE)
|
if(MLX_BUILD_ACCELERATE)
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(IOS)
|
if(IOS)
|
||||||
|
|||||||
@@ -14,10 +14,8 @@ template <typename InT, typename OpT>
|
|||||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||||
auto axis_size = in.shape()[axis];
|
auto axis_size = in.shape()[axis];
|
||||||
auto axis_stride = in.strides()[axis];
|
auto axis_stride = in.strides()[axis];
|
||||||
Strides strides = in.strides();
|
Strides strides = remove_index(in.strides(), axis);
|
||||||
Shape shape = in.shape();
|
Shape shape = remove_index(in.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
shape.erase(shape.begin() + axis);
|
|
||||||
auto in_ptr = in.data<InT>();
|
auto in_ptr = in.data<InT>();
|
||||||
auto out_ptr = out.data<uint32_t>();
|
auto out_ptr = out.data<uint32_t>();
|
||||||
|
|
||||||
@@ -68,7 +66,7 @@ void arg_reduce_dispatch(
|
|||||||
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|||||||
11
mlx/backend/cpu/available.cpp
Normal file
11
mlx/backend/cpu/available.cpp
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/available.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
||||||
9
mlx/backend/cpu/available.h
Normal file
9
mlx/backend/cpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
||||||
@@ -14,230 +14,11 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void comparison_op(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
Op op,
|
|
||||||
Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (a.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_float(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
Op op,
|
|
||||||
Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[binary_float] Only supports non-complex floating point types.");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary_int(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
array& out,
|
|
||||||
Op op,
|
|
||||||
Stream stream) {
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream);
|
|
||||||
encoder.set_input_array(a);
|
|
||||||
encoder.set_input_array(b);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
||||||
b = array::unsafe_weak_copy(b),
|
|
||||||
out = array::unsafe_weak_copy(out),
|
|
||||||
bopt]() mutable {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool, Op>(a, b, out, bopt);
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error("[binary_int] Type not supported");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Add(), stream());
|
binary_op_cpu(a, b, out, detail::Add(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void DivMod::eval_cpu(
|
void DivMod::eval_cpu(
|
||||||
@@ -321,14 +102,14 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Divide(), stream());
|
binary_op_cpu(a, b, out, detail::Divide(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Remainder(), stream());
|
binary_op_cpu(a, b, out, detail::Remainder(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -369,89 +150,90 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
comparison_op(a, b, out, detail::Equal(), stream());
|
comparison_op_cpu(a, b, out, detail::Equal(), stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::Greater(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
comparison_op_cpu(
|
||||||
|
inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::Less(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary_float(a, b, out, detail::LogAddExp(), stream());
|
binary_float_op_cpu(a, b, out, detail::LogAddExp(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||||
auto& in1 = inputs[0];
|
auto& in1 = inputs[0];
|
||||||
auto& in2 = inputs[1];
|
auto& in2 = inputs[1];
|
||||||
binary(in1, in2, out, detail::LogicalAnd(), stream());
|
binary_op_cpu(in1, in2, out, detail::LogicalAnd(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||||
auto& in1 = inputs[0];
|
auto& in1 = inputs[0];
|
||||||
auto& in2 = inputs[1];
|
auto& in2 = inputs[1];
|
||||||
binary(in1, in2, out, detail::LogicalOr(), stream());
|
binary_op_cpu(in1, in2, out, detail::LogicalOr(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Maximum(), stream());
|
binary_op_cpu(a, b, out, detail::Maximum(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Minimum(), stream());
|
binary_op_cpu(a, b, out, detail::Minimum(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Multiply(), stream());
|
binary_op_cpu(a, b, out, detail::Multiply(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
comparison_op_cpu(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Power(), stream());
|
binary_op_cpu(a, b, out, detail::Power(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Subtract(), stream());
|
binary_op_cpu(a, b, out, detail::Subtract(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -460,19 +242,19 @@ void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
switch (op_) {
|
switch (op_) {
|
||||||
case BitwiseBinary::And:
|
case BitwiseBinary::And:
|
||||||
binary_int(a, b, out, detail::BitwiseAnd(), stream());
|
binary_int_op_cpu(a, b, out, detail::BitwiseAnd(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Or:
|
case BitwiseBinary::Or:
|
||||||
binary_int(a, b, out, detail::BitwiseOr(), stream());
|
binary_int_op_cpu(a, b, out, detail::BitwiseOr(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Xor:
|
case BitwiseBinary::Xor:
|
||||||
binary_int(a, b, out, detail::BitwiseXor(), stream());
|
binary_int_op_cpu(a, b, out, detail::BitwiseXor(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::LeftShift:
|
case BitwiseBinary::LeftShift:
|
||||||
binary_int(a, b, out, detail::LeftShift(), stream());
|
binary_int_op_cpu(a, b, out, detail::LeftShift(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::RightShift:
|
case BitwiseBinary::RightShift:
|
||||||
binary_int(a, b, out, detail::RightShift(), stream());
|
binary_int_op_cpu(a, b, out, detail::RightShift(), stream());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -481,7 +263,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
const auto& b = inputs[1];
|
const auto& b = inputs[1];
|
||||||
binary_float(a, b, out, detail::ArcTan2(), stream());
|
binary_float_op_cpu(a, b, out, detail::ArcTan2(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -290,4 +291,227 @@ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
|||||||
binary_op<T, T, Op>(a, b, out, bopt);
|
binary_op<T, T, Op>(a, b, out, bopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_op_cpu(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void comparison_op_cpu(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_float_op_cpu(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[binary_float] Only supports floating point types.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_int_op_cpu(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool, Op>(a, b, out, bopt);
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("[binary_int] Type not supported");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
|||||||
|
|
||||||
// The decomposition is computed in place, so just copy the input to the
|
// The decomposition is computed in place, so just copy the input to the
|
||||||
// output.
|
// output.
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
factor,
|
factor,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
#include "mlx/backend/cpu/jit_compiler.h"
|
#include "mlx/backend/cpu/jit_compiler.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/graph_utils.h"
|
||||||
|
#include "mlx/version.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -40,7 +41,10 @@ struct CompilerCache {
|
|||||||
std::shared_mutex mtx;
|
std::shared_mutex mtx;
|
||||||
};
|
};
|
||||||
|
|
||||||
static CompilerCache cache{};
|
static CompilerCache& cache() {
|
||||||
|
static CompilerCache cache_;
|
||||||
|
return cache_;
|
||||||
|
};
|
||||||
|
|
||||||
// GPU compile is always available if the GPU is available and since we are in
|
// GPU compile is always available if the GPU is available and since we are in
|
||||||
// this file CPU compile is also available.
|
// this file CPU compile is also available.
|
||||||
@@ -56,14 +60,16 @@ void* compile(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::function<std::string(void)>& source_builder) {
|
const std::function<std::string(void)>& source_builder) {
|
||||||
{
|
{
|
||||||
std::shared_lock lock(cache.mtx);
|
std::shared_lock lock(cache().mtx);
|
||||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
if (auto it = cache().kernels.find(kernel_name);
|
||||||
|
it != cache().kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_lock lock(cache.mtx);
|
std::unique_lock lock(cache().mtx);
|
||||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
if (auto it = cache().kernels.find(kernel_name);
|
||||||
|
it != cache().kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
std::string source_code = source_builder();
|
std::string source_code = source_builder();
|
||||||
@@ -89,7 +95,11 @@ void* compile(
|
|||||||
kernel_file_name = kernel_name;
|
kernel_file_name = kernel_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto output_dir = std::filesystem::temp_directory_path();
|
auto output_dir =
|
||||||
|
std::filesystem::temp_directory_path() / "mlx" / version() / "cpu";
|
||||||
|
if (!std::filesystem::exists(output_dir)) {
|
||||||
|
std::filesystem::create_directories(output_dir);
|
||||||
|
}
|
||||||
|
|
||||||
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
|
std::string shared_lib_name = "lib" + kernel_file_name + ".so";
|
||||||
auto shared_lib_path = (output_dir / shared_lib_name).string();
|
auto shared_lib_path = (output_dir / shared_lib_name).string();
|
||||||
@@ -120,10 +130,10 @@ void* compile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load library
|
// load library
|
||||||
cache.libs.emplace_back(shared_lib_path);
|
cache().libs.emplace_back(shared_lib_path);
|
||||||
|
|
||||||
// Load function
|
// Load function
|
||||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
||||||
if (!fun) {
|
if (!fun) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||||
@@ -131,7 +141,7 @@ void* compile(
|
|||||||
<< dlerror();
|
<< dlerror();
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
cache.kernels.insert({kernel_name, fun});
|
cache().kernels.insert({kernel_name, fun});
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,18 +151,9 @@ inline void build_kernel(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& tape,
|
const std::vector<array>& tape,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
bool contiguous,
|
bool contiguous,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
// All outputs should have the exact same shape and will be row contiguous
|
|
||||||
auto output_shape = outputs[0].shape();
|
|
||||||
auto output_strides = outputs[0].strides();
|
|
||||||
|
|
||||||
// Constants are scalars that are captured by value and cannot change
|
|
||||||
auto is_constant = [&constant_ids](const array& x) {
|
|
||||||
return constant_ids.find(x.id()) != constant_ids.end();
|
|
||||||
};
|
|
||||||
|
|
||||||
NodeNamer namer;
|
NodeNamer namer;
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
@@ -161,25 +162,28 @@ inline void build_kernel(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Start the kernel
|
// Start the kernel
|
||||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
os << "void " << kernel_name
|
||||||
|
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
|
||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
for (auto& x : inputs) {
|
int strides_index = 1;
|
||||||
auto& xname = namer.get_name(x);
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
|
||||||
// Skip constants from the input list
|
// Skip constants from the input list
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
auto tstr = get_type_string(x.dtype());
|
auto tstr = get_type_string(x.dtype());
|
||||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||||
<< "];" << std::endl;
|
<< "];" << std::endl;
|
||||||
// Scalars and contiguous need no strides
|
// Scalars and contiguous need no strides
|
||||||
if (!is_scalar(x) && !contiguous) {
|
if (!is_scalar(x) && !contiguous) {
|
||||||
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
|
os << " const int64_t* " << xname << "_strides = strides["
|
||||||
<< "];" << std::endl;
|
<< strides_index++ << "];" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,10 +193,8 @@ inline void build_kernel(
|
|||||||
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
||||||
<< "*)args[" << cnt++ << "];" << std::endl;
|
<< "*)args[" << cnt++ << "];" << std::endl;
|
||||||
}
|
}
|
||||||
// Add output strides and shape to extract the indices.
|
// Add output size
|
||||||
if (!contiguous) {
|
if (contiguous) {
|
||||||
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
|
|
||||||
} else {
|
|
||||||
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,10 +208,11 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read the inputs in tmps
|
// Read the inputs in tmps
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||||
print_constant(os, x);
|
print_constant(os, x);
|
||||||
os << ";" << std::endl;
|
os << ";" << std::endl;
|
||||||
@@ -233,7 +236,7 @@ inline void build_kernel(
|
|||||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||||
} else {
|
} else {
|
||||||
x.primitive().print(os);
|
os << x.primitive().name();
|
||||||
os << "()(";
|
os << "()(";
|
||||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||||
@@ -259,8 +262,9 @@ inline void build_kernel(
|
|||||||
} else {
|
} else {
|
||||||
for (int d = ndim - 1; d >= 0; --d) {
|
for (int d = ndim - 1; d >= 0; --d) {
|
||||||
// Update pointers
|
// Update pointers
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
if (is_constant(x) || is_scalar(x)) {
|
const auto& x = inputs[i];
|
||||||
|
if (is_constant(i) || is_scalar(x)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
@@ -282,65 +286,33 @@ inline void build_kernel(
|
|||||||
void Compiled::eval_cpu(
|
void Compiled::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
if (kernel_lib_.empty()) {
|
|
||||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Figure out which kernel we are using
|
|
||||||
auto& shape = outputs[0].shape();
|
|
||||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
|
||||||
// Handle all broadcasting and collect function input arguments
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
|
// handle all broadcasting.
|
||||||
|
auto [contiguous, shape, strides] =
|
||||||
|
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||||
|
|
||||||
|
// Collect function input arguments.
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
std::vector<std::vector<size_t>> strides;
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
if (is_constant_(i)) {
|
||||||
// Skip constants.
|
|
||||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& x = inputs[i];
|
const auto& x = inputs[i];
|
||||||
encoder.set_input_array(x);
|
encoder.set_input_array(x);
|
||||||
args.push_back((void*)x.data<void>());
|
args.push_back((void*)x.data<void>());
|
||||||
|
|
||||||
if (contiguous || is_scalar(x)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Broadcast the input to the output shape.
|
|
||||||
std::vector<size_t> xstrides;
|
|
||||||
int j = 0;
|
|
||||||
for (; j < shape.size() - x.ndim(); j++) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
|
||||||
if (x.shape(i) == 1) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(x.strides()[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
strides.push_back(std::move(xstrides));
|
|
||||||
args.push_back(strides.back().data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel name from the lib
|
// Get the kernel name from the lib
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
kernel_name += std::to_string(shape.size());
|
kernel_name += std::to_string(ndim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the function
|
// Get the function
|
||||||
auto fn_ptr = compile(kernel_name, [&]() {
|
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << get_kernel_preamble() << std::endl;
|
kernel << get_kernel_preamble() << std::endl;
|
||||||
kernel << "extern \"C\" {" << std::endl;
|
kernel << "extern \"C\" {" << std::endl;
|
||||||
@@ -350,7 +322,7 @@ void Compiled::eval_cpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
contiguous,
|
contiguous,
|
||||||
ndim);
|
ndim);
|
||||||
// Close extern "C"
|
// Close extern "C"
|
||||||
@@ -358,26 +330,26 @@ void Compiled::eval_cpu(
|
|||||||
return kernel.str();
|
return kernel.str();
|
||||||
});
|
});
|
||||||
|
|
||||||
compiled_allocate_outputs(
|
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
|
||||||
|
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
args.push_back(x.data<void>());
|
args.push_back(x.data<void>());
|
||||||
encoder.set_output_array(x);
|
encoder.set_output_array(x);
|
||||||
}
|
}
|
||||||
Shape out_shape;
|
if (contiguous) {
|
||||||
if (!contiguous) {
|
|
||||||
out_shape = outputs[0].shape();
|
|
||||||
args.push_back((void*)out_shape.data());
|
|
||||||
} else {
|
|
||||||
args.push_back((void*)outputs[0].data_size());
|
args.push_back((void*)outputs[0].data_size());
|
||||||
}
|
}
|
||||||
auto fun = (void (*)(void**))fn_ptr;
|
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
|
||||||
encoder.dispatch(
|
encoder.dispatch([fun,
|
||||||
[fun,
|
args = std::move(args),
|
||||||
args = std::move(args),
|
strides = std::move(strides),
|
||||||
strides = std::move(strides),
|
shape = std::move(shape)]() mutable {
|
||||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
SmallVector<int64_t*> strides_ptrs;
|
||||||
|
for (auto& s : strides) {
|
||||||
|
strides_ptrs.push_back(s.data());
|
||||||
|
}
|
||||||
|
fun(shape.data(), strides_ptrs.data(), args.data());
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user