mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
1219 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54f1cc6e3e | ||
|
|
b3825ac149 | ||
|
|
7f4b7e553c | ||
|
|
ad16f41a7f | ||
|
|
f46877bc08 | ||
|
|
6f35017d1b | ||
|
|
b167f0df1c | ||
|
|
a9f0d6b160 | ||
|
|
940f4c7818 | ||
|
|
35f81728f1 | ||
|
|
4442ed86c1 | ||
|
|
698559c231 | ||
|
|
ecc4879b07 | ||
|
|
32b18d8b66 | ||
|
|
472c43a0c8 | ||
|
|
b7214ff01e | ||
|
|
76414c8971 | ||
|
|
49e4566df3 | ||
|
|
aad49f932f | ||
|
|
86765cce34 | ||
|
|
1bedcbd556 | ||
|
|
9ac7dbe877 | ||
|
|
1bf605d56d | ||
|
|
3c622ddd1d | ||
|
|
27ff069175 | ||
|
|
3b2ffcefc3 | ||
|
|
b65f882df3 | ||
|
|
b704e9e77a | ||
|
|
66519fb348 | ||
|
|
8973550ff3 | ||
|
|
3f866be665 | ||
|
|
23f81ed1c1 | ||
|
|
3fe2250c00 | ||
|
|
047114b988 | ||
|
|
9320eb89a8 | ||
|
|
75819d70ea | ||
|
|
60d80a3728 | ||
|
|
eba6a9d163 | ||
|
|
be9e2aebd6 | ||
|
|
df58b4133a | ||
|
|
27778156dc | ||
|
|
761f901a41 | ||
|
|
6ece97f69b | ||
|
|
d3bc6a9bff | ||
|
|
26ceb507eb | ||
|
|
910b3e3299 | ||
|
|
50fa315d18 | ||
|
|
1ff2b713b6 | ||
|
|
50514a6146 | ||
|
|
93d76b0f30 | ||
|
|
78678de0cd | ||
|
|
ed9c6b1117 | ||
|
|
39b04ce638 | ||
|
|
d9e6349657 | ||
|
|
b901a9f311 | ||
|
|
68c5fa1c95 | ||
|
|
793a31eeb6 | ||
|
|
74c1ed25bb | ||
|
|
ec72b44417 | ||
|
|
460691a0e8 | ||
|
|
969924cc69 | ||
|
|
d1e06117e8 | ||
|
|
539d8322d1 | ||
|
|
c4767d110f | ||
|
|
895217f25b | ||
|
|
0cfeeb60ca | ||
|
|
8f8af61a37 | ||
|
|
233384161e | ||
|
|
5bcf3a6794 | ||
|
|
7707196297 | ||
|
|
7e3471c987 | ||
|
|
9f0ba3ddf1 | ||
|
|
4bce5f9b2d | ||
|
|
e9eab527eb | ||
|
|
36ca62dba8 | ||
|
|
9cbb1b0148 | ||
|
|
9bfc476d72 | ||
|
|
25e2356316 | ||
|
|
226a1d24e0 | ||
|
|
630350ad3e | ||
|
|
380aeb58ae | ||
|
|
f37389d100 | ||
|
|
e89e8b4272 | ||
|
|
85a8824a8c | ||
|
|
f5d4397e5c | ||
|
|
343e33b6d5 | ||
|
|
0073096dd1 | ||
|
|
e3d004fed9 | ||
|
|
a393435d28 | ||
|
|
a7a94b29d7 | ||
|
|
22a5da76c8 | ||
|
|
287c63a093 | ||
|
|
1c9ae1eaa1 | ||
|
|
c2c3e0b0a2 | ||
|
|
b0cc71ae71 | ||
|
|
e88f2d4a8e | ||
|
|
9cee557423 | ||
|
|
bbf1423953 | ||
|
|
eb24267b56 | ||
|
|
dc371ae7a5 | ||
|
|
e76a8dd5c5 | ||
|
|
b466dea982 | ||
|
|
7a6adda1e6 | ||
|
|
1a9f820af6 | ||
|
|
d4f4ff3c5e | ||
|
|
7c7e48dbd1 | ||
|
|
fbbf3b9b3e | ||
|
|
bf01ad9367 | ||
|
|
ae438d05fa | ||
|
|
711a645807 | ||
|
|
aa9d44b3d4 | ||
|
|
ec2ab42888 | ||
|
|
787c0d90cd | ||
|
|
e8b604a6a3 | ||
|
|
50cc09887f | ||
|
|
3f730e77aa | ||
|
|
caecbe876a | ||
|
|
8afb6d62f2 | ||
|
|
6ccfa603cd | ||
|
|
36cad99a11 | ||
|
|
ee18e1cbf0 | ||
|
|
af120c2bc0 | ||
|
|
6a3acf2301 | ||
|
|
d6977f2a57 | ||
|
|
db5443e831 | ||
|
|
52b8384d10 | ||
|
|
44cc5da4bc | ||
|
|
dde3682b69 | ||
|
|
17310d91a6 | ||
|
|
b194d65a6a | ||
|
|
a44b27f5f8 | ||
|
|
e5a33f2223 | ||
|
|
c1e3340b23 | ||
|
|
8f163a367d | ||
|
|
89a3df9014 | ||
|
|
c5d2937aa5 | ||
|
|
b61a65e313 | ||
|
|
04cbb4191c | ||
|
|
c5460762e7 | ||
|
|
8ce49cd39e | ||
|
|
9c68b50853 | ||
|
|
111f1e71af | ||
|
|
827003d568 | ||
|
|
d363a76aa4 | ||
|
|
70560b6bd5 | ||
|
|
7ef8a6f2d5 | ||
|
|
31c6f6e33f | ||
|
|
584d48458e | ||
|
|
5cf984ca87 | ||
|
|
a9bac3d9e5 | ||
|
|
5458d43247 | ||
|
|
a4dba65220 | ||
|
|
3dcb286baf | ||
|
|
4822c3dbe9 | ||
|
|
2ca75bb529 | ||
|
|
db14e29a0b | ||
|
|
d2f540f4e0 | ||
|
|
333ffea273 | ||
|
|
f55b6f1f2f | ||
|
|
30561229c7 | ||
|
|
068a4612e9 | ||
|
|
5722c147de | ||
|
|
f6819a1f26 | ||
|
|
f93f87c802 | ||
|
|
9392fc3f88 | ||
|
|
e843c4d8d5 | ||
|
|
0c5fc63a36 | ||
|
|
e397177f6e | ||
|
|
f4c8888cbe | ||
|
|
25c1e03205 | ||
|
|
512281781c | ||
|
|
ac85ddfdb7 | ||
|
|
65d0d40232 | ||
|
|
cea9369610 | ||
|
|
e7c6e1db82 | ||
|
|
c5fcd5b61b | ||
|
|
1df9887998 | ||
|
|
73f22d6226 | ||
|
|
c422050ca7 | ||
|
|
1ba18ff7d9 | ||
|
|
37b440faa8 | ||
|
|
888b13ed63 | ||
|
|
4abb218d21 | ||
|
|
6441c21a94 | ||
|
|
dfb5022eab | ||
|
|
ac207ce7aa | ||
|
|
fce53b61d6 | ||
|
|
8ae4a76308 | ||
|
|
7fde1b6a1e | ||
|
|
aa7b47481a | ||
|
|
56be773610 | ||
|
|
a9bdd67baa | ||
|
|
f2adb5638d | ||
|
|
728d4db582 | ||
|
|
db5c7efcf6 | ||
|
|
7bb96e4249 | ||
|
|
fa89f0b150 | ||
|
|
ca973d1e83 | ||
|
|
828c5f1137 | ||
|
|
7d86a5c108 | ||
|
|
0b807893a7 | ||
|
|
6ad0889c8a | ||
|
|
737dd6d1ac | ||
|
|
aaf78f4c6b | ||
|
|
8831064493 | ||
|
|
be9bc96da4 | ||
|
|
86258f292f | ||
|
|
b26d88591c | ||
|
|
86c6a15571 | ||
|
|
8b25ce62d5 | ||
|
|
da5912e4f2 | ||
|
|
daafee676f | ||
|
|
d32519c8ee | ||
|
|
b405591249 | ||
|
|
3bf81ed1bd | ||
|
|
2204182bba | ||
|
|
3628e5d497 | ||
|
|
a0ae49d397 | ||
|
|
254476718b | ||
|
|
3adba92ebe | ||
|
|
ef631d63af | ||
|
|
970dbe8e25 | ||
|
|
641be9463b | ||
|
|
ab0e608862 | ||
|
|
1588659062 | ||
|
|
b9e88fb976 | ||
|
|
4ad53414dd | ||
|
|
d1165b215e | ||
|
|
dcb8319f3d | ||
|
|
5597fa089c | ||
|
|
9acec364c2 | ||
|
|
7d9d6ef456 | ||
|
|
6f5874a2f2 | ||
|
|
70dc336785 | ||
|
|
4e504039f5 | ||
|
|
d1f4d291e8 | ||
|
|
e1840853ce | ||
|
|
0f5ce173da | ||
|
|
588854195f | ||
|
|
28d068bce6 | ||
|
|
d107d8d495 | ||
|
|
1e496ddb82 | ||
|
|
74eccbf3fa | ||
|
|
08638223ca | ||
|
|
56cc858af9 | ||
|
|
f55c4ed1d6 | ||
|
|
93d70419e7 | ||
|
|
63f663d9c6 | ||
|
|
84b4d96efa | ||
|
|
aec67f2fa6 | ||
|
|
deee214a95 | ||
|
|
45adec102c | ||
|
|
31fc530c76 | ||
|
|
fbb3f65a1a | ||
|
|
6b1b8ea91b | ||
|
|
b2273733ea | ||
|
|
f409b229a4 | ||
|
|
30571e2326 | ||
|
|
d7734edd9f | ||
|
|
2ba69bc8fa | ||
|
|
cb349a291c | ||
|
|
f0a0b077a0 | ||
|
|
49114f28ab | ||
|
|
e7d2ebadd2 | ||
|
|
e569803d7c | ||
|
|
d34f887abc | ||
|
|
5201df5030 | ||
|
|
2d3c26c565 | ||
|
|
6325f60d52 | ||
|
|
42cc9cfbc7 | ||
|
|
8347575ba1 | ||
|
|
b6eec20260 | ||
|
|
0eb035b4b1 | ||
|
|
afb9817599 | ||
|
|
8fb3e7a26c | ||
|
|
8c7bc30ce4 | ||
|
|
85873cb162 | ||
|
|
e14ee12491 | ||
|
|
8b9a3f3cea | ||
|
|
fb4e8b896b | ||
|
|
2ca533b279 | ||
|
|
4a9b29a875 | ||
|
|
a4fcc893cd | ||
|
|
9d10239af7 | ||
|
|
19facd4b20 | ||
|
|
f5299f72cd | ||
|
|
0e0d9ac522 | ||
|
|
8917022deb | ||
|
|
ec0d5db67b | ||
|
|
e76e9b87f0 | ||
|
|
cfb6a244ea | ||
|
|
58f3860306 | ||
|
|
dd4f53db63 | ||
|
|
3d5e17e507 | ||
|
|
33bf1a244b | ||
|
|
772f471ff2 | ||
|
|
2c11d10f8d | ||
|
|
656ed7f780 | ||
|
|
81bb9a2a9e | ||
|
|
5adf185f86 | ||
|
|
c9a9180584 | ||
|
|
76831ed83d | ||
|
|
b3d7b85376 | ||
|
|
cad5c0241c | ||
|
|
b8022c578a | ||
|
|
bc53f8293f | ||
|
|
c552ff2451 | ||
|
|
4fda5fbdf9 | ||
|
|
580776559b | ||
|
|
a14aaa7c9d | ||
|
|
a6d780154f | ||
|
|
6871e2eeb7 | ||
|
|
8402a2acf4 | ||
|
|
fddb6933e1 | ||
|
|
c8b4787e4e | ||
|
|
2188199ff8 | ||
|
|
aa07429bad | ||
|
|
918761a25a | ||
|
|
a4fc671d3e | ||
|
|
f5f65ef48c | ||
|
|
c2dd81a8aa | ||
|
|
d7e680ffe4 | ||
|
|
c371baf53a | ||
|
|
ccf78f566c | ||
|
|
c9fa68664a | ||
|
|
c35f4d089a | ||
|
|
8590c0941e | ||
|
|
095163b8d1 | ||
|
|
99c33d011d | ||
|
|
62fecf3e13 | ||
|
|
7c4eb5d03e | ||
|
|
bae9a6b404 | ||
|
|
004c1d8ef2 | ||
|
|
7ebb2e0193 | ||
|
|
9ce77798b1 | ||
|
|
f8bad60609 | ||
|
|
5866b3857b | ||
|
|
1ca616844b | ||
|
|
2e8cf0b450 | ||
|
|
24f89173d1 | ||
|
|
c6a20b427a | ||
|
|
a5ac9244c4 | ||
|
|
c763fe1be0 | ||
|
|
52dc8c8cd5 | ||
|
|
aede70e81d | ||
|
|
85a8beb5e4 | ||
|
|
0bb89e9e5f | ||
|
|
5685ceb3c7 | ||
|
|
0408ba0a76 | ||
|
|
cbad6c3093 | ||
|
|
1b021f6984 | ||
|
|
95b7551d65 | ||
|
|
db5a7c6192 | ||
|
|
6ef2f67e7f | ||
|
|
f76ee1ffd2 | ||
|
|
54a71f270a | ||
|
|
55b4062dd8 | ||
|
|
79071bfba4 | ||
|
|
7774b87cbd | ||
|
|
35c87741cf | ||
|
|
4cbe605214 | ||
|
|
ab8883dd55 | ||
|
|
eebe73001a | ||
|
|
0359bf02c9 | ||
|
|
237f9e58a8 | ||
|
|
8576e6fe36 | ||
|
|
0654543dcc | ||
|
|
48ef3e74e2 | ||
|
|
7d4b378952 | ||
|
|
7ff5c41e06 | ||
|
|
602f43e3d1 | ||
|
|
a2cadb8218 | ||
|
|
c1eb9d05d9 | ||
|
|
cf6c939e86 | ||
|
|
130df35e1b | ||
|
|
0751263dec | ||
|
|
eca2f3eb97 | ||
|
|
3aa9cf3f9e | ||
|
|
8f3d208dce | ||
|
|
caaa3f1f8c | ||
|
|
659a51919f | ||
|
|
6661387066 | ||
|
|
a7fae8a176 | ||
|
|
0cae0bdac8 | ||
|
|
5a1a5d5ed1 | ||
|
|
1683975acf | ||
|
|
af705590ac | ||
|
|
825124af8f | ||
|
|
9c5e7da507 | ||
|
|
481349495b | ||
|
|
9daa6b003f | ||
|
|
a3a632d567 | ||
|
|
e496c5a4b4 | ||
|
|
ea890d8710 | ||
|
|
aa5d84f102 | ||
|
|
f1606486d2 | ||
|
|
87720a8908 | ||
|
|
bb6565ef14 | ||
|
|
7bb063bcb3 | ||
|
|
b36dd472bb | ||
|
|
167b759a38 | ||
|
|
99b9868859 | ||
|
|
6b2d5448f2 | ||
|
|
eaf709b83e | ||
|
|
f0e70afff0 | ||
|
|
86984cad68 | ||
|
|
fbc89e3ced | ||
|
|
38c1e720c2 | ||
|
|
600e87e03c | ||
|
|
3836445241 | ||
|
|
1d2c9d6a07 | ||
|
|
e8ac6bd2f5 | ||
|
|
fdadc4f22c | ||
|
|
79b527f45f | ||
|
|
dc4eada7f0 | ||
|
|
70ebc3b598 | ||
|
|
b13f2aed16 | ||
|
|
5f04c0f818 | ||
|
|
55935ccae7 | ||
|
|
b529515eb1 | ||
|
|
3cde719eb7 | ||
|
|
5de6d94a90 | ||
|
|
99eefd2ec0 | ||
|
|
e9e268336b | ||
|
|
7275ac7523 | ||
|
|
c4189a38e4 | ||
|
|
68d1b3256b | ||
|
|
9c6953bda7 | ||
|
|
ef7ece9851 | ||
|
|
ddaa4b7dcb | ||
|
|
dfae2c6989 | ||
|
|
515f104926 | ||
|
|
9ecefd56db | ||
|
|
e5d35aa187 | ||
|
|
00794c42bc | ||
|
|
08a1bf3f10 | ||
|
|
60c4154346 | ||
|
|
f2c85308c1 | ||
|
|
1a28b69ee2 | ||
|
|
ba09f01ce8 | ||
|
|
6cf48872b7 | ||
|
|
7b3b8fa000 | ||
|
|
ec5e2aae61 | ||
|
|
86389bf970 | ||
|
|
3290bfa690 | ||
|
|
8777fd104f | ||
|
|
c41f7565ed | ||
|
|
9ba81e3da4 | ||
|
|
c23888acd7 | ||
|
|
f98ce25ab9 | ||
|
|
de5f38fd48 | ||
|
|
ec2854b13a | ||
|
|
90823d2938 | ||
|
|
5f5770e3a2 | ||
|
|
28f39e9038 | ||
|
|
b2d2b37888 | ||
|
|
fe597e141c | ||
|
|
72ca1539e0 | ||
|
|
13b26775f1 | ||
|
|
05d7118561 | ||
|
|
98b901ad66 | ||
|
|
5580b47291 | ||
|
|
bc62932984 | ||
|
|
a6b5d6e759 | ||
|
|
a8931306e1 | ||
|
|
fecdb8717e | ||
|
|
916fd273ea | ||
|
|
0da8506552 | ||
|
|
eda7a7b43e | ||
|
|
022eabb734 | ||
|
|
aba899cef8 | ||
|
|
6a40e1c176 | ||
|
|
9307b2ab8b | ||
|
|
522d8d3917 | ||
|
|
a84cc0123f | ||
|
|
f018e248cd | ||
|
|
cfd7237a80 | ||
|
|
4eef8102c9 | ||
|
|
69e4dd506b | ||
|
|
25814a9458 | ||
|
|
2a980a76ce | ||
|
|
d343782c8b | ||
|
|
4e1994e9d7 | ||
|
|
65a38c452b | ||
|
|
7b7e2352cd | ||
|
|
1177d28395 | ||
|
|
005e7efa64 | ||
|
|
b42d13ec84 | ||
|
|
9adcd1a650 | ||
|
|
3c164fca8c | ||
|
|
95e335db7b | ||
|
|
f90206ad74 | ||
|
|
3779150750 | ||
|
|
0a9777aa5c | ||
|
|
45ad06aac8 | ||
|
|
c6ea2ba329 | ||
|
|
2770a10240 | ||
|
|
d2a94f9e6a | ||
|
|
32da94507a | ||
|
|
736a340478 | ||
|
|
117e1355a2 | ||
|
|
3c3e558c60 | ||
|
|
cffceda6ee | ||
|
|
048805ad2c | ||
|
|
d14c9fe7ea | ||
|
|
5db90ce822 | ||
|
|
d699cc1330 | ||
|
|
c4230747a1 | ||
|
|
5245f12a46 | ||
|
|
a198b2787e | ||
|
|
04edad8c59 | ||
|
|
392b3060b0 | ||
|
|
85b34d59bc | ||
|
|
f599c11bc8 | ||
|
|
0792ff02ff | ||
|
|
fd0d63ba5b | ||
|
|
3835a428c5 | ||
|
|
9680f72cca | ||
|
|
a0737273d3 | ||
|
|
e613d0eaf0 | ||
|
|
6bcd6bcf70 | ||
|
|
ba12e4999a | ||
|
|
4e7cd31d12 | ||
|
|
5e6c130d93 | ||
|
|
5d68082881 | ||
|
|
607181644f | ||
|
|
89d327075f | ||
|
|
6bf00ef631 | ||
|
|
7d042f17fe | ||
|
|
28b8079e30 | ||
|
|
7face5d9fd | ||
|
|
a44dc4bdb0 | ||
|
|
2d0f384b6f | ||
|
|
8ff84b5c43 | ||
|
|
10b271d963 | ||
|
|
0ebc8a3d25 | ||
|
|
bbda0fdbdb | ||
|
|
c86422bdd4 | ||
|
|
c707b2b0a6 | ||
|
|
78ba24c37d | ||
|
|
1a2cb72030 | ||
|
|
344a29506e | ||
|
|
71de73a668 | ||
|
|
4c1dfa58b7 | ||
|
|
5274c3c43f | ||
|
|
1762793989 | ||
|
|
6cec78d8f2 | ||
|
|
2dc307f2e6 | ||
|
|
7aea5b1895 | ||
|
|
9733e16496 | ||
|
|
7f2d1024f3 | ||
|
|
428f589364 | ||
|
|
5cd97f7ffe | ||
|
|
e425dc00c0 | ||
|
|
d274ae77f2 | ||
|
|
55c5ac7820 | ||
|
|
0145911bea | ||
|
|
0a5215693e | ||
|
|
2a45056ba8 | ||
|
|
142b77751d | ||
|
|
a5ededf1c3 | ||
|
|
7df3f792a2 | ||
|
|
9eb7d7362f | ||
|
|
1c0c118f7c | ||
|
|
1a1b2108ec | ||
|
|
b6c6552d20 | ||
|
|
83a0340fa7 | ||
|
|
a62fc1b39f | ||
|
|
af1b725fda | ||
|
|
9174606d4c | ||
|
|
ca305afdbe | ||
|
|
fe5987b81d | ||
|
|
a229c8cef0 | ||
|
|
f6c0499b8d | ||
|
|
1156c84e86 | ||
|
|
ec7c7def40 | ||
|
|
2d8e667400 | ||
|
|
80c863b972 | ||
|
|
f5cc1eea72 | ||
|
|
b7c9f1d38f | ||
|
|
c6fc07f1f4 | ||
|
|
ded914f442 | ||
|
|
4758c8baa1 | ||
|
|
7064fed1b1 | ||
|
|
1017ac4a9e | ||
|
|
ccb61d7aae | ||
|
|
2235dee906 | ||
|
|
28091aa1ff | ||
|
|
121d9a0702 | ||
|
|
0cea88bcc5 | ||
|
|
72146fc4cd | ||
|
|
e6a7ab9675 | ||
|
|
1f4c127fb9 | ||
|
|
90532b1f37 | ||
|
|
a8666a757a | ||
|
|
a4667da1eb | ||
|
|
0c259961ac | ||
|
|
f288db8d34 | ||
|
|
33421c1dd3 | ||
|
|
5cc5201914 | ||
|
|
252e423e81 | ||
|
|
a4a2764a52 | ||
|
|
ab8e832c18 | ||
|
|
1ce0c0fcb0 | ||
|
|
657f466402 | ||
|
|
c7b0300af5 | ||
|
|
da8c885784 | ||
|
|
1ccaf80575 | ||
|
|
ec36bfa317 | ||
|
|
b8f76f717a | ||
|
|
d1766f2c70 | ||
|
|
516ded618b | ||
|
|
c9c81d0584 | ||
|
|
545f84d905 | ||
|
|
d5ec172c95 | ||
|
|
25b3a3e541 | ||
|
|
058d6ce683 | ||
|
|
eab93985b8 | ||
|
|
b51d70a83c | ||
|
|
259025100e | ||
|
|
c9d30aa6ac | ||
|
|
8544b42007 | ||
|
|
6fa0501387 | ||
|
|
ae69cb15e9 | ||
|
|
a64a8dfe45 | ||
|
|
491fa95b1f | ||
|
|
92ec632ad5 | ||
|
|
8ecdfb718b | ||
|
|
4ba0c24a8f | ||
|
|
935c8c4bb1 | ||
|
|
88f993da38 | ||
|
|
ebfe64b92d | ||
|
|
0308e9af71 | ||
|
|
c3628eea49 | ||
|
|
e03f0372b1 | ||
|
|
f17536af9c | ||
|
|
ed4ec81bca | ||
|
|
7480059306 | ||
|
|
8bae22b0fa | ||
|
|
49c34c4161 | ||
|
|
5548fcc96d | ||
|
|
070bd433ab | ||
|
|
c8fb54951a | ||
|
|
f110357aaa | ||
|
|
a6b426422e | ||
|
|
d03c01dfbc | ||
|
|
a82996e9fb | ||
|
|
af5a614aad | ||
|
|
f9640e049d | ||
|
|
4768c61b57 | ||
|
|
dfccd17ab9 | ||
|
|
635117c5d4 | ||
|
|
50f3535693 | ||
|
|
9111999af3 | ||
|
|
6bd28d246e | ||
|
|
4d595a2a39 | ||
|
|
3a21f61772 | ||
|
|
4e1e9520e1 | ||
|
|
0bf19037ca | ||
|
|
f3dfa36a3a | ||
|
|
4f9b60dd53 | ||
|
|
f76a49e555 | ||
|
|
310ad8d9db | ||
|
|
56db268f47 | ||
|
|
92ab6bdeb8 | ||
|
|
0070e360a1 | ||
|
|
9df8fed046 | ||
|
|
a59fae040f | ||
|
|
29a620cab2 | ||
|
|
87d7a2520e | ||
|
|
40c62c1321 | ||
|
|
35b412c099 | ||
|
|
d0f471cff7 | ||
|
|
6f316b8bf5 | ||
|
|
7c10c93a1f | ||
|
|
d92ea094f1 | ||
|
|
6ae5423b4a | ||
|
|
9635cffdc8 | ||
|
|
96986fb362 | ||
|
|
3ceb341a75 | ||
|
|
50fa705125 | ||
|
|
69a2991614 | ||
|
|
fd3377dd1f | ||
|
|
d0b6cb0425 | ||
|
|
95c4a2e3af | ||
|
|
bc2a29f033 | ||
|
|
3bb5b4a302 | ||
|
|
fc88fd9097 | ||
|
|
c5b0928c1f | ||
|
|
e047fd977d | ||
|
|
9d40e521d7 | ||
|
|
1445dcaa60 | ||
|
|
e4eeb4e910 | ||
|
|
aa86876813 | ||
|
|
974bb54ab2 | ||
|
|
9bc2183a31 | ||
|
|
d4b222b6d3 | ||
|
|
af2af818a6 | ||
|
|
698e63a608 | ||
|
|
211411faf2 | ||
|
|
bb303c45a5 | ||
|
|
6f7986d592 | ||
|
|
7cbb4aef17 | ||
|
|
02bec0bb6d | ||
|
|
c79f6a4a8c | ||
|
|
0c5eea226b | ||
|
|
dcca0d7477 | ||
|
|
0d5e7716ad | ||
|
|
d8c824c594 | ||
|
|
cb431dfc9f | ||
|
|
61d787726a | ||
|
|
5e89aace9b | ||
|
|
2af7e8a9a6 | ||
|
|
2419edd5b2 | ||
|
|
bf481e8e5d | ||
|
|
9d7fa6b8e6 | ||
|
|
073076ac7d | ||
|
|
9bd03dd9b4 | ||
|
|
6931f84412 | ||
|
|
16ec0556a0 | ||
|
|
610af352d4 | ||
|
|
b35f1e3c9c | ||
|
|
dfa0b9aab4 | ||
|
|
a4c47b0276 | ||
|
|
111fefd5e9 | ||
|
|
c1fe1ef081 | ||
|
|
8c34c9dac4 | ||
|
|
91c0277356 | ||
|
|
9f0d5c12fc | ||
|
|
59247c2b62 | ||
|
|
9a3842a2d9 | ||
|
|
726dbd9267 | ||
|
|
54f05e7195 | ||
|
|
26be608470 | ||
|
|
248431eb3c | ||
|
|
76f275b4df | ||
|
|
f1951d6cce | ||
|
|
62f297b51d | ||
|
|
09bc32f62f | ||
|
|
46d8b16ab4 | ||
|
|
42533931fa | ||
|
|
9bd3a7102f | ||
|
|
9e516b71ea | ||
|
|
eac961ddb1 | ||
|
|
57c6aa7188 | ||
|
|
cde5b4ad80 | ||
|
|
4f72c66911 | ||
|
|
960e3f0f05 | ||
|
|
884af42da2 | ||
|
|
048fabdabd | ||
|
|
917252a5a1 | ||
|
|
1a992e31e8 | ||
|
|
d2ff04a4f2 | ||
|
|
015c247393 | ||
|
|
d3cd26820e | ||
|
|
91f6c499d7 | ||
|
|
35e9c87ab9 | ||
|
|
8e88e30d95 | ||
|
|
0eb56d5be0 | ||
|
|
f70764a162 | ||
|
|
dad1b00b13 | ||
|
|
430ffef58a | ||
|
|
3d17077187 | ||
|
|
c9b41d460f | ||
|
|
32972a5924 | ||
|
|
f6afb9c09b | ||
|
|
3ddc07e936 | ||
|
|
c26208f67d | ||
|
|
d15fa13daf | ||
|
|
58a855682c | ||
|
|
92d7cb71f8 | ||
|
|
50d8bed468 | ||
|
|
9dd72cd421 | ||
|
|
343aa46b78 | ||
|
|
b8ab89b413 | ||
|
|
f9f8c167d4 | ||
|
|
3f86399922 | ||
|
|
2b8ace6a03 | ||
|
|
0ab8e099e8 | ||
|
|
020f048cd0 | ||
|
|
881615b072 | ||
|
|
0eef4febfd | ||
|
|
b54a70ec2d | ||
|
|
bf6ec92216 | ||
|
|
c21331d47f | ||
|
|
e1c9600da3 | ||
|
|
1fa0d20a30 | ||
|
|
3274c6a087 | ||
|
|
9b12093739 | ||
|
|
f374b6ca4d | ||
|
|
0070e1db40 | ||
|
|
95d04805b3 | ||
|
|
e4534dac17 | ||
|
|
fef3c4ec1d | ||
|
|
1bdc038bf9 | ||
|
|
5523d9c426 | ||
|
|
d878015228 | ||
|
|
5900e3249f | ||
|
|
bacced53d3 | ||
|
|
4a64d4bff1 | ||
|
|
b1e2b53c2d | ||
|
|
11354d5bff | ||
|
|
718aea3f1d | ||
|
|
5b6f38df2b | ||
|
|
0b4a58699e | ||
|
|
4f9f9ebb6f | ||
|
|
afc9c0ec1b | ||
|
|
195b429d99 | ||
|
|
2b878e9dd7 | ||
|
|
67b6bf530d | ||
|
|
6af5ca35b2 | ||
|
|
4f46e9c997 | ||
|
|
c6739ba7f3 | ||
|
|
914409fef9 | ||
|
|
8d68a3e805 | ||
|
|
6bbcc453ef | ||
|
|
d5ed4d7a71 | ||
|
|
669c27140d | ||
|
|
adcc88e208 | ||
|
|
d6492b0163 | ||
|
|
b3f52c9fbe | ||
|
|
bd8396fad8 | ||
|
|
d0c58841d1 | ||
|
|
881f09b2e2 | ||
|
|
8b30acd7eb | ||
|
|
02efb310ca | ||
|
|
e7e59c6f05 | ||
|
|
3ae6aabe9f | ||
|
|
dc627dcb5e | ||
|
|
efeb9c0f02 | ||
|
|
ba3e913c7a | ||
|
|
7cca1727af | ||
|
|
11371fe251 | ||
|
|
41c603d48a | ||
|
|
969337345f | ||
|
|
9592766939 | ||
|
|
58dca7d846 | ||
|
|
0d302cd25b | ||
|
|
da691257ec | ||
|
|
1600092e92 | ||
|
|
dba2bd1105 | ||
|
|
28be4de7c2 | ||
|
|
a6c3b38fba | ||
|
|
fcb65a3897 | ||
|
|
4e22a1dffe | ||
|
|
291cf40aca | ||
|
|
bd47e1f066 | ||
|
|
e6b223df5f | ||
|
|
e64349bbdd | ||
|
|
cdb59faea6 | ||
|
|
1d94ac3f90 | ||
|
|
5f7d19d1f5 | ||
|
|
2fdf9eb535 | ||
|
|
860d3a50d7 | ||
|
|
d1183821a7 | ||
|
|
8081df79be | ||
|
|
64bec4fad7 | ||
|
|
b96e105244 | ||
|
|
3b4d5484c7 | ||
|
|
684e11c664 | ||
|
|
b57a52813b | ||
|
|
da8deb2b62 | ||
|
|
98b6ce3460 | ||
|
|
f9e00efe31 | ||
|
|
0fd2a1f4b0 | ||
|
|
df3233454d | ||
|
|
82db84b899 | ||
|
|
8ae751d3da | ||
|
|
d40e76809f | ||
|
|
bb1b76d9dc | ||
|
|
9d26441224 | ||
|
|
f12f24a77c | ||
|
|
ae5b5cabfd | ||
|
|
d0630ffe8c | ||
|
|
99bb7d3a58 | ||
|
|
63ae767232 | ||
|
|
eaaea02010 | ||
|
|
a098bc92e0 | ||
|
|
1086dc4db0 | ||
|
|
19fb69e2ed | ||
|
|
9231617eb3 | ||
|
|
32668a7317 | ||
|
|
780c197f95 | ||
|
|
eb8819e91e | ||
|
|
30bbea2f08 | ||
|
|
635ccd9e25 | ||
|
|
8c9f0278b9 | ||
|
|
58d0e199e1 | ||
|
|
10b5835501 | ||
|
|
6c8dd307eb | ||
|
|
43ffdab172 | ||
|
|
40b6d67333 | ||
|
|
c52d1600f0 | ||
|
|
aa1d6cadad | ||
|
|
6e06e3a904 | ||
|
|
8cfb9fc0b8 | ||
|
|
7b456fd2c0 | ||
|
|
e9e53856d2 | ||
|
|
5029894662 | ||
|
|
baf9fa5f42 | ||
|
|
7f914365fd | ||
|
|
ebd7135b50 | ||
|
|
50eff6a10a | ||
|
|
c34a5ae7f7 | ||
|
|
e2aa6ec8ae | ||
|
|
6768c6a54a | ||
|
|
6307d166eb | ||
|
|
1fba87b0df | ||
|
|
df124e018a | ||
|
|
2f83d6e4b7 | ||
|
|
987785d8d7 | ||
|
|
8c01a7893b | ||
|
|
218047c75a | ||
|
|
d0da74209b | ||
|
|
5c1fa64fb0 | ||
|
|
a3c287354f | ||
|
|
03cf033f82 | ||
|
|
bdb36c9a63 | ||
|
|
20bb301195 | ||
|
|
d6383a1c6a | ||
|
|
b05bcfd27f | ||
|
|
2615660e62 | ||
|
|
5b0af4cdb1 | ||
|
|
8c2e15e6c8 | ||
|
|
56c8a33439 | ||
|
|
4eef1e8a3e | ||
|
|
95d11bda06 | ||
|
|
af9079cc1f | ||
|
|
2d6cd47713 | ||
|
|
fe3167d7ea | ||
|
|
31e134be35 | ||
|
|
e84ba8056d | ||
|
|
f20e97b092 | ||
|
|
934683088e | ||
|
|
de2b9e7d0a | ||
|
|
dd7d8e5e29 | ||
|
|
df964132fb | ||
|
|
709ccc6800 | ||
|
|
cf236fc390 | ||
|
|
27d70c7d9d | ||
|
|
0e585b4409 | ||
|
|
0163a8e57a | ||
|
|
578842954c | ||
|
|
496315fe1d | ||
|
|
0fe6895893 | ||
|
|
0b7d71fd2f | ||
|
|
83b11bc58d | ||
|
|
375a8bbdcc | ||
|
|
ea9090bbc4 | ||
|
|
81def6ac76 | ||
|
|
3de8ce3f3c | ||
|
|
4d485fca24 | ||
|
|
1865299a30 | ||
|
|
3576b547c5 | ||
|
|
079882495d | ||
|
|
ab977109db | ||
|
|
fd1c08137b | ||
|
|
76b6cece46 | ||
|
|
9f0df51f8d | ||
|
|
e7a2a3dcd1 | ||
|
|
a87ef5bfc1 | ||
|
|
9f9cb7a2ef | ||
|
|
7e26fd8032 | ||
|
|
eab2685c67 | ||
|
|
50dfb664db | ||
|
|
0189ab6ab6 | ||
|
|
9401507336 | ||
|
|
eb8321d863 | ||
|
|
79ef49b2c2 | ||
|
|
e110ca11e2 | ||
|
|
226748b3e7 | ||
|
|
d568c7ee36 | ||
|
|
e6fecbb3e1 | ||
|
|
da83f899bb | ||
|
|
7e5674d8be | ||
|
|
0a558577bf | ||
|
|
fb71a82ada | ||
|
|
23406c9e9e | ||
|
|
b3ec792380 | ||
|
|
6a9b584f3d | ||
|
|
81dd33af66 | ||
|
|
8b76571896 | ||
|
|
e78a6518fa | ||
|
|
1873ffda01 | ||
|
|
c417e42116 | ||
|
|
358e1fd6ab | ||
|
|
631dfbe673 | ||
|
|
56a4eaed72 | ||
|
|
bf925d9dc7 | ||
|
|
1a7ed5dcb6 | ||
|
|
5be5daa6ef | ||
|
|
60cb11764e | ||
|
|
cbd5445ea7 | ||
|
|
2c7e9b5158 | ||
|
|
2263e4b279 | ||
|
|
863039da4c | ||
|
|
7178ac0111 | ||
|
|
e7f9710499 | ||
|
|
ff4223904d | ||
|
|
a9f80d60f6 | ||
|
|
2e158cf6d0 | ||
|
|
8bd6bfa4b5 | ||
|
|
8b1906abd0 | ||
|
|
06375e6605 | ||
|
|
b21242faf1 | ||
|
|
cc05a281c4 | ||
|
|
fe96ceee66 | ||
|
|
9814a2ae12 | ||
|
|
6992498e7a | ||
|
|
21623156a3 | ||
|
|
79c859e2e0 | ||
|
|
b00ac960b4 | ||
|
|
02a9fc7bfa | ||
|
|
f390957685 | ||
|
|
17f57df797 | ||
|
|
7f7b9662ea | ||
|
|
19bef39f5c | ||
|
|
a30e7ed2da | ||
|
|
8db7161c94 | ||
|
|
09f1777896 | ||
|
|
490c0c4fdc | ||
|
|
c4a471c99d | ||
|
|
86f495985b | ||
|
|
67d1894759 | ||
|
|
5bfe89bdb1 | ||
|
|
82463e9938 | ||
|
|
771575d27b | ||
|
|
20a01bbd9f | ||
|
|
ec8578d41a | ||
|
|
d0dbfe0b97 | ||
|
|
3d405fb3b1 | ||
|
|
b0012cdd0f | ||
|
|
84d61d27aa | ||
|
|
ed83908931 | ||
|
|
ef5f7d1aea | ||
|
|
090ff659dc | ||
|
|
85c8a91a27 | ||
|
|
581b699ac9 | ||
|
|
8a0677d56d | ||
|
|
b18468bf81 | ||
|
|
107ba2891a | ||
|
|
cd9e184529 | ||
|
|
2e7c02d5cd | ||
|
|
ae18326533 | ||
|
|
91eba8e485 | ||
|
|
d07e295c62 | ||
|
|
dce4bd74a4 | ||
|
|
ffff671273 | ||
|
|
12d4507ee3 | ||
|
|
8580d997ff | ||
|
|
061cf9a4ce | ||
|
|
99abb9eff4 | ||
|
|
fffe072028 | ||
|
|
a1a31eed27 | ||
|
|
ae812350f9 | ||
|
|
b63ef10a7f | ||
|
|
42afe27e12 | ||
|
|
76e63212ff | ||
|
|
aac2f9fb61 | ||
|
|
bddf23f175 | ||
|
|
039da779d1 | ||
|
|
d88d2124b5 | ||
|
|
e142aaf8a1 | ||
|
|
0caf35f4b8 | ||
|
|
3fc993f82d | ||
|
|
741eb28443 | ||
|
|
1a87dc5ea8 | ||
|
|
2427fa171e | ||
|
|
639e06e1f3 | ||
|
|
02fedbf1da | ||
|
|
110d9b149d | ||
|
|
9cbff5ec1d | ||
|
|
433c0206b0 | ||
|
|
8915901966 | ||
|
|
f48bc496c7 | ||
|
|
913b19329c | ||
|
|
d8cb3128f6 | ||
|
|
5f9ba3019f | ||
|
|
46caf0bef0 | ||
|
|
45f636e759 | ||
|
|
a7b404ff53 | ||
|
|
c4fd0e5ede | ||
|
|
bab5386306 | ||
|
|
aca7584635 | ||
|
|
d611251502 | ||
|
|
f30b659291 | ||
|
|
90dfa43ff1 | ||
|
|
dc175f08d3 | ||
|
|
29221fa238 | ||
|
|
a789685c63 | ||
|
|
240d10699c | ||
|
|
925014b661 | ||
|
|
5611e1a95e | ||
|
|
570f2bf29e | ||
|
|
9948eddf11 | ||
|
|
a3ee03da01 | ||
|
|
28fcd2b519 | ||
|
|
8e686764ac | ||
|
|
479051ce1c | ||
|
|
bfb5bad4f0 | ||
|
|
1e16331d9c | ||
|
|
be98f4ab6b | ||
|
|
6ee1112f30 | ||
|
|
8e5a5a1ccd | ||
|
|
fcda3a0e66 | ||
|
|
9663c22fe9 | ||
|
|
f0ae00da12 | ||
|
|
44390bd3d0 | ||
|
|
2225374060 | ||
|
|
105d236889 | ||
|
|
53e6a9367c | ||
|
|
f5a1582fe8 | ||
|
|
a54f06b16f | ||
|
|
4650d94d98 | ||
|
|
a5681ebc52 | ||
|
|
e849b3424a | ||
|
|
b219d12a6b | ||
|
|
cec8661113 | ||
|
|
73a8c090e0 | ||
|
|
db6796ac61 | ||
|
|
9a8ee00246 | ||
|
|
d39ed54f8e | ||
|
|
16546c70d8 | ||
|
|
eaba55c9bf | ||
|
|
19ec023256 | ||
|
|
63ab0ab580 | ||
|
|
8dfc376c00 | ||
|
|
1efee9db09 | ||
|
|
43abc402d8 | ||
|
|
3f8b1668c4 | ||
|
|
76c919b4ec | ||
|
|
29d0c10ee5 | ||
|
|
5ad133f8bb | ||
|
|
d0c544a868 | ||
|
|
ffb19df3c0 | ||
|
|
8b7532b9ab | ||
|
|
366478c560 | ||
|
|
8e5600022a | ||
|
|
0e95b64942 | ||
|
|
0ae22b915b | ||
|
|
7c441600fe | ||
|
|
a4d290adb9 | ||
|
|
28301807c2 | ||
|
|
74ed0974b3 | ||
|
|
ec8a4864fa | ||
|
|
b7588fd5d7 | ||
|
|
f512b905c7 | ||
|
|
afd5274049 | ||
|
|
1074674e32 | ||
|
|
7762e07fde | ||
|
|
cbefd9129e | ||
|
|
e39bebe13e | ||
|
|
14b4e51a7c | ||
|
|
cbcf44a4ca | ||
|
|
859ae15a54 | ||
|
|
0787724c44 | ||
|
|
7b463ffb07 | ||
|
|
6686e61ca4 | ||
|
|
c096a77b9b | ||
|
|
5121f028d9 | ||
|
|
6a665ea6ed | ||
|
|
bc06cb9ff6 | ||
|
|
8e281c76c3 | ||
|
|
d5964a2710 | ||
|
|
cf3eb87e52 | ||
|
|
ab3a466711 | ||
|
|
4494970f47 | ||
|
|
776c3d226d | ||
|
|
f5f18b704f | ||
|
|
420ff2f331 | ||
|
|
56ba3ec40e | ||
|
|
de3d2467a3 | ||
|
|
fe1dabf272 | ||
|
|
08226ab491 | ||
|
|
3b661b7394 | ||
|
|
e6418781ab | ||
|
|
ac02cf33bd | ||
|
|
22364c40b7 | ||
|
|
d729a1991b | ||
|
|
126c9869c8 | ||
|
|
ad4a45e615 | ||
|
|
04fc896016 | ||
|
|
884b4ed43b | ||
|
|
972d9a3aea | ||
|
|
7dcdd88e27 | ||
|
|
8120a3b65c | ||
|
|
5798256fcf | ||
|
|
d0fda82595 | ||
|
|
f883fcede0 | ||
|
|
e1bdf6a8d9 | ||
|
|
1a4f4c5ea6 | ||
|
|
0925af43b0 | ||
|
|
dc937b8ed3 | ||
|
|
c3965fc5ee | ||
|
|
bf7cd29970 | ||
|
|
a000d2288c | ||
|
|
165abf0e4c | ||
|
|
818cda16bc | ||
|
|
85143fecdd | ||
|
|
35431a4ac8 | ||
|
|
ccf1645995 | ||
|
|
1a48713d32 | ||
|
|
1eb04aa23f | ||
|
|
0c65517e91 | ||
|
|
2fdc2462c3 | ||
|
|
be6e9d6a9f | ||
|
|
e54cbb7ba6 | ||
|
|
40c108766b | ||
|
|
4cc70290f7 | ||
|
|
74caa68d02 | ||
|
|
3756381358 | ||
|
|
d12573daa6 | ||
|
|
0dbc4c7547 | ||
|
|
06072601ce | ||
|
|
11d2c8f7a1 | ||
|
|
7f3f8d8f8d | ||
|
|
b96be943dc | ||
|
|
b670485185 | ||
|
|
b57bd0488d |
@@ -1,209 +0,0 @@
|
|||||||
version: 2.1
|
|
||||||
|
|
||||||
parameters:
|
|
||||||
nightly_build:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
weekly_build:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
linux_build_and_test:
|
|
||||||
docker:
|
|
||||||
- image: cimg/python:3.9
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Run style checks
|
|
||||||
command: |
|
|
||||||
pip install pre-commit
|
|
||||||
pre-commit run --all
|
|
||||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install --upgrade pybind11[global]
|
|
||||||
pip install pybind11-stubgen
|
|
||||||
pip install numpy
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
- run:
|
|
||||||
name: Install Python package
|
|
||||||
command: |
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
|
|
||||||
- run:
|
|
||||||
name: Generate package stubs
|
|
||||||
command: |
|
|
||||||
python3 setup.py generate_stubs
|
|
||||||
- run:
|
|
||||||
name: Run Python tests
|
|
||||||
command: |
|
|
||||||
python3 -m unittest discover python/tests -v
|
|
||||||
# TODO: Reenable when extension api becomes stable
|
|
||||||
# - run:
|
|
||||||
# name: Build example extension
|
|
||||||
# command: |
|
|
||||||
# cd examples/extensions && python3 -m pip install .
|
|
||||||
- run:
|
|
||||||
name: Build CPP only
|
|
||||||
command: |
|
|
||||||
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j
|
|
||||||
- run:
|
|
||||||
name: Run CPP tests
|
|
||||||
command: ./build/tests/tests
|
|
||||||
|
|
||||||
mac_build_and_test:
|
|
||||||
macos:
|
|
||||||
xcode: "15.2.0"
|
|
||||||
resource_class: macos.m1.large.gen1
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
brew install python@3.9
|
|
||||||
python3.9 -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install --upgrade pybind11[global]
|
|
||||||
pip install pybind11-stubgen
|
|
||||||
pip install numpy
|
|
||||||
pip install torch
|
|
||||||
pip install tensorflow
|
|
||||||
pip install unittest-xml-reporting
|
|
||||||
- run:
|
|
||||||
name: Install Python package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
|
|
||||||
- run:
|
|
||||||
name: Generate package stubs
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
|
||||||
name: Run Python tests
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
|
||||||
# TODO: Reenable when Circle CI can run gpu jobs
|
|
||||||
# DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
|
|
||||||
# TODO: Reenable when extension api becomes stable
|
|
||||||
# - run:
|
|
||||||
# name: Build example extension
|
|
||||||
# command: |
|
|
||||||
# cd examples/extensions && python3.11 -m pip install .
|
|
||||||
- store_test_results:
|
|
||||||
path: test-results
|
|
||||||
- run:
|
|
||||||
name: Build CPP only
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
mkdir -p build && cd build && cmake .. && make -j
|
|
||||||
- run:
|
|
||||||
name: Run CPP tests
|
|
||||||
#command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
|
||||||
command: DEVICE=cpu ./build/tests/tests
|
|
||||||
|
|
||||||
build_release:
|
|
||||||
parameters:
|
|
||||||
python_version:
|
|
||||||
type: string
|
|
||||||
default: "3.9"
|
|
||||||
xcode_version:
|
|
||||||
type: string
|
|
||||||
default: "15.2.0"
|
|
||||||
build_env:
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
macos:
|
|
||||||
xcode: << parameters.xcode_version >>
|
|
||||||
resource_class: macos.m1.large.gen1
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
brew install python@<< parameters.python_version >>
|
|
||||||
python<< parameters.python_version >> -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install --upgrade pybind11[global]
|
|
||||||
pip install --upgrade setuptools
|
|
||||||
pip install pybind11-stubgen
|
|
||||||
pip install numpy
|
|
||||||
pip install twine
|
|
||||||
pip install build
|
|
||||||
- run:
|
|
||||||
name: Install Python package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
DEV_RELEASE=1 \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
|
||||||
pip install . -v
|
|
||||||
- run:
|
|
||||||
name: Generate package stubs
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
|
||||||
name: Build Python package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
<< parameters.build_env >> \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
|
||||||
python -m build -w
|
|
||||||
- when:
|
|
||||||
condition: << parameters.build_env >>
|
|
||||||
steps:
|
|
||||||
- run:
|
|
||||||
name: Upload package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
twine upload dist/*
|
|
||||||
- store_artifacts:
|
|
||||||
path: dist/
|
|
||||||
|
|
||||||
workflows:
|
|
||||||
build_and_test:
|
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
|
||||||
jobs:
|
|
||||||
- mac_build_and_test
|
|
||||||
- linux_build_and_test
|
|
||||||
- build_release:
|
|
||||||
filters:
|
|
||||||
tags:
|
|
||||||
only: /^v.*/
|
|
||||||
branches:
|
|
||||||
ignore: /.*/
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
|
||||||
xcode_version: ["14.3.1", "15.2.0"]
|
|
||||||
build_env: ["PYPI_RELEASE=1"]
|
|
||||||
nightly_build:
|
|
||||||
when: << pipeline.parameters.nightly_build >>
|
|
||||||
jobs:
|
|
||||||
- build_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
|
||||||
xcode_version: ["14.3.1", "15.2.0"]
|
|
||||||
weekly_build:
|
|
||||||
when: << pipeline.parameters.weekly_build >>
|
|
||||||
jobs:
|
|
||||||
- build_release:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
|
||||||
xcode_version: ["14.3.1", "15.2.0"]
|
|
||||||
build_env: ["DEV_RELEASE=1"]
|
|
||||||
20
.github/actions/build-cuda-release/action.yml
vendored
Normal file
20
.github/actions/build-cuda-release/action.yml
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
name: 'Build CUDA wheel'
|
||||||
|
description: 'Build CUDA wheel'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
toolkit:
|
||||||
|
description: 'The CUDA toolkit'
|
||||||
|
required: true
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Build package
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
||||||
|
run: |
|
||||||
|
pip install auditwheel build patchelf setuptools
|
||||||
|
python setup.py clean --all
|
||||||
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
|
bash python/scripts/repair_cuda.sh
|
||||||
26
.github/actions/build-cuda/action.yml
vendored
Normal file
26
.github/actions/build-cuda/action.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
name: 'Build and Test with CUDA'
|
||||||
|
description: 'Build and test MLX with CUDA'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
toolkit:
|
||||||
|
description: 'The CUDA toolkit'
|
||||||
|
required: true
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install Python package
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEBUG: 1
|
||||||
|
CMAKE_ARGS: -DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc
|
||||||
|
run: pip install --no-build-isolation -e ".[dev]" -v
|
||||||
|
|
||||||
|
- name: Build CPP only
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cmake . -B build \
|
||||||
|
-DMLX_BUILD_CUDA=ON \
|
||||||
|
-DCMAKE_CUDA_COMPILER=/usr/local/${{ inputs.toolkit }}/bin/nvcc \
|
||||||
|
-DCMAKE_BUILD_TYPE=DEBUG
|
||||||
|
cmake --build build -j $(nproc)
|
||||||
38
.github/actions/build-docs/action.yml
vendored
Normal file
38
.github/actions/build-docs/action.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
name: 'Build Documentation'
|
||||||
|
description: 'Build documentation'
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Setup machine
|
||||||
|
uses: ./.github/actions/setup-linux
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
sudo apt-get install -y doxygen
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install -r docs/requirements.txt
|
||||||
|
pip install . -v
|
||||||
|
|
||||||
|
- name: Build documentation
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
cd docs
|
||||||
|
doxygen
|
||||||
|
make html O=-W
|
||||||
|
|
||||||
|
- name: Create artifact tar
|
||||||
|
shell: bash
|
||||||
|
run: tar -cf artifact.tar -C docs --dereference build/html index.html
|
||||||
|
|
||||||
|
# Do it manually because upload-pages-artifact requires gtar
|
||||||
|
- name: Upload artifact
|
||||||
|
id: upload-artifact
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
name: github-pages
|
||||||
|
path: artifact.tar
|
||||||
|
retention-days: 1
|
||||||
|
if-no-files-found: error
|
||||||
40
.github/actions/build-linux-release/action.yml
vendored
Normal file
40
.github/actions/build-linux-release/action.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
name: 'Build Linux wheel'
|
||||||
|
description: 'Build Linux wheel'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
build-backend:
|
||||||
|
description: 'Build the backend mlx-cpu package'
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
arch:
|
||||||
|
description: 'Platform architecture tag'
|
||||||
|
required: true
|
||||||
|
type: choice
|
||||||
|
options:
|
||||||
|
- x86_64
|
||||||
|
- aarch64
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Generate package stubs
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
pip install -e ".[dev]" -v
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
- name: Build Python package
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
pip install auditwheel patchelf build
|
||||||
|
python setup.py clean --all
|
||||||
|
MLX_BUILD_STAGE=1 python -m build -w
|
||||||
|
bash python/scripts/repair_linux.sh ${{ inputs.arch }}
|
||||||
|
- name: Build backend package
|
||||||
|
if: ${{ inputs.build-backend }}
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python setup.py clean --all
|
||||||
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
|
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_${{ inputs.arch }}
|
||||||
25
.github/actions/build-linux/action.yml
vendored
Normal file
25
.github/actions/build-linux/action.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
name: 'Build and Test on Linux'
|
||||||
|
description: 'Build and test MLX on Linux'
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install Python package
|
||||||
|
shell: sh
|
||||||
|
env:
|
||||||
|
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||||
|
DEBUG: 1
|
||||||
|
run: pip install --no-build-isolation -e ".[dev]" -v
|
||||||
|
|
||||||
|
- name: Generate package stubs
|
||||||
|
shell: sh
|
||||||
|
run: |
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
|
||||||
|
- name: Build CPP only
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||||
|
make -j $(nproc)
|
||||||
30
.github/actions/build-macos-release/action.yml
vendored
Normal file
30
.github/actions/build-macos-release/action.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
name: 'Build macOS release'
|
||||||
|
description: 'Build MLX releases macOS'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
macos-target:
|
||||||
|
description: 'macOS build target'
|
||||||
|
required: false
|
||||||
|
default: '15.0'
|
||||||
|
build-backend:
|
||||||
|
description: 'Build the backend mlx-metal package'
|
||||||
|
type: boolean
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Build Python package
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install build
|
||||||
|
python setup.py clean --all
|
||||||
|
MLX_BUILD_STAGE=1 python -m build -w
|
||||||
|
|
||||||
|
- name: Build backend package
|
||||||
|
if: ${{ inputs.build-backend }}
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
python setup.py clean --all
|
||||||
|
MLX_BUILD_STAGE=2 python -m build -w
|
||||||
88
.github/actions/build-macos/action.yml
vendored
Normal file
88
.github/actions/build-macos/action.yml
vendored
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
name: 'Build and Test on macOS'
|
||||||
|
description: 'Build and test MLX on macOS'
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install dependencies
|
||||||
|
env:
|
||||||
|
DEBUG: 1
|
||||||
|
CMAKE_ARGS: "-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install cmake setuptools nanobind==2.4.0
|
||||||
|
pip install -e . -v
|
||||||
|
|
||||||
|
- name: Generate package stubs
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
|
||||||
|
- name: Install tests dependencies
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install numpy torch tensorflow unittest-xml-reporting
|
||||||
|
|
||||||
|
- name: Run Python tests
|
||||||
|
shell: bash -l {0}
|
||||||
|
env:
|
||||||
|
LOW_MEMORY: 1
|
||||||
|
run: |
|
||||||
|
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
|
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
|
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||||
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||||
|
|
||||||
|
- name: Build example extension
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
cd examples/extensions
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python setup.py build_ext --inplace
|
||||||
|
python test.py
|
||||||
|
|
||||||
|
- name: Build CPP only
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
mkdir -p build
|
||||||
|
cd build
|
||||||
|
cmake ..
|
||||||
|
make -j $(sysctl -n hw.ncpu)
|
||||||
|
|
||||||
|
- name: Run CPP tests
|
||||||
|
shell: bash -l {0}
|
||||||
|
env:
|
||||||
|
DEVICE: gpu
|
||||||
|
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||||
|
METAL_DEBUG_ERROR_MODE: 0
|
||||||
|
run: ./build/tests/tests
|
||||||
|
|
||||||
|
- name: Build small binary with JIT
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
mkdir -p build
|
||||||
|
cd build
|
||||||
|
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
|
-DMLX_BUILD_CPU=OFF \
|
||||||
|
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||||
|
-DMLX_BUILD_GGUF=OFF \
|
||||||
|
-DMLX_METAL_JIT=ON
|
||||||
|
make -j $(sysctl -n hw.ncpu)
|
||||||
|
|
||||||
|
- name: Run Python tests with JIT
|
||||||
|
shell: bash -l {0}
|
||||||
|
env:
|
||||||
|
LOW_MEMORY: 1
|
||||||
|
DEVICE: gpu
|
||||||
|
METAL_DEVICE_WRAPPER_TYPE: 1
|
||||||
|
METAL_DEBUG_ERROR_MODE: 0
|
||||||
|
run: |
|
||||||
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||||
|
pip install -e . -v
|
||||||
|
python -m xmlrunner discover \
|
||||||
|
-v python/tests \
|
||||||
|
-o test-results/gpu_jit
|
||||||
85
.github/actions/setup-linux/action.yml
vendored
Normal file
85
.github/actions/setup-linux/action.yml
vendored
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
name: 'Setup Linux Environment'
|
||||||
|
description: 'Install dependencies for Linux builds'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
toolkit:
|
||||||
|
description: 'Which toolkit to install'
|
||||||
|
required: false
|
||||||
|
default: 'cpu'
|
||||||
|
python-version:
|
||||||
|
description: 'Version of python to set up'
|
||||||
|
required: false
|
||||||
|
default: '3.10'
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Use ccache
|
||||||
|
uses: hendrikmuhs/ccache-action@v1.2
|
||||||
|
with:
|
||||||
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
|
||||||
|
max-size: 1GB
|
||||||
|
|
||||||
|
- name: Install common dependencies
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version: ${{ inputs.python-version }}
|
||||||
|
|
||||||
|
- name: Setup Python venv
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install setuptools cmake nanobind==2.4.0
|
||||||
|
echo PATH=$PATH >> $GITHUB_ENV
|
||||||
|
# Make cmake search .venv for nanobind
|
||||||
|
echo PYTHONPATH=`python -c 'import sys; print(sys.path[-1])'` >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Install MPI
|
||||||
|
shell: bash
|
||||||
|
run: sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev
|
||||||
|
|
||||||
|
- name: Install CUDA toolkit
|
||||||
|
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
# Note: the CI machine does not meet CUDA 13's driver requirement.
|
||||||
|
# Compatibility matrix:
|
||||||
|
# https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html
|
||||||
|
# The `nvcc` is installed into `/usr/local/cuda-VERSION/bin/nvcc` - but
|
||||||
|
# it's *not* on the default toolkit path.
|
||||||
|
PACKAGES: |
|
||||||
|
{
|
||||||
|
"cuda-12.6": "libcudnn9-dev-cuda-12 cuda-toolkit-12-6",
|
||||||
|
"cuda-12.9": "libcudnn9-dev-cuda-12 cuda-toolkit-12-9",
|
||||||
|
"cuda-13.0": "libcudnn9-dev-cuda-13 cuda-toolkit-13-0"
|
||||||
|
}
|
||||||
|
run: |
|
||||||
|
export ARCH=${{ runner.arch == 'arm64' && 'arm64' || 'x86_64' }}
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$ARCH/cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y \
|
||||||
|
libnccl2 libnccl-dev \
|
||||||
|
${{ fromJson(env.PACKAGES)[inputs.toolkit] }}
|
||||||
|
|
||||||
|
- name: CUDA packages and driver report
|
||||||
|
if: ${{ startsWith(inputs.toolkit, 'cuda') }}
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
sudo apt-get install -y ubuntu-drivers-common dkms
|
||||||
|
echo "NVIDIA Driver Packages Available:"
|
||||||
|
sudo ubuntu-drivers list --gpgpu
|
||||||
|
echo "NVIDIA Driver Version:"
|
||||||
|
cat /proc/driver/nvidia/version || echo "nvidia driver not found"
|
||||||
|
echo "Installed NVIDIA and CUDA packages:"
|
||||||
|
dpkg -l | egrep "cuda|nvidia" -i
|
||||||
|
echo "DKMS Status:"
|
||||||
|
dkms status || echo "dkms not found"
|
||||||
|
echo "NVIDIA-SMI Status:"
|
||||||
|
nvidia-smi || echo "nvidia-smi not found"
|
||||||
24
.github/actions/setup-macos/action.yml
vendored
Normal file
24
.github/actions/setup-macos/action.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
name: 'Setup macOS Environment'
|
||||||
|
description: 'Install dependencies for macOS builds'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
python-version:
|
||||||
|
description: 'Python version to use'
|
||||||
|
required: false
|
||||||
|
default: '3.10'
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install Homebrew packages
|
||||||
|
shell: sh
|
||||||
|
run: /opt/homebrew/bin/brew install openmpi
|
||||||
|
|
||||||
|
- name: Verify MetalToolchain installed
|
||||||
|
shell: bash
|
||||||
|
run: xcodebuild -showComponent MetalToolchain
|
||||||
|
|
||||||
|
- uses: conda-incubator/setup-miniconda@v3
|
||||||
|
with:
|
||||||
|
miniconda-version: "latest"
|
||||||
|
python-version: ${{ inputs.python-version }}
|
||||||
69
.github/actions/test-linux/action.yml
vendored
Normal file
69
.github/actions/test-linux/action.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
name: 'Run Linux tests'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
cpu-only:
|
||||||
|
description: 'Skip GPU tests'
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Run MPI tests
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "::group::MPI tests"
|
||||||
|
mpirun --bind-to none --allow-run-as-root -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run distributed tests
|
||||||
|
if: ${{ inputs.cpu-only == 'true' }}
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "::group::Distributed tests"
|
||||||
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if grep -Fq '[WARN]' stderr.log ; then
|
||||||
|
grep -F '[WARN]' stderr.log
|
||||||
|
echo "Distributed ring test failed";
|
||||||
|
exit 1;
|
||||||
|
fi
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run Python tests - CPU
|
||||||
|
if: ${{ inputs.cpu-only == 'true' }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: cpu
|
||||||
|
run: |
|
||||||
|
echo "::group::Python tests - CPU"
|
||||||
|
python -m unittest discover python/tests -v
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run Python tests - GPU
|
||||||
|
if: ${{ inputs.cpu-only == 'false' }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: gpu
|
||||||
|
run: |
|
||||||
|
echo "::group::Python tests - GPU"
|
||||||
|
python -m tests discover python/tests -v
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run CPP tests - CPU
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: cpu
|
||||||
|
run: |
|
||||||
|
echo "::group::CPP tests - CPU"
|
||||||
|
./build/tests/tests
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Run CPP tests - GPU
|
||||||
|
if: ${{ inputs.cpu-only == 'false' }}
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
DEVICE: gpu
|
||||||
|
run: |
|
||||||
|
echo "::group::CPP tests - GPU"
|
||||||
|
./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||||
|
echo "::endgroup::"
|
||||||
6
.github/dependabot.yml
vendored
Normal file
6
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: "github-actions"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
27
.github/scripts/setup+build-cpp-linux-fedora-container.sh
vendored
Executable file
27
.github/scripts/setup+build-cpp-linux-fedora-container.sh
vendored
Executable file
@@ -0,0 +1,27 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# [Setup] Install dependencies inside the container.
|
||||||
|
dnf update -y
|
||||||
|
dnf install -y \
|
||||||
|
blas-devel \
|
||||||
|
lapack-devel \
|
||||||
|
openblas-devel \
|
||||||
|
make \
|
||||||
|
cmake \
|
||||||
|
clang \
|
||||||
|
git
|
||||||
|
dnf clean all
|
||||||
|
|
||||||
|
# [C++] CI Build Sanity Check: Verifies code compilation, not for release.
|
||||||
|
export CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON"
|
||||||
|
export DEBUG=1
|
||||||
|
export CMAKE_C_COMPILER=/usr/bin/clang
|
||||||
|
export CMAKE_CXX_COMPILER=/usr/bin/clang++
|
||||||
|
|
||||||
|
mkdir -p build
|
||||||
|
pushd build
|
||||||
|
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||||
|
make -j $(nproc)
|
||||||
|
./tests/tests
|
||||||
|
popd
|
||||||
28
.github/workflows/documentation.yml
vendored
Normal file
28
.github/workflows/documentation.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
name: Documentation
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
|
deploy:
|
||||||
|
needs: build
|
||||||
|
permissions:
|
||||||
|
pages: write
|
||||||
|
id-token: write
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
environment:
|
||||||
|
name: github-pages
|
||||||
|
url: ${{ steps.deployment.outputs.page_url }}
|
||||||
|
steps:
|
||||||
|
- name: Deploy to GitHub Pages
|
||||||
|
id: deployment
|
||||||
|
uses: actions/deploy-pages@v4
|
||||||
98
.github/workflows/nightly.yml
vendored
Normal file
98
.github/workflows/nightly.yml
vendored
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
name: Nightly Build
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: 33 6 * * 1-5
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build_linux_release:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python_version: ["3.10", "3.14"]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
- uses: ./.github/actions/build-linux-release
|
||||||
|
with:
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
arch: "x86_64"
|
||||||
|
- name: Upload mlx artifacts
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
name: linux-wheels-${{ matrix.python_version }}
|
||||||
|
path: wheelhouse/mlx-*.whl
|
||||||
|
retention-days: 7
|
||||||
|
- name: Upload mlx-cpu artifacts
|
||||||
|
if: matrix.python_version == '3.10'
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
name: mlx-cpu
|
||||||
|
path: wheelhouse/mlx_cpu-*.whl
|
||||||
|
retention-days: 7
|
||||||
|
|
||||||
|
build_linux_with_tests:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python_version: ["3.11", "3.12", "3.13", "3.14"]
|
||||||
|
runner:
|
||||||
|
- ubuntu-22.04
|
||||||
|
- ubuntu-22.04-arm
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python_version }}
|
||||||
|
- uses: ./.github/actions/build-linux
|
||||||
|
- uses: ./.github/actions/test-linux
|
||||||
|
with:
|
||||||
|
cpu-only: true
|
||||||
|
|
||||||
|
build_mac_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10", "3.13"]
|
||||||
|
runs-on: [self-hosted, macos]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/setup-macos
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- uses: ./.github/actions/build-macos
|
||||||
|
- name: Build macOS 15 package
|
||||||
|
uses: ./.github/actions/build-macos-release
|
||||||
|
with:
|
||||||
|
macos-target: 15.0
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
- name: Build macOS 14 package
|
||||||
|
uses: ./.github/actions/build-macos-release
|
||||||
|
with:
|
||||||
|
macos-target: 14.0
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
|
||||||
|
build_cuda_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
runs-on: ubuntu-22-large
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
with:
|
||||||
|
toolkit: 'cuda-12.9'
|
||||||
|
- name: Build Python package
|
||||||
|
uses: ./.github/actions/build-cuda-release
|
||||||
|
with:
|
||||||
|
toolkit: 'cuda-12.9'
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
name: mlx-cuda
|
||||||
|
path: wheelhouse/mlx_cuda-*.whl
|
||||||
|
retention-days: 7
|
||||||
103
.github/workflows/pull_request.yml
vendored
103
.github/workflows/pull_request.yml
vendored
@@ -1,20 +1,103 @@
|
|||||||
|
name: Build and Test
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
|
# For testing CI without starting a pull request:
|
||||||
|
- test/*
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check_lint:
|
check_lint:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
- uses: actions/setup-python@v4
|
- uses: pre-commit/action@v3.0.1
|
||||||
|
|
||||||
|
linux_build_and_test:
|
||||||
|
needs: check_lint
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
runner:
|
||||||
|
- ubuntu-22.04
|
||||||
|
- ubuntu-22.04-arm
|
||||||
|
fail-fast: false
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
- uses: ./.github/actions/build-linux
|
||||||
|
- uses: ./.github/actions/test-linux
|
||||||
with:
|
with:
|
||||||
python-version: 3.8
|
cpu-only: true
|
||||||
- name: Install dependencies
|
|
||||||
|
mac_build_and_test:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
macos-target: ["14.0", "15.0"]
|
||||||
|
runs-on: [self-hosted, macos]
|
||||||
|
env:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: ${{ matrix.macos-target }}
|
||||||
|
needs: check_lint
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/setup-macos
|
||||||
|
- uses: ./.github/actions/build-macos
|
||||||
|
|
||||||
|
cuda_build_and_test:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
toolkit: ['cuda-12.6', 'cuda-12.9']
|
||||||
|
runs-on: gpu-t4-4-core
|
||||||
|
needs: check_lint
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
with:
|
||||||
|
toolkit: ${{ matrix.toolkit }}
|
||||||
|
- uses: ./.github/actions/build-cuda
|
||||||
|
with:
|
||||||
|
toolkit: ${{ matrix.toolkit }}
|
||||||
|
- uses: ./.github/actions/test-linux
|
||||||
|
|
||||||
|
build_documentation:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
needs: check_lint
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
|
linux_fedora_build_cpp:
|
||||||
|
name: Linux Fedora CPP Build (${{ matrix.arch }})
|
||||||
|
needs: check_lint
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- host: ubuntu-22.04
|
||||||
|
arch: x86_64
|
||||||
|
- host: ubuntu-22.04-arm
|
||||||
|
arch: aarch64
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.host }}
|
||||||
|
container:
|
||||||
|
image: fedora:42
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
|
- name: CPP Build Test - No Release
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
bash ./.github/scripts/setup+build-cpp-linux-fedora-container.sh
|
||||||
pip install pre-commit black isort clang-format
|
|
||||||
- name: Run lint
|
|
||||||
run: |
|
|
||||||
pre-commit run --all-files
|
|
||||||
|
|||||||
239
.github/workflows/release.yml
vendored
Normal file
239
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
name: PyPI Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
dev_release:
|
||||||
|
description: "Do a dev release or regular release"
|
||||||
|
required: true
|
||||||
|
default: "false"
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
setup:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Set publishing variables
|
||||||
|
run: echo "Publishing setup complete"
|
||||||
|
|
||||||
|
build_documentation:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
runs-on: [self-hosted, macos]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/build-docs
|
||||||
|
|
||||||
|
deploy_documentation:
|
||||||
|
needs: build_documentation
|
||||||
|
permissions:
|
||||||
|
pages: write
|
||||||
|
id-token: write
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
environment:
|
||||||
|
name: github-pages
|
||||||
|
url: ${{ steps.deployment.outputs.page_url }}
|
||||||
|
steps:
|
||||||
|
- name: Deploy to GitHub Pages
|
||||||
|
id: deployment
|
||||||
|
uses: actions/deploy-pages@v4
|
||||||
|
|
||||||
|
build_linux_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
|
arch: ['x86_64', 'aarch64']
|
||||||
|
runs-on: ${{ matrix.arch == 'x86_64' && 'ubuntu-22.04' || 'ubuntu-22.04-arm' }}
|
||||||
|
env:
|
||||||
|
PYPI_RELEASE: 1
|
||||||
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python_version }}
|
||||||
|
- uses: ./.github/actions/build-linux-release
|
||||||
|
with:
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
arch: ${{ matrix.arch }}
|
||||||
|
- name: Upload MLX artifacts
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
overwrite: true
|
||||||
|
name: linux-wheels-${{ matrix.python_version }}
|
||||||
|
path: wheelhouse/mlx-*.whl
|
||||||
|
- name: Upload CPU artifacts
|
||||||
|
if: matrix.python_version == '3.10'
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
overwrite: true
|
||||||
|
name: mlx-cpu
|
||||||
|
path: wheelhouse/mlx_cpu-*.whl
|
||||||
|
|
||||||
|
build_mac_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
|
runs-on: [self-hosted, macos]
|
||||||
|
env:
|
||||||
|
PYPI_RELEASE: 1
|
||||||
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/setup-macos
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install cmake setuptools nanobind==2.4.0
|
||||||
|
pip install -e . -v
|
||||||
|
- name: Generate package stubs
|
||||||
|
shell: bash -l {0}
|
||||||
|
run: |
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
- name: Build macOS 14 package
|
||||||
|
uses: ./.github/actions/build-macos-release
|
||||||
|
with:
|
||||||
|
macos-target: 14.0
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
- name: Build macOS 15 package
|
||||||
|
uses: ./.github/actions/build-macos-release
|
||||||
|
with:
|
||||||
|
macos-target: 15.0
|
||||||
|
build-backend: ${{ matrix.python-version == '3.10' }}
|
||||||
|
- name: Upload MLX artifacts
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
overwrite: true
|
||||||
|
name: mac-wheels-${{ matrix.python-version }}
|
||||||
|
path: dist/mlx-*.whl
|
||||||
|
- name: Upload Metal artifacts
|
||||||
|
if: matrix.python-version == '3.10'
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
overwrite: true
|
||||||
|
name: mlx-metal
|
||||||
|
path: dist/mlx_metal-*.whl
|
||||||
|
|
||||||
|
build_cuda_release:
|
||||||
|
if: github.repository == 'ml-explore/mlx'
|
||||||
|
runs-on: ubuntu-22-large
|
||||||
|
env:
|
||||||
|
PYPI_RELEASE: 1
|
||||||
|
DEV_RELEASE: ${{ github.event.inputs.dev_release == 'true' && 1 || 0 }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
- uses: ./.github/actions/setup-linux
|
||||||
|
with:
|
||||||
|
toolkit: 'cuda-12.9'
|
||||||
|
- name: Build Python package
|
||||||
|
uses: ./.github/actions/build-cuda-release
|
||||||
|
with:
|
||||||
|
toolkit: 'cuda-12.9'
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v5
|
||||||
|
with:
|
||||||
|
overwrite: true
|
||||||
|
name: mlx-cuda
|
||||||
|
path: wheelhouse/mlx_cuda-*.whl
|
||||||
|
|
||||||
|
pypi-publish:
|
||||||
|
name: Upload release to PyPI
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [setup, build_linux_release, build_mac_release]
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/mlx
|
||||||
|
steps:
|
||||||
|
- uses: actions/download-artifact@v6
|
||||||
|
with:
|
||||||
|
pattern: linux-wheels-*
|
||||||
|
merge-multiple: true
|
||||||
|
path: dist
|
||||||
|
- uses: actions/download-artifact@v6
|
||||||
|
with:
|
||||||
|
pattern: mac-wheels-*
|
||||||
|
merge-multiple: true
|
||||||
|
path: dist
|
||||||
|
- name: Display structure of downloaded files
|
||||||
|
run: ls -R dist
|
||||||
|
- name: Publish package distributions to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
|
|
||||||
|
pypi-publish-cuda:
|
||||||
|
name: Upload CUDA release to PyPI
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [setup, build_cuda_release]
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/mlx-cuda
|
||||||
|
steps:
|
||||||
|
- uses: actions/download-artifact@v6
|
||||||
|
with:
|
||||||
|
name: mlx-cuda
|
||||||
|
path: dist
|
||||||
|
- name: Display structure of downloaded files
|
||||||
|
run: ls -R dist
|
||||||
|
- name: Publish package distributions to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
|
|
||||||
|
pypi-publish-cpu:
|
||||||
|
name: Upload CPU release to PyPI
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [setup, build_linux_release]
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/mlx-cpu
|
||||||
|
steps:
|
||||||
|
- uses: actions/download-artifact@v6
|
||||||
|
with:
|
||||||
|
name: mlx-cpu
|
||||||
|
path: dist
|
||||||
|
- name: Display structure of downloaded files
|
||||||
|
run: ls -R dist
|
||||||
|
- name: Publish package distributions to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
|
|
||||||
|
pypi-publish-metal:
|
||||||
|
name: Upload Metal release to PyPI
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [setup, build_mac_release]
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/mlx-metal
|
||||||
|
steps:
|
||||||
|
- uses: actions/download-artifact@v6
|
||||||
|
with:
|
||||||
|
name: mlx-metal
|
||||||
|
path: dist
|
||||||
|
- name: Display structure of downloaded files
|
||||||
|
run: ls -R dist
|
||||||
|
- name: Publish package distributions to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
repository-url: https://upload.pypi.org/legacy/
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
|||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
|
uv.lock
|
||||||
|
|
||||||
# vim
|
# vim
|
||||||
*.swp
|
*.swp
|
||||||
@@ -76,6 +77,9 @@ build/
|
|||||||
*.out
|
*.out
|
||||||
*.app
|
*.app
|
||||||
|
|
||||||
|
# Debug symbols
|
||||||
|
*.pdb
|
||||||
|
|
||||||
# VSCode
|
# VSCode
|
||||||
.vscode/
|
.vscode/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|||||||
@@ -1,16 +1,27 @@
|
|||||||
repos:
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v6.0.0
|
||||||
|
hooks:
|
||||||
|
- id: check-yaml
|
||||||
|
# - id: end-of-file-fixer
|
||||||
|
# - id: trailing-whitespace
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v17.0.6
|
rev: v19.1.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 23.12.1
|
rev: 25.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.13.2
|
rev: 6.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args:
|
args:
|
||||||
- --profile=black
|
- --profile=black
|
||||||
|
- repo: https://github.com/cheshirekow/cmake-format-precommit
|
||||||
|
rev: v0.6.13
|
||||||
|
hooks:
|
||||||
|
- id: cmake-format
|
||||||
|
|||||||
@@ -7,16 +7,29 @@ with a short description of your contribution(s) below. For example:
|
|||||||
|
|
||||||
MLX was developed with contributions from the following individuals:
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||||
|
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||||
|
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||||
|
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||||
|
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||||
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
|
# Organizations
|
||||||
|
|
||||||
|
MLX has received contributions from the following companies:
|
||||||
|
- NVIDIA Corporation & Affiliates
|
||||||
|
|
||||||
# Third-Party Software
|
# Third-Party Software
|
||||||
|
|
||||||
MLX leverages several third-party software, listed here together with
|
MLX leverages several third-party software, listed here together with
|
||||||
|
|||||||
24
CITATION.cff
Normal file
24
CITATION.cff
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
title: mlx
|
||||||
|
message: >-
|
||||||
|
If you use this software, please cite it using the
|
||||||
|
metadata from this file.
|
||||||
|
type: software
|
||||||
|
authors:
|
||||||
|
- given-names: Awni
|
||||||
|
family-names: Hannun
|
||||||
|
affiliation: Apple
|
||||||
|
- given-names: Jagrit
|
||||||
|
family-names: Digani
|
||||||
|
affiliation: Apple
|
||||||
|
- given-names: Angelos
|
||||||
|
family-names: Katharopoulos
|
||||||
|
affiliation: Apple
|
||||||
|
- given-names: Ronan
|
||||||
|
family-names: Collobert
|
||||||
|
affiliation: Apple
|
||||||
|
repository-code: 'https://github.com/ml-explore'
|
||||||
|
abstract: >-
|
||||||
|
MLX: efficient and flexible machine learning on Apple
|
||||||
|
silicon
|
||||||
|
license: MIT
|
||||||
330
CMakeLists.txt
330
CMakeLists.txt
@@ -1,6 +1,24 @@
|
|||||||
cmake_minimum_required(VERSION 3.24)
|
cmake_minimum_required(VERSION 3.25)
|
||||||
|
|
||||||
project(mlx LANGUAGES C CXX)
|
if(NOT MLX_VERSION)
|
||||||
|
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
|
||||||
|
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||||
|
set(_major ${CMAKE_MATCH_1})
|
||||||
|
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||||
|
set(_minor ${CMAKE_MATCH_1})
|
||||||
|
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
||||||
|
set(_patch ${CMAKE_MATCH_1})
|
||||||
|
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
||||||
|
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
||||||
|
else()
|
||||||
|
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
||||||
|
${MLX_VERSION})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
project(
|
||||||
|
mlx
|
||||||
|
LANGUAGES C CXX
|
||||||
|
VERSION ${MLX_PROJECT_VERSION})
|
||||||
|
|
||||||
# ----------------------------- Setup -----------------------------
|
# ----------------------------- Setup -----------------------------
|
||||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||||
@@ -8,6 +26,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
# ----------------------------- Configuration -----------------------------
|
# ----------------------------- Configuration -----------------------------
|
||||||
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
||||||
@@ -15,37 +34,51 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
|||||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
|
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||||
|
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||||
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
|
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||||
|
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||||
|
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||||
|
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||||
|
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||||
if(NOT MLX_VERSION)
|
|
||||||
set(MLX_VERSION 0.2.0)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
|
message(
|
||||||
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
STATUS
|
||||||
|
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
||||||
set(MLX_BUILD_ARM OFF)
|
)
|
||||||
|
|
||||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
|
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||||
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
|
if(NOT MLX_ENABLE_X64_MAC)
|
||||||
message(FATAL_ERROR
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
"Building for x86_64 on macOS is not supported."
|
"Building for x86_64 on macOS is not supported."
|
||||||
" If you are on an Apple silicon system, check the build"
|
" If you are on an Apple silicon system, check the build"
|
||||||
" documentation for possible fixes: "
|
" documentation for possible fixes: "
|
||||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
||||||
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
)
|
||||||
message(WARNING
|
else()
|
||||||
"Building for x86_64 on macOS is not supported."
|
set(MLX_BUILD_METAL OFF)
|
||||||
" If you are on an Apple silicon system, "
|
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||||
" make sure you are building for arm64.")
|
endif()
|
||||||
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
|
endif()
|
||||||
set(MLX_BUILD_ARM ON)
|
else()
|
||||||
|
set(MLX_BUILD_METAL OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
if(MLX_USE_CCACHE)
|
||||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
find_program(CCACHE_PROGRAM ccache)
|
||||||
|
if(CCACHE_PROGRAM)
|
||||||
|
message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Lib -----------------------------
|
# ----------------------------- Lib -----------------------------
|
||||||
@@ -56,103 +89,198 @@ cmake_policy(SET CMP0135 NEW)
|
|||||||
|
|
||||||
add_library(mlx)
|
add_library(mlx)
|
||||||
|
|
||||||
|
# Supress warnings: note: parameter passing for argument of type
|
||||||
|
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||||
|
# 10.1
|
||||||
|
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||||
|
|
||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
enable_language(CUDA)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
find_library(METAL_LIB Metal)
|
find_library(METAL_LIB Metal)
|
||||||
find_library(FOUNDATION_LIB Foundation)
|
find_library(FOUNDATION_LIB Foundation)
|
||||||
find_library(QUARTZ_LIB QuartzCore)
|
find_library(QUARTZ_LIB QuartzCore)
|
||||||
|
if(METAL_LIB)
|
||||||
|
message(STATUS "Metal found ${METAL_LIB}")
|
||||||
|
else()
|
||||||
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
|
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
if(MLX_METAL_DEBUG)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
add_compile_definitions(MLX_METAL_DEBUG)
|
||||||
set(MLX_BUILD_METAL OFF)
|
endif()
|
||||||
elseif (MLX_BUILD_METAL)
|
|
||||||
message(STATUS "Building METAL sources")
|
|
||||||
add_compile_definitions(_METAL_)
|
|
||||||
|
|
||||||
# Throw an error if xcrun not found
|
# Throw an error if xcrun not found
|
||||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
execute_process(
|
||||||
OUTPUT_VARIABLE MACOS_VERSION
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||||
|
message(
|
||||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
FATAL_ERROR
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
||||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
|
||||||
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
|
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
|
|
||||||
endif()
|
endif()
|
||||||
|
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||||
|
|
||||||
FetchContent_Declare(
|
set(METAL_CPP_URL
|
||||||
metal_cpp
|
https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
|
||||||
URL ${METAL_CPP_URL}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||||
|
if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
|
||||||
|
message(FATAL_ERROR "MLX requires macOS >= 14.0")
|
||||||
|
endif()
|
||||||
|
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||||
|
endif()
|
||||||
|
execute_process(
|
||||||
|
COMMAND
|
||||||
|
zsh "-c"
|
||||||
|
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||||
|
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||||
FetchContent_MakeAvailable(metal_cpp)
|
FetchContent_MakeAvailable(metal_cpp)
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx PUBLIC
|
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||||
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
$<INSTALL_INTERFACE:include/metal_cpp>)
|
||||||
$<INSTALL_INTERFACE:include/metal_cpp>
|
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||||
)
|
|
||||||
target_link_libraries(
|
|
||||||
mlx
|
|
||||||
${METAL_LIB}
|
|
||||||
${FOUNDATION_LIB}
|
|
||||||
${QUARTZ_LIB})
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||||
|
# With newer clang/gcc versions following libs are implicitly linked, but when
|
||||||
|
# building on old distributions they need to be explicitly listed.
|
||||||
|
target_link_libraries(mlx PRIVATE dl pthread)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
if(MSVC)
|
||||||
|
# GGUF does not build with MSVC.
|
||||||
|
set(MLX_BUILD_GGUF OFF)
|
||||||
|
# There is no prebuilt OpenBLAS distribution for MSVC.
|
||||||
|
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
|
||||||
|
endif()
|
||||||
|
# Windows implementation of dlfcn.h APIs.
|
||||||
|
FetchContent_Declare(
|
||||||
|
dlfcn-win32
|
||||||
|
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
||||||
|
GIT_TAG v1.4.1
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
block()
|
||||||
|
set(BUILD_SHARED_LIBS OFF)
|
||||||
|
FetchContent_MakeAvailable(dlfcn-win32)
|
||||||
|
endblock()
|
||||||
|
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
|
||||||
|
target_link_libraries(mlx PRIVATE dl)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_CPU)
|
||||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
if(ACCELERATE_LIBRARY)
|
||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
set(MLX_BUILD_ACCELERATE ON)
|
set(MLX_BUILD_ACCELERATE ON)
|
||||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
|
||||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
|
||||||
else()
|
else()
|
||||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
message(STATUS "Accelerate not found, using default backend.")
|
||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
#set(BLA_VENDOR Generic)
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_ACCELERATE)
|
||||||
|
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||||
|
add_compile_definitions(MLX_USE_ACCELERATE)
|
||||||
|
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||||
|
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
|
||||||
|
# Download and build OpenBLAS from source code.
|
||||||
|
FetchContent_Declare(
|
||||||
|
openblas
|
||||||
|
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
|
||||||
|
GIT_TAG v0.3.28
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
set(BUILD_STATIC_LIBS ON) # link statically
|
||||||
|
set(NOFORTRAN ON) # msvc has no fortran compiler
|
||||||
|
FetchContent_MakeAvailable(openblas)
|
||||||
|
target_link_libraries(mlx PRIVATE openblas)
|
||||||
|
target_include_directories(
|
||||||
|
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
|
||||||
|
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
|
||||||
|
else()
|
||||||
|
if(${CMAKE_HOST_APPLE})
|
||||||
|
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||||
|
# openblas instead.
|
||||||
|
set(BLA_VENDOR OpenBLAS)
|
||||||
|
set(LAPACK_ROOT
|
||||||
|
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||||
|
endif()
|
||||||
|
# Search and link with lapack.
|
||||||
|
find_package(LAPACK REQUIRED)
|
||||||
|
if(NOT LAPACK_FOUND)
|
||||||
|
message(FATAL_ERROR "Must have LAPACK installed")
|
||||||
|
endif()
|
||||||
|
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
|
||||||
|
/usr/local/opt/openblas/include)
|
||||||
|
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||||
|
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||||
|
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
|
||||||
|
# List blas after lapack otherwise we may accidentally incldue an old
|
||||||
|
# version of lapack.h from the include dirs of blas.
|
||||||
find_package(BLAS REQUIRED)
|
find_package(BLAS REQUIRED)
|
||||||
if(NOT BLAS_FOUND)
|
if(NOT BLAS_FOUND)
|
||||||
message(FATAL_ERROR "Must have BLAS installed")
|
message(FATAL_ERROR "Must have BLAS installed")
|
||||||
endif()
|
endif()
|
||||||
# TODO find a cleaner way to do this
|
# TODO find a cleaner way to do this
|
||||||
find_path(BLAS_INCLUDE_DIRS cblas.h
|
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
|
||||||
/usr/include
|
|
||||||
/usr/local/include
|
|
||||||
$ENV{BLAS_HOME}/include)
|
$ENV{BLAS_HOME}/include)
|
||||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
|
||||||
find_package(LAPACK REQUIRED)
|
|
||||||
if (NOT LAPACK_FOUND)
|
|
||||||
message(FATAL_ERROR "Must have LAPACK installed")
|
|
||||||
endif()
|
endif()
|
||||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
else()
|
||||||
/usr/include
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
/usr/local/include)
|
|
||||||
message(STATUS "Lapack lib" ${LAPACK_LIBRARIES})
|
|
||||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
|
||||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
|
||||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
message(STATUS "Downloading json")
|
||||||
|
FetchContent_Declare(
|
||||||
|
json
|
||||||
|
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||||
|
FetchContent_MakeAvailable(json)
|
||||||
|
target_include_directories(
|
||||||
|
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||||
|
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx
|
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
PUBLIC
|
$<INSTALL_INTERFACE:include>)
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
|
||||||
$<INSTALL_INTERFACE:include>
|
# Do not add mlx_EXPORTS define for shared library.
|
||||||
)
|
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||||
|
|
||||||
|
if(USE_SYSTEM_FMT)
|
||||||
|
find_package(fmt REQUIRED)
|
||||||
|
else()
|
||||||
|
FetchContent_Declare(
|
||||||
|
fmt
|
||||||
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
|
GIT_TAG 10.2.1
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
endif()
|
||||||
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||||
|
|
||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
message(STATUS "Building Python bindings.")
|
message(STATUS "Building Python bindings.")
|
||||||
find_package(Python COMPONENTS Interpreter Development)
|
find_package(
|
||||||
find_package(pybind11 CONFIG REQUIRED)
|
Python 3.8
|
||||||
|
COMPONENTS Interpreter Development.Module
|
||||||
|
REQUIRED)
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE nanobind_ROOT)
|
||||||
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@@ -169,8 +297,6 @@ if (MLX_BUILD_BENCHMARKS)
|
|||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------- Installation -----------------------------
|
# ----------------------------- Installation -----------------------------
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
@@ -181,17 +307,17 @@ install(
|
|||||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||||
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
INCLUDES
|
||||||
)
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||||
|
|
||||||
|
|
||||||
# Install headers
|
# Install headers
|
||||||
install(
|
install(
|
||||||
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||||
COMPONENT headers
|
COMPONENT headers
|
||||||
FILES_MATCHING PATTERN "*.h"
|
FILES_MATCHING
|
||||||
)
|
PATTERN "*.h"
|
||||||
|
PATTERN "backend/metal/kernels.h" EXCLUDE)
|
||||||
|
|
||||||
# Install metal dependencies
|
# Install metal dependencies
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
@@ -200,8 +326,7 @@ if (MLX_BUILD_METAL)
|
|||||||
install(
|
install(
|
||||||
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
||||||
COMPONENT metal_cpp_source
|
COMPONENT metal_cpp_source)
|
||||||
)
|
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@@ -213,31 +338,24 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
|
|||||||
install(
|
install(
|
||||||
EXPORT MLXTargets
|
EXPORT MLXTargets
|
||||||
FILE MLXTargets.cmake
|
FILE MLXTargets.cmake
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||||
)
|
|
||||||
|
|
||||||
include(CMakePackageConfigHelpers)
|
include(CMakePackageConfigHelpers)
|
||||||
|
|
||||||
write_basic_package_version_file(
|
write_basic_package_version_file(
|
||||||
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||||
COMPATIBILITY SameMajorVersion
|
COMPATIBILITY SameMajorVersion
|
||||||
VERSION ${MLX_VERSION}
|
VERSION ${MLX_VERSION})
|
||||||
)
|
|
||||||
|
|
||||||
configure_package_config_file(
|
configure_package_config_file(
|
||||||
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
|
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
|
||||||
${MLX_CMAKE_BUILD_CONFIG}
|
|
||||||
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
||||||
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
|
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
|
||||||
)
|
MLX_CMAKE_INSTALL_MODULE_DIR)
|
||||||
|
|
||||||
install(
|
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||||
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
|
||||||
)
|
|
||||||
|
|
||||||
install(
|
install(DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||||
DIRECTORY ${CMAKE_MODULE_PATH}/
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -17,11 +17,11 @@ possible.
|
|||||||
|
|
||||||
You can also run the formatters manually as follows:
|
You can also run the formatters manually as follows:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
clang-format -i file.cpp
|
clang-format -i file.cpp
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```shell
|
||||||
black file.py
|
black file.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
include CMakeLists.txt
|
include CMakeLists.txt
|
||||||
|
include mlx.pc.in
|
||||||
recursive-include mlx/ *
|
recursive-include mlx/ *
|
||||||
|
include cmake/*
|
||||||
include python/src/*
|
include python/src/*
|
||||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||||
|
|||||||
37
README.md
37
README.md
@@ -6,15 +6,17 @@
|
|||||||
|
|
||||||
[](https://circleci.com/gh/ml-explore/mlx)
|
[](https://circleci.com/gh/ml-explore/mlx)
|
||||||
|
|
||||||
MLX is an array framework for machine learning on Apple silicon, brought to you
|
MLX is an array framework for machine learning on Apple silicon,
|
||||||
by Apple machine learning research.
|
brought to you by Apple machine learning research.
|
||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
|
||||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
|
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||||
MLX also has a fully featured C++ API, which closely mirrors the Python API.
|
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||||
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
|
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||||
that closely follow PyTorch to simplify building more complex models.
|
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||||
|
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||||
|
more complex models.
|
||||||
|
|
||||||
- **Composable function transformations**: MLX supports composable function
|
- **Composable function transformations**: MLX supports composable function
|
||||||
transformations for automatic differentiation, automatic vectorization,
|
transformations for automatic differentiation, automatic vectorization,
|
||||||
@@ -66,18 +68,23 @@ in the documentation.
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||||
|
macOS, run:
|
||||||
|
|
||||||
**With `pip`**:
|
```bash
|
||||||
|
|
||||||
```
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
**With `conda`**:
|
To install the CUDA backend on Linux, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cuda]
|
||||||
```
|
```
|
||||||
conda install -c conda-forge mlx
|
|
||||||
|
To install a CPU-only Linux package, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cpu]
|
||||||
```
|
```
|
||||||
|
|
||||||
Checkout the
|
Checkout the
|
||||||
@@ -86,13 +93,13 @@ for more information on building the C++ and Python APIs from source.
|
|||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
||||||
on contributing to MLX. See the
|
on contributing to MLX. See the
|
||||||
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
||||||
information on building from source, and running tests.
|
information on building from source, and running tests.
|
||||||
|
|
||||||
We are grateful for all of [our
|
We are grateful for all of [our
|
||||||
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||||
to MLX and wish to be acknowledged, please add your name to the list in your
|
to MLX and wish to be acknowledged, please add your name to the list in your
|
||||||
pull request.
|
pull request.
|
||||||
|
|
||||||
@@ -103,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
|||||||
MLX useful in your research and wish to cite it, please use the following
|
MLX useful in your research and wish to cite it, please use the following
|
||||||
BibTex entry:
|
BibTex entry:
|
||||||
|
|
||||||
```
|
```text
|
||||||
@software{mlx2023,
|
@software{mlx2023,
|
||||||
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
||||||
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
||||||
|
|||||||
@@ -5,35 +5,35 @@
|
|||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
#include "time_utils.h"
|
#include "time_utils.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
void time_value_and_grad() {
|
void time_value_and_grad() {
|
||||||
auto x = ones({200, 1000});
|
auto x = mx::ones({200, 1000});
|
||||||
eval(x);
|
mx::eval(x);
|
||||||
auto fn = [](array x) {
|
auto fn = [](mx::array x) {
|
||||||
for (int i = 0; i < 20; ++i) {
|
for (int i = 0; i < 20; ++i) {
|
||||||
x = log(exp(x));
|
x = mx::log(mx::exp(x));
|
||||||
}
|
}
|
||||||
return sum(x);
|
return mx::sum(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto grad_fn = grad(fn);
|
auto grad_fn = mx::grad(fn);
|
||||||
auto independent_value_and_grad = [&]() {
|
auto independent_value_and_grad = [&]() {
|
||||||
auto value = fn(x);
|
auto value = fn(x);
|
||||||
auto dfdx = grad_fn(x);
|
auto dfdx = grad_fn(x);
|
||||||
return std::vector<array>{value, dfdx};
|
return std::vector<mx::array>{value, dfdx};
|
||||||
};
|
};
|
||||||
TIME(independent_value_and_grad);
|
TIME(independent_value_and_grad);
|
||||||
|
|
||||||
auto value_and_grad_fn = value_and_grad(fn);
|
auto value_and_grad_fn = mx::value_and_grad(fn);
|
||||||
auto combined_value_and_grad = [&]() {
|
auto combined_value_and_grad = [&]() {
|
||||||
auto [value, dfdx] = value_and_grad_fn(x);
|
auto [value, dfdx] = value_and_grad_fn(x);
|
||||||
return std::vector<array>{value, dfdx};
|
return std::vector<mx::array>{value, dfdx};
|
||||||
};
|
};
|
||||||
TIME(combined_value_and_grad);
|
TIME(combined_value_and_grad);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||||
time_value_and_grad();
|
time_value_and_grad();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,21 +4,21 @@
|
|||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
#include "time_utils.h"
|
#include "time_utils.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
void time_add_op() {
|
void time_add_op() {
|
||||||
std::vector<int> sizes(1, 1);
|
std::vector<int> sizes(1, 1);
|
||||||
for (int i = 0; i < 9; ++i) {
|
for (int i = 0; i < 9; ++i) {
|
||||||
sizes.push_back(10 * sizes.back());
|
sizes.push_back(10 * sizes.back());
|
||||||
}
|
}
|
||||||
set_default_device(Device::cpu);
|
set_default_device(mx::Device::cpu);
|
||||||
for (auto size : sizes) {
|
for (auto size : sizes) {
|
||||||
auto a = random::uniform({size});
|
auto a = mx::random::uniform({size});
|
||||||
auto b = random::uniform({size});
|
auto b = mx::random::uniform({size});
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
std::cout << "Size " << size << std::endl;
|
std::cout << "Size " << size << std::endl;
|
||||||
TIMEM("cpu", add, a, b, Device::cpu);
|
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
|
||||||
TIMEM("gpu", add, a, b, Device::gpu);
|
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,110 +1,111 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
#include "time_utils.h"
|
#include "time_utils.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
void time_irregular_binary_ops_1D() {
|
void time_irregular_binary_ops_1D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
int size = 1000000;
|
int size = 1000000;
|
||||||
int step = 2;
|
int step = 2;
|
||||||
auto a = random::uniform({size});
|
auto a = mx::random::uniform({size});
|
||||||
auto b = random::uniform({size});
|
auto b = mx::random::uniform({size});
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
a = slice(a, {0}, {size}, {step});
|
a = slice(a, {0}, {size}, {step});
|
||||||
b = slice(b, {0}, {size}, {step});
|
b = slice(b, {0}, {size}, {step});
|
||||||
TIMEM("1D strided", add, a, b, device);
|
TIMEM("1D strided", mx::add, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_binary_ops_2D() {
|
void time_irregular_binary_ops_2D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
int size = 2048;
|
int size = 2048;
|
||||||
auto a = random::uniform({size, size});
|
auto a = mx::random::uniform({size, size});
|
||||||
auto b = random::uniform({size, size});
|
auto b = mx::random::uniform({size, size});
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
TIMEM("2D regular", add, a, b, device);
|
TIMEM("2D regular", mx::add, a, b, device);
|
||||||
|
|
||||||
b = transpose(b);
|
b = mx::transpose(b);
|
||||||
eval(b);
|
mx::eval(b);
|
||||||
TIMEM("2D transpose", add, a, b, device);
|
TIMEM("2D mx::transpose", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({size});
|
b = mx::random::uniform({size});
|
||||||
eval(b);
|
mx::eval(b);
|
||||||
TIMEM("2D broadcast dim 0", add, a, b, device);
|
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
|
||||||
|
|
||||||
b = reshape(b, {size, 1});
|
b = mx::reshape(b, {size, 1});
|
||||||
eval(b);
|
mx::eval(b);
|
||||||
TIMEM("2D broadcast dim 1", add, a, b, device);
|
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_binary_ops_3D() {
|
void time_irregular_binary_ops_3D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
int d0 = 32;
|
int d0 = 32;
|
||||||
int d1 = 512;
|
int d1 = 512;
|
||||||
int d2 = 512;
|
int d2 = 512;
|
||||||
auto a = random::uniform({d0, d1, d2});
|
auto a = mx::random::uniform({d0, d1, d2});
|
||||||
auto b = random::uniform({d0, d1, d2});
|
auto b = mx::random::uniform({d0, d1, d2});
|
||||||
TIMEM("3D regular", add, a, b, device);
|
TIMEM("3D regular", mx::add, a, b, device);
|
||||||
|
|
||||||
b = transpose(b, {0, 2, 1});
|
b = mx::transpose(b, {0, 2, 1});
|
||||||
TIMEM("3D transpose", add, a, b, device);
|
TIMEM("3D mx::transpose", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d1, d2});
|
b = mx::random::uniform({d1, d2});
|
||||||
TIMEM("3D broadcast dim 0", add, a, b, device);
|
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d0, 1, d2});
|
b = mx::random::uniform({d0, 1, d2});
|
||||||
TIMEM("3D broadcast dim 1", add, a, b, device);
|
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d0, d1, 1});
|
b = mx::random::uniform({d0, d1, 1});
|
||||||
TIMEM("3D broadcast dim 2", add, a, b, device);
|
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d2});
|
b = mx::random::uniform({d2});
|
||||||
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d1, 1});
|
b = mx::random::uniform({d1, 1});
|
||||||
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d0, 1, 1});
|
b = mx::random::uniform({d0, 1, 1});
|
||||||
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_binary_ops_4D() {
|
void time_irregular_binary_ops_4D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
std::vector<int> shape = {8, 8, 512, 512};
|
mx::Shape shape = {8, 8, 512, 512};
|
||||||
auto a = random::uniform(shape);
|
auto a = mx::random::uniform(shape);
|
||||||
auto b = random::uniform(shape);
|
auto b = mx::random::uniform(shape);
|
||||||
|
|
||||||
TIMEM("4D regular", add, a, b, device);
|
TIMEM("4D regular", mx::add, a, b, device);
|
||||||
|
|
||||||
b = transpose(b, {0, 1, 3, 2});
|
b = mx::transpose(b, {0, 1, 3, 2});
|
||||||
TIMEM("4D transpose", add, a, b, device);
|
TIMEM("4D mx::transpose", mx::add, a, b, device);
|
||||||
|
|
||||||
std::string om = "4D broadcast dims ";
|
std::string om = "4D broadcast dims ";
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
shape[i] = 1;
|
shape[i] = 1;
|
||||||
b = random::uniform(shape);
|
b = mx::random::uniform(shape);
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << om << i;
|
msg << om << i;
|
||||||
TIMEM(msg.str(), add, a, b, device);
|
TIMEM(msg.str(), mx::add, a, b, device);
|
||||||
|
|
||||||
for (int j = i + 1; j < shape.size(); ++j) {
|
for (int j = i + 1; j < shape.size(); ++j) {
|
||||||
shape[j] = 1;
|
shape[j] = 1;
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << om << i << ", " << j;
|
msg << om << i << ", " << j;
|
||||||
b = random::uniform(shape);
|
b = mx::random::uniform(shape);
|
||||||
TIMEM(msg.str(), add, a, b, device);
|
TIMEM(msg.str(), mx::add, a, b, device);
|
||||||
shape[j] = a.shape(j);
|
shape[j] = a.shape(j);
|
||||||
|
|
||||||
for (int k = j + 1; k < shape.size(); ++k) {
|
for (int k = j + 1; k < shape.size(); ++k) {
|
||||||
shape[k] = 1;
|
shape[k] = 1;
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << om << i << ", " << j << ", " << k;
|
msg << om << i << ", " << j << ", " << k;
|
||||||
b = random::uniform(shape);
|
b = mx::random::uniform(shape);
|
||||||
TIMEM(msg.str(), add, a, b, device);
|
TIMEM(msg.str(), mx::add, a, b, device);
|
||||||
shape[k] = a.shape(k);
|
shape[k] = a.shape(k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -113,83 +114,83 @@ void time_irregular_binary_ops_4D() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_reshape() {
|
void time_irregular_reshape() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
std::vector<int> shape;
|
mx::Shape shape;
|
||||||
auto reshape_fn = [&shape, device](const array& a) {
|
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||||
return reshape(a, shape, device);
|
return mx::reshape(a, shape, device);
|
||||||
};
|
};
|
||||||
|
|
||||||
int size = 64;
|
int size = 64;
|
||||||
int d = 2 * size;
|
int d = 2 * size;
|
||||||
|
|
||||||
auto a = random::uniform({d, d, d});
|
auto a = mx::random::uniform({d, d, d});
|
||||||
|
|
||||||
shape = {8 * size, size, size};
|
shape = {8 * size, size, size};
|
||||||
TIMEM("3D contiguous", reshape_fn, a);
|
TIMEM("3D contiguous", reshape_fn, a);
|
||||||
|
|
||||||
a = transpose(a);
|
a = mx::transpose(a);
|
||||||
shape = {8 * size, size, size};
|
shape = {8 * size, size, size};
|
||||||
TIMEM("3D transpose", reshape_fn, a);
|
TIMEM("3D mx::transpose", reshape_fn, a);
|
||||||
|
|
||||||
a = transpose(a, {1, 2, 0});
|
a = mx::transpose(a, {1, 2, 0});
|
||||||
shape = {8 * size, size, size};
|
shape = {8 * size, size, size};
|
||||||
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
|
||||||
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
|
||||||
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
|
||||||
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
|
||||||
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
|
||||||
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
|
||||||
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
|
||||||
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_astype_1D() {
|
void time_irregular_astype_1D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
int size = 1000000;
|
int size = 1000000;
|
||||||
int step = 2;
|
int step = 2;
|
||||||
auto a = random::uniform({size});
|
auto a = mx::random::uniform({size});
|
||||||
a = slice(a, {0}, {size}, {step});
|
a = slice(a, {0}, {size}, {step});
|
||||||
TIMEM("1D strided", astype, a, int32, device);
|
TIMEM("1D strided", mx::astype, a, mx::int32, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_astype_2D() {
|
void time_irregular_astype_2D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
int size = 2048;
|
int size = 2048;
|
||||||
std::vector<int> shape = {size, size};
|
mx::Shape shape = {size, size};
|
||||||
|
|
||||||
auto a = random::uniform(shape);
|
auto a = mx::random::uniform(shape);
|
||||||
TIMEM("2D regular", astype, a, int32, device);
|
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||||
|
|
||||||
a = transpose(a);
|
a = mx::transpose(a);
|
||||||
TIMEM("2D transpose", astype, a, int32, device);
|
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({size}), shape);
|
a = mx::broadcast_to(mx::random::uniform({size}), shape);
|
||||||
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({size, 1}), shape);
|
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
|
||||||
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
if (argc > 1) {
|
if (argc > 1) {
|
||||||
bool use_gpu = !strcmp(argv[1], "gpu");
|
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||||
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
|
||||||
}
|
}
|
||||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||||
time_irregular_binary_ops_1D();
|
time_irregular_binary_ops_1D();
|
||||||
time_irregular_binary_ops_2D();
|
time_irregular_binary_ops_2D();
|
||||||
time_irregular_binary_ops_3D();
|
time_irregular_binary_ops_3D();
|
||||||
|
|||||||
@@ -3,20 +3,20 @@
|
|||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
#include "time_utils.h"
|
#include "time_utils.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
void time_creation_ops() {
|
void time_creation_ops() {
|
||||||
int M = 2000;
|
int M = 2000;
|
||||||
int N = 500;
|
int N = 500;
|
||||||
auto shape = {M, N};
|
auto shape = {M, N};
|
||||||
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
|
||||||
TIME(full_fp32);
|
TIME(full_fp32);
|
||||||
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
|
||||||
TIME(zeros_fp32);
|
TIME(zeros_fp32);
|
||||||
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
|
||||||
TIME(ones_fp32);
|
TIME(ones_fp32);
|
||||||
|
|
||||||
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
|
||||||
TIME(arange_fp32);
|
TIME(arange_fp32);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,188 +24,212 @@ void time_type_conversions() {
|
|||||||
int M = 2000;
|
int M = 2000;
|
||||||
int N = 500;
|
int N = 500;
|
||||||
auto shape = {M, N};
|
auto shape = {M, N};
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
|
|
||||||
auto a = zeros(shape, float32);
|
auto a = mx::zeros(shape, mx::float32);
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
TIMEM("float32 to int32", astype, a, int32, device);
|
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
|
||||||
TIMEM("float32 to uint32", astype, a, uint32, device);
|
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||||
|
|
||||||
a = zeros(shape, int32);
|
a = mx::zeros(shape, mx::int32);
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
TIMEM("int32 to float32", astype, a, float32, device);
|
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
|
||||||
|
|
||||||
a = zeros(shape, bool_);
|
a = mx::zeros(shape, mx::bool_);
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
TIMEM("bool to float32", astype, a, float32, device);
|
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
|
||||||
TIMEM("bool to int32", astype, a, int32, device);
|
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
|
||||||
TIMEM("bool to uint32", astype, a, uint32, device);
|
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_random_generation() {
|
void time_random_generation() {
|
||||||
int M = 2000;
|
int M = 2000;
|
||||||
int N = 500;
|
int N = 500;
|
||||||
|
|
||||||
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
|
||||||
TIME(uniform);
|
TIME(uniform);
|
||||||
auto normal = [&]() { return random::normal({M, N}, float32); };
|
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
|
||||||
TIME(normal);
|
TIME(normal);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_unary_ops() {
|
void time_unary_ops() {
|
||||||
int M = 2000;
|
int M = 2000;
|
||||||
int N = 500;
|
int N = 500;
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
|
|
||||||
auto a = random::normal({M, N});
|
auto a = mx::random::normal({M, N});
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
TIME(mlx::core::abs, a, device);
|
TIME(mlx::core::abs, a, device);
|
||||||
TIME(negative, a, device);
|
TIME(mx::negative, a, device);
|
||||||
TIME(sign, a, device);
|
TIME(mx::sign, a, device);
|
||||||
TIME(square, a, device);
|
TIME(mx::square, a, device);
|
||||||
TIME(mlx::core::sqrt, a, device);
|
TIME(mlx::core::sqrt, a, device);
|
||||||
TIME(rsqrt, a, device);
|
TIME(mx::rsqrt, a, device);
|
||||||
TIME(mlx::core::exp, a, device);
|
TIME(mlx::core::exp, a, device);
|
||||||
|
|
||||||
a = random::uniform({M, N});
|
a = mx::random::uniform({M, N});
|
||||||
TIME(mlx::core::log, a, device);
|
TIME(mlx::core::log, a, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_binary_ops() {
|
void time_binary_ops() {
|
||||||
int M = 1000, N = 100, K = 10;
|
int M = 1000, N = 100, K = 10;
|
||||||
auto a = random::uniform({M, N, K});
|
auto condition = mx::random::randint(0, 2, {M, N, K});
|
||||||
auto b = random::uniform({M, N, K});
|
auto a = mx::random::uniform({M, N, K});
|
||||||
auto device = default_device();
|
auto b = mx::random::uniform({M, N, K});
|
||||||
eval(a, b);
|
auto device = mx::default_device();
|
||||||
|
mx::eval(a, b);
|
||||||
|
|
||||||
TIME(add, a, b, device);
|
TIME(mx::add, a, b, device);
|
||||||
TIME(subtract, a, b, device);
|
TIME(mx::subtract, a, b, device);
|
||||||
TIME(multiply, a, b, device);
|
TIME(mx::multiply, a, b, device);
|
||||||
TIME(divide, a, b, device);
|
TIME(mx::divide, a, b, device);
|
||||||
TIME(maximum, a, b, device);
|
TIME(mx::maximum, a, b, device);
|
||||||
TIME(minimum, a, b, device);
|
TIME(mx::minimum, a, b, device);
|
||||||
|
TIME(mx::where, condition, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({1});
|
condition = mx::array({true});
|
||||||
eval(b);
|
b = mx::random::uniform({1});
|
||||||
TIMEM("scalar", add, a, b, device);
|
mx::eval(b);
|
||||||
TIMEM("vector-scalar", subtract, a, b, device);
|
TIMEM("scalar", mx::add, a, b, device);
|
||||||
TIMEM("scalar-vector", subtract, b, a, device);
|
TIMEM("vector-scalar", mx::subtract, a, b, device);
|
||||||
TIMEM("scalar", multiply, a, b, device);
|
TIMEM("scalar-vector", mx::subtract, b, a, device);
|
||||||
TIMEM("vector-scalar", divide, a, b, device);
|
TIMEM("scalar", mx::multiply, a, b, device);
|
||||||
TIMEM("scalar-vector", divide, b, a, device);
|
TIMEM("vector-scalar", mx::divide, a, b, device);
|
||||||
|
TIMEM("scalar-vector", mx::divide, b, a, device);
|
||||||
|
TIMEM("scalar-vector", mx::where, condition, a, b, device);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
|
||||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||||
eval(a, b);
|
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||||
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
mx::eval(a, b);
|
||||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
|
||||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
|
||||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
|
||||||
|
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
|
||||||
|
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_strided_ops() {
|
void time_strided_ops() {
|
||||||
int M = 50, N = 50, O = 50, P = 50;
|
int M = 50, N = 50, O = 50, P = 50;
|
||||||
auto a = random::uniform({M, N, O, P});
|
auto a = mx::random::uniform({M, N, O, P});
|
||||||
auto b = random::uniform({M, N, O, P});
|
auto b = mx::random::uniform({M, N, O, P});
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
TIMEM("non-strided", add, a, b, device);
|
TIMEM("non-strided", mx::add, a, b, device);
|
||||||
a = transpose(a, {1, 0, 2, 3});
|
a = mx::transpose(a, {1, 0, 2, 3});
|
||||||
b = transpose(b, {3, 2, 0, 1});
|
b = mx::transpose(b, {3, 2, 0, 1});
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
TIMEM("strided", add, a, b, device);
|
TIMEM("strided", mx::add, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_comparisons() {
|
void time_comparisons() {
|
||||||
int M = 1000, N = 100, K = 10;
|
int M = 1000, N = 100, K = 10;
|
||||||
auto a = random::uniform({M, N, K});
|
auto a = mx::random::uniform({M, N, K});
|
||||||
auto b = random::uniform({M, N, K});
|
auto b = mx::random::uniform({M, N, K});
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
TIME(equal, a, b, device);
|
TIME(mx::equal, a, b, device);
|
||||||
TIME(greater, a, b, device);
|
TIME(mx::greater, a, b, device);
|
||||||
TIME(greater_equal, a, b, device);
|
TIME(mx::greater_equal, a, b, device);
|
||||||
TIME(less, a, b, device);
|
TIME(mx::less, a, b, device);
|
||||||
TIME(less_equal, a, b, device);
|
TIME(mx::less_equal, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_matvec() {
|
void time_matvec() {
|
||||||
int M = 2000, N = 200;
|
int M = 2000, N = 200;
|
||||||
auto a = random::uniform({M, N});
|
auto a = mx::random::uniform({M, N});
|
||||||
auto b = random::uniform({N});
|
auto b = mx::random::uniform({N});
|
||||||
auto c = random::uniform({M});
|
auto c = mx::random::uniform({M});
|
||||||
eval(a, b, c);
|
mx::eval(a, b, c);
|
||||||
auto matvec = [&]() { return matmul(a, b); };
|
auto matvec = [&]() { return mx::matmul(a, b); };
|
||||||
TIME(matvec);
|
TIME(matvec);
|
||||||
|
|
||||||
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
|
||||||
TIME(matvec_transpose);
|
TIME(matvec_transpose);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_matmul() {
|
void time_matmul() {
|
||||||
int M = 1000, N = 1000, K = 1000;
|
int M = 1000, N = 1000, K = 1000;
|
||||||
auto a = random::uniform({M, K});
|
auto a = mx::random::uniform({M, K});
|
||||||
auto b = random::uniform({K, N});
|
auto b = mx::random::uniform({K, N});
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
TIME(matmul, a, b, device);
|
TIME(mx::matmul, a, b, device);
|
||||||
|
|
||||||
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
|
||||||
TIME(transpose_matmul);
|
TIME(transpose_matmul);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_reductions() {
|
void time_reductions() {
|
||||||
auto a = random::normal({10000, 1000});
|
auto a = mx::random::normal({10000, 1000});
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
auto sum_all = [&a]() { return sum(a, false); };
|
auto sum_all = [&a]() { return mx::sum(a, false); };
|
||||||
TIME(sum_all);
|
TIME(sum_all);
|
||||||
|
|
||||||
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
|
||||||
TIME(sum_along_0);
|
TIME(sum_along_0);
|
||||||
|
|
||||||
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
|
||||||
TIME(sum_along_1);
|
TIME(sum_along_1);
|
||||||
|
|
||||||
auto prod_all = [&a]() { return prod(a, false); };
|
auto prod_all = [&a]() { return mx::prod(a, false); };
|
||||||
TIME(prod_all);
|
TIME(prod_all);
|
||||||
|
|
||||||
auto all_true = [&a]() { return all(a, false); };
|
auto all_true = [&a]() { return mx::all(a, false); };
|
||||||
TIME(all_true);
|
TIME(all_true);
|
||||||
|
|
||||||
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
|
||||||
TIME(all_along_0);
|
TIME(all_along_0);
|
||||||
|
|
||||||
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
|
||||||
TIME(all_along_1);
|
TIME(all_along_1);
|
||||||
|
|
||||||
auto any_true = [&a]() { return any(a, false); };
|
auto any_true = [&a]() { return mx::any(a, false); };
|
||||||
TIME(any_true);
|
TIME(any_true);
|
||||||
|
|
||||||
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
|
||||||
TIME(argmin_along_0);
|
TIME(argmin_along_0);
|
||||||
|
|
||||||
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||||
TIME(argmin_along_1);
|
TIME(argmin_along_1);
|
||||||
|
|
||||||
|
auto indices = mx::array({1});
|
||||||
|
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||||
|
std::vector<int> axes{0};
|
||||||
|
auto b = scatter(a, {indices}, updates, axes);
|
||||||
|
mx::eval(b);
|
||||||
|
|
||||||
|
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||||
|
TIME(max_along_0);
|
||||||
|
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||||
|
TIME(max_along_1);
|
||||||
|
|
||||||
|
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||||
|
TIME(min_along_0);
|
||||||
|
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||||
|
TIME(min_along_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_gather_scatter() {
|
void time_gather_scatter() {
|
||||||
auto a = random::normal({1000, 768});
|
auto a = mx::random::normal({1000, 768});
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
auto indices = random::randint(0, 1000, {256});
|
auto indices = mx::random::randint(0, 1000, {256});
|
||||||
eval(indices);
|
mx::eval(indices);
|
||||||
|
|
||||||
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
|
||||||
TIME(embedding_lookup);
|
TIME(embedding_lookup);
|
||||||
|
|
||||||
indices = random::randint(0, 768 * 1000, {256 * 768});
|
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
|
||||||
eval(indices);
|
mx::eval(indices);
|
||||||
|
|
||||||
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
auto single_element_lookup = [&a, &indices]() {
|
||||||
|
return mx::take(a, indices);
|
||||||
|
};
|
||||||
TIME(single_element_lookup);
|
TIME(single_element_lookup);
|
||||||
|
|
||||||
indices = random::randint(0, 1000, {256});
|
indices = mx::random::randint(0, 1000, {256});
|
||||||
auto updates = random::normal({256, 1, 768});
|
auto updates = mx::random::normal({256, 1, 768});
|
||||||
eval(indices, updates);
|
mx::eval(indices, updates);
|
||||||
|
|
||||||
auto embedding_update = [&a, &indices, &updates]() {
|
auto embedding_update = [&a, &indices, &updates]() {
|
||||||
return scatter(a, indices, updates, 0);
|
return scatter(a, indices, updates, 0);
|
||||||
@@ -217,10 +241,10 @@ void time_gather_scatter() {
|
|||||||
};
|
};
|
||||||
TIME(embedding_add);
|
TIME(embedding_add);
|
||||||
|
|
||||||
a = reshape(a, {-1});
|
a = mx::reshape(a, {-1});
|
||||||
indices = random::randint(0, 768 * 1000, {768 * 256});
|
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
|
||||||
updates = random::normal({256 * 768, 1});
|
updates = mx::random::normal({256 * 768, 1});
|
||||||
eval(a, indices, updates);
|
mx::eval(a, indices, updates);
|
||||||
|
|
||||||
auto single_element_update = [&a, &indices, &updates]() {
|
auto single_element_update = [&a, &indices, &updates]() {
|
||||||
return scatter(a, indices, updates, 0);
|
return scatter(a, indices, updates, 0);
|
||||||
@@ -234,21 +258,21 @@ void time_gather_scatter() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void time_divmod() {
|
void time_divmod() {
|
||||||
auto a = random::normal({1000});
|
auto a = mx::random::normal({1000});
|
||||||
auto b = random::normal({1000});
|
auto b = mx::random::normal({1000});
|
||||||
eval({a, b});
|
mx::eval({a, b});
|
||||||
|
|
||||||
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
|
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
|
||||||
TIME(divmod_fused);
|
TIME(divmod_fused);
|
||||||
|
|
||||||
auto divmod_separate = [&a, &b]() {
|
auto divmod_separate = [&a, &b]() {
|
||||||
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
|
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
|
||||||
};
|
};
|
||||||
TIME(divmod_separate);
|
TIME(divmod_separate);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||||
time_creation_ops();
|
time_creation_ops();
|
||||||
time_type_conversions();
|
time_type_conversions();
|
||||||
time_unary_ops();
|
time_unary_ops();
|
||||||
|
|||||||
@@ -18,13 +18,12 @@
|
|||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
|
||||||
#define TIMEM(MSG, FUNC, ...) \
|
#define TIMEM(MSG, FUNC, ...) \
|
||||||
std::cout << "Timing " \
|
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
|
||||||
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
|
<< std::flush << std::setprecision(5) \
|
||||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
template <typename F, typename... Args>
|
template <typename F, typename... Args>
|
||||||
double time_fn(F fn, Args... args) {
|
double time_fn(F fn, Args&&... args) {
|
||||||
// warmup
|
// warmup
|
||||||
for (int i = 0; i < 5; ++i) {
|
for (int i = 0; i < 5; ++i) {
|
||||||
eval(fn(std::forward<Args>(args)...));
|
eval(fn(std::forward<Args>(args)...));
|
||||||
|
|||||||
@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
|||||||
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
||||||
|
|
||||||
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
||||||
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
|
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
|
||||||
np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||||
|
|
||||||
dtypes = ("float32", "float16")
|
dtypes = ("float32", "float16", "complex64")
|
||||||
transposes = ("nn", "nt", "tn")
|
transposes = ("nn", "nt", "tn")
|
||||||
shapes = (
|
shapes = (
|
||||||
(16, 234, 768, 3072),
|
(16, 234, 768, 3072),
|
||||||
@@ -187,7 +185,7 @@ if __name__ == "__main__":
|
|||||||
diff = gflops_mx / gflops_pt - 1.0
|
diff = gflops_mx / gflops_pt - 1.0
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
|
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
|
||||||
)
|
)
|
||||||
if gflops_pt >= 2.0 * gflops_mx:
|
if gflops_pt >= 2.0 * gflops_mx:
|
||||||
print("ATTENTION ^^^^^^^")
|
print("ATTENTION ^^^^^^^")
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
@@ -196,7 +195,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
|||||||
|
|
||||||
|
|
||||||
for transpose in (False, True):
|
for transpose in (False, True):
|
||||||
for dtype in ("float32", "float16"):
|
for dtype in ("float32", "float16", "complex64"):
|
||||||
fig, axs = plt.subplots(
|
fig, axs = plt.subplots(
|
||||||
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
||||||
)
|
)
|
||||||
@@ -215,7 +214,7 @@ for transpose in (False, True):
|
|||||||
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
||||||
fig.savefig(
|
fig.savefig(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
|
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|||||||
@@ -144,6 +144,13 @@ def reduction(op, axis, x):
|
|||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
mx.eval(z)
|
||||||
|
|
||||||
|
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
@@ -380,10 +387,6 @@ if __name__ == "__main__":
|
|||||||
if len(args.axis) > 1:
|
if len(args.axis) > 1:
|
||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
if args.print_pid:
|
|
||||||
print(os.getpid())
|
|
||||||
input("Press enter to run")
|
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
mx.set_default_device(mx.cpu)
|
mx.set_default_device(mx.cpu)
|
||||||
else:
|
else:
|
||||||
@@ -406,6 +409,10 @@ if __name__ == "__main__":
|
|||||||
x = xs[0]
|
x = xs[0]
|
||||||
axis = args.axis[0]
|
axis = args.axis[0]
|
||||||
|
|
||||||
|
if args.print_pid:
|
||||||
|
print(os.getpid())
|
||||||
|
input("Press enter to run")
|
||||||
|
|
||||||
if args.benchmark == "matmul_square":
|
if args.benchmark == "matmul_square":
|
||||||
print(bench(matmul_square, x))
|
print(bench(matmul_square, x))
|
||||||
|
|
||||||
@@ -505,5 +512,8 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown benchmark")
|
raise ValueError("Unknown benchmark")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.cuda
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
@@ -44,8 +45,10 @@ def bench(f, *args):
|
|||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
def sync_if_needed(x):
|
||||||
if x.device != torch.device("cpu"):
|
if x.device == torch.device("mps"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
elif x.device == torch.device("cuda"):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
@@ -185,7 +196,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
|
|||||||
def mish(x: torch.Tensor) -> torch.Tensor:
|
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||||
y = x
|
y = x
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
return torch.nn.functional.mish(y)
|
y = torch.nn.functional.mish(y)
|
||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@@ -283,6 +294,14 @@ def topk(axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step_function(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.where(y < 0, 0, 1)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def selu(x):
|
def selu(x):
|
||||||
y = x
|
y = x
|
||||||
@@ -331,12 +350,12 @@ if __name__ == "__main__":
|
|||||||
if len(args.axis) > 1:
|
if len(args.axis) > 1:
|
||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
if args.print_pid:
|
|
||||||
print(os.getpid())
|
|
||||||
input("Press enter to run")
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "cpu" if args.cpu else "mps"
|
device = "mps"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
if args.cpu:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
types = args.dtype
|
types = args.dtype
|
||||||
if not types:
|
if not types:
|
||||||
@@ -354,6 +373,10 @@ if __name__ == "__main__":
|
|||||||
x = xs[0]
|
x = xs[0]
|
||||||
axis = args.axis[0]
|
axis = args.axis[0]
|
||||||
|
|
||||||
|
if args.print_pid:
|
||||||
|
print(os.getpid())
|
||||||
|
input("Press enter to run")
|
||||||
|
|
||||||
if args.benchmark == "matmul_square":
|
if args.benchmark == "matmul_square":
|
||||||
print(bench(matmul_square, x))
|
print(bench(matmul_square, x))
|
||||||
|
|
||||||
@@ -446,5 +469,14 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "topk":
|
elif args.benchmark == "topk":
|
||||||
print(bench(topk, axis, x))
|
print(bench(topk, axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "step":
|
||||||
|
print(bench(step_function, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "selu":
|
||||||
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown benchmark")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
|||||||
@@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
|
|||||||
result = run(*args, capture_output=True, **kwargs)
|
result = run(*args, capture_output=True, **kwargs)
|
||||||
return float(result.stdout)
|
return float(result.stdout)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
|
raise ValueError(
|
||||||
|
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compare(args):
|
def compare(args):
|
||||||
@@ -80,10 +82,8 @@ if __name__ == "__main__":
|
|||||||
_filter = make_predicate(args.filter, args.negative_filter)
|
_filter = make_predicate(args.filter, args.negative_filter)
|
||||||
|
|
||||||
if args.mlx_dtypes:
|
if args.mlx_dtypes:
|
||||||
compare_filtered = (
|
compare_filtered = lambda x: (
|
||||||
lambda x: compare_mlx_dtypes(
|
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
|
||||||
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
|
|
||||||
)
|
|
||||||
if _filter(x)
|
if _filter(x)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|||||||
107
benchmarks/python/compile_bench.py
Normal file
107
benchmarks/python/compile_bench.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def bench_gelu():
|
||||||
|
def gelu(x):
|
||||||
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1000, 1024))
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x):
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
time_fn(gen_fun(gelu), x, msg="fixed gelu")
|
||||||
|
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
|
||||||
|
|
||||||
|
def randint():
|
||||||
|
return random.randint(1, x.shape[0])
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x, y):
|
||||||
|
x = x[: randint()]
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
y = fun(y)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
y = mx.random.uniform(shape=(1000, 1024))
|
||||||
|
time_fn(gen_fun(gelu), x, y, msg="variable gelu")
|
||||||
|
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
|
||||||
|
time_fn(
|
||||||
|
gen_fun(mx.compile(gelu, shapeless=True)),
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
msg="shapeless variable gelu",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_layernorm():
|
||||||
|
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
mx.eval(weight, bias)
|
||||||
|
|
||||||
|
def layernorm(x):
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
means = mx.mean(x, axis=-1, keepdims=True)
|
||||||
|
var = mx.var(x, axis=-1, keepdims=True)
|
||||||
|
x = (x - means) * mx.rsqrt(var + 1e-4)
|
||||||
|
x = x.astype(mx.float16)
|
||||||
|
return weight * x + bias
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x):
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
|
||||||
|
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
|
||||||
|
|
||||||
|
def randint():
|
||||||
|
return random.randint(1, x.shape[0])
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x):
|
||||||
|
x = x[: randint()]
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
random.seed(0)
|
||||||
|
time_fn(gen_fun(layernorm), x, msg="variable layernorm")
|
||||||
|
random.seed(0)
|
||||||
|
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
|
||||||
|
random.seed(0)
|
||||||
|
time_fn(
|
||||||
|
gen_fun(mx.compile(layernorm, shapeless=True)),
|
||||||
|
x,
|
||||||
|
msg="shapeless variable layernorm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Compile benchmarks.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
bench_gelu()
|
||||||
|
bench_layernorm()
|
||||||
123
benchmarks/python/conv1d_bench.py
Normal file
123
benchmarks/python/conv1d_bench.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
device_name = device_name.decode("utf-8").strip("\n")
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_1D(strides=1, padding=0, groups=1):
|
||||||
|
def mx_conv_1D(a, b):
|
||||||
|
ys = []
|
||||||
|
for _ in range(N_iter_func):
|
||||||
|
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_1D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_1D(strides=1, padding=0, groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_1D(a, b):
|
||||||
|
ys = []
|
||||||
|
for _ in range(N_iter_func):
|
||||||
|
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_1D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
|
||||||
|
scale = 1.0 / math.sqrt(wH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_1D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_1D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv1d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 1),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 2),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 4),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 8),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 8),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 16),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 32),
|
||||||
|
(4, 32, 256, 5, 512, 1, 2, 2),
|
||||||
|
(4, 32, 256, 5, 512, 1, 2, 128),
|
||||||
|
(4, 32, 256, 5, 512, 1, 2, 256),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
|
||||||
|
for N, iH, C, wH, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, iH, C, wH, O, strides, padding, np_dtype, groups
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
127
benchmarks/python/conv2d_bench_cpu.py
Normal file
127
benchmarks/python/conv2d_bench_cpu.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 1
|
||||||
|
N_iter_bench = 10
|
||||||
|
N_iter_func = 5
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||||
|
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||||
|
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||||
|
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
143
benchmarks/python/conv2d_train_bench_cpu.py
Normal file
143
benchmarks/python/conv2d_train_bench_cpu.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn
|
||||||
|
import mlx.optimizers as opt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def bench_mlx(steps: int = 20) -> float:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
class BenchNetMLX(mlx.nn.Module):
|
||||||
|
# simple encoder-decoder net
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_channels=32):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.net = mlx.nn.Sequential(
|
||||||
|
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.Conv2d(
|
||||||
|
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.ConvTranspose2d(
|
||||||
|
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.ConvTranspose2d(
|
||||||
|
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, input):
|
||||||
|
return self.net(input)
|
||||||
|
|
||||||
|
benchNet = BenchNetMLX(3)
|
||||||
|
mx.eval(benchNet.parameters())
|
||||||
|
optim = opt.Adam(learning_rate=1e-3)
|
||||||
|
|
||||||
|
inputs = mx.random.normal([10, 256, 256, 3])
|
||||||
|
|
||||||
|
params = benchNet.parameters()
|
||||||
|
optim.init(params)
|
||||||
|
|
||||||
|
state = [benchNet.state, optim.state]
|
||||||
|
|
||||||
|
def loss_fn(params, image):
|
||||||
|
benchNet.update(params)
|
||||||
|
pred_image = benchNet(image)
|
||||||
|
return (pred_image - image).abs().mean()
|
||||||
|
|
||||||
|
def step(params, image):
|
||||||
|
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
||||||
|
optim.update(benchNet, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
total_time = 0.0
|
||||||
|
print("MLX:")
|
||||||
|
for i in range(steps):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
step(benchNet.parameters(), inputs)
|
||||||
|
mx.eval(state)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
|
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||||
|
total_time += (end_time - start_time) * 1000
|
||||||
|
|
||||||
|
return total_time
|
||||||
|
|
||||||
|
|
||||||
|
def bench_torch(steps: int = 20) -> float:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
class BenchNetTorch(torch.nn.Module):
|
||||||
|
# simple encoder-decoder net
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_channels=32):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.net = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.ConvTranspose2d(
|
||||||
|
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.ConvTranspose2d(
|
||||||
|
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return self.net(input)
|
||||||
|
|
||||||
|
benchNet = BenchNetTorch(3).to(device)
|
||||||
|
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
inputs = torch.randn(10, 3, 256, 256, device=device)
|
||||||
|
|
||||||
|
def loss_fn(pred_image, image):
|
||||||
|
return (pred_image - image).abs().mean()
|
||||||
|
|
||||||
|
total_time = 0.0
|
||||||
|
print("PyTorch:")
|
||||||
|
for i in range(steps):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
optim.zero_grad()
|
||||||
|
pred_image = benchNet(inputs)
|
||||||
|
loss = loss_fn(pred_image, inputs)
|
||||||
|
loss.backward()
|
||||||
|
optim.step()
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
|
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||||
|
total_time += (end_time - start_time) * 1000
|
||||||
|
|
||||||
|
return total_time
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
steps = 20
|
||||||
|
time_mlx = bench_mlx(steps)
|
||||||
|
time_torch = bench_torch(steps)
|
||||||
|
|
||||||
|
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
||||||
|
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
||||||
|
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
||||||
|
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
||||||
|
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
129
benchmarks/python/conv2d_transpose_bench_cpu.py
Normal file
129
benchmarks/python/conv2d_transpose_bench_cpu.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 1
|
||||||
|
N_iter_bench = 10
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_transpose_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv_transpose2d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_transpose_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_transpose_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv_transpose2d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_transpose_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv_transpose2d(
|
||||||
|
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
||||||
|
)
|
||||||
|
out_pt = torch.conv_transpose2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
110
benchmarks/python/conv3d_bench_cpu.py
Normal file
110
benchmarks/python/conv3d_bench_cpu.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 1
|
||||||
|
N_iter_bench = 10
|
||||||
|
N_iter_func = 5
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_3D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_3D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_3D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_3D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_3D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_3D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv3d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
||||||
|
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
143
benchmarks/python/conv3d_train_bench_cpu.py
Normal file
143
benchmarks/python/conv3d_train_bench_cpu.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn
|
||||||
|
import mlx.optimizers as opt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
class BenchNetMLX(mlx.nn.Module):
|
||||||
|
# simple encoder-decoder net
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_channels=16):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.net = mlx.nn.Sequential(
|
||||||
|
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.Conv3d(
|
||||||
|
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.ConvTranspose3d(
|
||||||
|
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.ConvTranspose3d(
|
||||||
|
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, input):
|
||||||
|
return self.net(input)
|
||||||
|
|
||||||
|
benchNet = BenchNetMLX(3)
|
||||||
|
mx.eval(benchNet.parameters())
|
||||||
|
optim = opt.Adam(learning_rate=1e-3)
|
||||||
|
|
||||||
|
inputs = mx.random.normal(shape)
|
||||||
|
|
||||||
|
params = benchNet.parameters()
|
||||||
|
optim.init(params)
|
||||||
|
|
||||||
|
state = [benchNet.state, optim.state]
|
||||||
|
|
||||||
|
def loss_fn(params, image):
|
||||||
|
benchNet.update(params)
|
||||||
|
pred_image = benchNet(image)
|
||||||
|
return (pred_image - image).abs().mean()
|
||||||
|
|
||||||
|
def step(params, image):
|
||||||
|
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
||||||
|
optim.update(benchNet, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
total_time = 0.0
|
||||||
|
print("MLX:")
|
||||||
|
for i in range(steps):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
step(benchNet.parameters(), inputs)
|
||||||
|
mx.eval(state)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
|
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||||
|
total_time += (end_time - start_time) * 1000
|
||||||
|
|
||||||
|
return total_time
|
||||||
|
|
||||||
|
|
||||||
|
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
class BenchNetTorch(torch.nn.Module):
|
||||||
|
# simple encoder-decoder net
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_channels=16):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.net = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv3d(
|
||||||
|
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.ConvTranspose3d(
|
||||||
|
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.ConvTranspose3d(
|
||||||
|
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return self.net(input)
|
||||||
|
|
||||||
|
benchNet = BenchNetTorch(3).to(device)
|
||||||
|
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
inputs = torch.randn(*shape, device=device)
|
||||||
|
|
||||||
|
def loss_fn(pred_image, image):
|
||||||
|
return (pred_image - image).abs().mean()
|
||||||
|
|
||||||
|
total_time = 0.0
|
||||||
|
print("PyTorch:")
|
||||||
|
for i in range(steps):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
optim.zero_grad()
|
||||||
|
pred_image = benchNet(inputs)
|
||||||
|
loss = loss_fn(pred_image, inputs)
|
||||||
|
loss.backward()
|
||||||
|
optim.step()
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
|
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||||
|
total_time += (end_time - start_time) * 1000
|
||||||
|
|
||||||
|
return total_time
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
steps = 10
|
||||||
|
time_mlx = bench_mlx(steps)
|
||||||
|
time_torch = bench_torch(steps)
|
||||||
|
|
||||||
|
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
||||||
|
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
||||||
|
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
||||||
|
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
||||||
|
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
116
benchmarks/python/conv3d_transpose_bench_cpu.py
Normal file
116
benchmarks/python/conv3d_transpose_bench_cpu.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 1
|
||||||
|
N_iter_bench = 10
|
||||||
|
N_iter_func = 5
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
||||||
|
def mx_conv_3D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv_transpose3d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_3D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_3D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv_transpose3d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_3D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_3D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_3D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv_transpose3d(
|
||||||
|
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.conv_transpose3d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
||||||
|
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
135
benchmarks/python/conv_bench.py
Normal file
135
benchmarks/python/conv_bench.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
device_name = device_name.decode("utf-8").strip("\n")
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||||
|
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
135
benchmarks/python/conv_transpose_bench.py
Normal file
135
benchmarks/python/conv_transpose_bench.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_transpose_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv_transpose2d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_transpose_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_transpose_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv_transpose2d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_transpose_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv_transpose2d(
|
||||||
|
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.conv_transpose2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dtype = "float32"
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 21, 3, 3, 128),
|
||||||
|
(4, 32, 32, 21, 3, 3, 37),
|
||||||
|
(4, 32, 32, 370, 3, 3, 370),
|
||||||
|
(4, 32, 32, 370, 7, 7, 128),
|
||||||
|
(2, 320, 640, 21, 7, 7, 21),
|
||||||
|
)
|
||||||
|
for N, H, W, C, kh, kw, O in shapes:
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
66
benchmarks/python/distributed_bench.py
Normal file
66
benchmarks/python/distributed_bench.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Run with:
|
||||||
|
mpirun -n 2 python /path/to/distributed_bench.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def time_fn(fn, *args, **kwargs):
|
||||||
|
msg = kwargs.pop("msg", None)
|
||||||
|
world = mx.distributed.init()
|
||||||
|
if world.rank() == 0:
|
||||||
|
if msg:
|
||||||
|
print(f"Timing {msg} ...", end=" ")
|
||||||
|
else:
|
||||||
|
print(f"Timing {fn.__name__} ...", end=" ")
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(5):
|
||||||
|
mx.eval(fn(*args, **kwargs))
|
||||||
|
|
||||||
|
num_iters = 100
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(num_iters):
|
||||||
|
x = mx.eval(fn(*args, **kwargs))
|
||||||
|
toc = time.perf_counter()
|
||||||
|
|
||||||
|
msec = 1e3 * (toc - tic) / num_iters
|
||||||
|
if world.rank() == 0:
|
||||||
|
print(f"{msec:.5f} msec")
|
||||||
|
|
||||||
|
|
||||||
|
def time_all_sum():
|
||||||
|
shape = (4096,)
|
||||||
|
x = mx.random.uniform(shape=shape)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def sine(x):
|
||||||
|
for _ in range(20):
|
||||||
|
x = mx.sin(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(sine, x)
|
||||||
|
|
||||||
|
def all_sum_plain(x):
|
||||||
|
for _ in range(20):
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(all_sum_plain, x)
|
||||||
|
|
||||||
|
def all_sum_with_sine(x):
|
||||||
|
for _ in range(20):
|
||||||
|
x = mx.sin(x)
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(all_sum_with_sine, x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_all_sum()
|
||||||
84
benchmarks/python/einsum_bench.py
Normal file
84
benchmarks/python/einsum_bench.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def timeit(fn, its=100, args=[]):
|
||||||
|
for _ in range(5):
|
||||||
|
fn(*args)
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(its):
|
||||||
|
fn(*args)
|
||||||
|
toc = time.perf_counter()
|
||||||
|
return 1e3 * (toc - tic) / its
|
||||||
|
|
||||||
|
|
||||||
|
def time_little_einsum_path():
|
||||||
|
subscripts = "ik,kj->ij"
|
||||||
|
x = mx.ones((32, 32))
|
||||||
|
y = mx.ones((32, 32))
|
||||||
|
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
|
||||||
|
|
||||||
|
x = np.array(x)
|
||||||
|
y = np.array(y)
|
||||||
|
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
|
||||||
|
print("Timing little einsum path...")
|
||||||
|
print(f"MLX ... {mx_time:.3f} ms")
|
||||||
|
print(f"NumPy... {np_time:.3f} ms")
|
||||||
|
|
||||||
|
|
||||||
|
def time_big_einsum_path():
|
||||||
|
chars = list("abcdefgh")
|
||||||
|
char_to_dim = {c: v for v, c in enumerate(chars)}
|
||||||
|
|
||||||
|
num_inputs = 10
|
||||||
|
inputs = []
|
||||||
|
subscripts = []
|
||||||
|
for _ in range(num_inputs):
|
||||||
|
subscript = np.random.choice(chars, size=5, replace=False).tolist()
|
||||||
|
subscripts.append("".join(subscript))
|
||||||
|
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
|
||||||
|
subscripts = ",".join(subscripts)
|
||||||
|
|
||||||
|
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
|
||||||
|
|
||||||
|
inputs = [mx.array(x) for x in inputs]
|
||||||
|
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
|
||||||
|
print("Timing big einsum path...")
|
||||||
|
print(f"MLX ... {mx_time:.3f} ms")
|
||||||
|
print(f"NumPy... {np_time:.3f} ms")
|
||||||
|
|
||||||
|
|
||||||
|
def time_attention():
|
||||||
|
def regular_attention(x):
|
||||||
|
# shape [batch, sequence, num_heads, head_dim]
|
||||||
|
queries, keys, values = x, x, x
|
||||||
|
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
|
||||||
|
mx.eval(output)
|
||||||
|
|
||||||
|
def einsum_attention(x):
|
||||||
|
# shape [batch, sequence, num_heads, head_dim]
|
||||||
|
queries, keys, values = x, x, x
|
||||||
|
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
output = mx.einsum("ijtu,iujk->itjk", scores, values)
|
||||||
|
mx.eval(output)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, 512, 32, 128))
|
||||||
|
|
||||||
|
regular_time = timeit(regular_attention, args=(x,))
|
||||||
|
ein_time = timeit(einsum_attention, args=(x,))
|
||||||
|
print("Timing einsum attention...")
|
||||||
|
print(f"Regular ... {regular_time:.3f} ms")
|
||||||
|
print(f"Einsum ... {ein_time:.3f} ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_little_einsum_path()
|
||||||
|
time_big_einsum_path()
|
||||||
|
time_attention()
|
||||||
118
benchmarks/python/fft_bench.py
Normal file
118
benchmarks/python/fft_bench.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import sympy
|
||||||
|
import torch
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def bandwidth_gb(runtime_ms, system_size):
|
||||||
|
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
|
||||||
|
bytes_per_gb = 1e9
|
||||||
|
ms_per_s = 1e3
|
||||||
|
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
||||||
|
|
||||||
|
|
||||||
|
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
|
||||||
|
def fft_mlx(x):
|
||||||
|
if dim == 1:
|
||||||
|
out = mx.fft.fft(x)
|
||||||
|
elif dim == 2:
|
||||||
|
out = mx.fft.fft2(x)
|
||||||
|
mx.eval(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def fft_mps(x):
|
||||||
|
if dim == 1:
|
||||||
|
out = torch.fft.fft(x)
|
||||||
|
elif dim == 2:
|
||||||
|
out = torch.fft.fft2(x)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return out
|
||||||
|
|
||||||
|
bandwidths = []
|
||||||
|
for n in fft_sizes:
|
||||||
|
batch_size = system_size // n**dim
|
||||||
|
shape = [batch_size] + [n for _ in range(dim)]
|
||||||
|
if backend == "mlx":
|
||||||
|
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||||
|
x = mx.array(x_np)
|
||||||
|
mx.eval(x)
|
||||||
|
fft = fft_mlx
|
||||||
|
elif backend == "mps":
|
||||||
|
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||||
|
x = torch.tensor(x_np, device="mps")
|
||||||
|
torch.mps.synchronize()
|
||||||
|
fft = fft_mps
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
runtime_ms = measure_runtime(fft, x=x)
|
||||||
|
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
|
||||||
|
print(n, bandwidth)
|
||||||
|
bandwidths.append(bandwidth)
|
||||||
|
|
||||||
|
return np.array(bandwidths)
|
||||||
|
|
||||||
|
|
||||||
|
def time_fft():
|
||||||
|
x = np.array(range(2, 512))
|
||||||
|
system_size = int(2**26)
|
||||||
|
|
||||||
|
print("MLX GPU")
|
||||||
|
with mx.stream(mx.gpu):
|
||||||
|
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||||
|
|
||||||
|
print("MPS GPU")
|
||||||
|
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
|
||||||
|
|
||||||
|
print("CPU")
|
||||||
|
system_size = int(2**20)
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||||
|
|
||||||
|
x = np.array(x)
|
||||||
|
|
||||||
|
all_indices = x - x[0]
|
||||||
|
radix_2to13 = (
|
||||||
|
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
|
||||||
|
)
|
||||||
|
bluesteins = (
|
||||||
|
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
for indices, name in [
|
||||||
|
(all_indices, "All"),
|
||||||
|
(radix_2to13, "Radix 2-13"),
|
||||||
|
(bluesteins, "Bluestein's"),
|
||||||
|
]:
|
||||||
|
# plot bandwidths
|
||||||
|
print(name)
|
||||||
|
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
|
||||||
|
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
|
||||||
|
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
|
||||||
|
plt.title(f"MLX FFT Benchmark -- {name}")
|
||||||
|
plt.xlabel("N")
|
||||||
|
plt.ylabel("Bandwidth (GB/s)")
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(f"{name}.png")
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
av_gpu_bandwidth = np.mean(gpu_bandwidths)
|
||||||
|
av_mps_bandwidth = np.mean(mps_bandwidths)
|
||||||
|
av_cpu_bandwidth = np.mean(cpu_bandwidths)
|
||||||
|
print("Average bandwidths:")
|
||||||
|
print("GPU:", av_gpu_bandwidth)
|
||||||
|
print("MPS:", av_mps_bandwidth)
|
||||||
|
print("CPU:", av_cpu_bandwidth)
|
||||||
|
|
||||||
|
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
|
||||||
|
print("Percent MLX faster than MPS: ", portion_faster * 100)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_fft()
|
||||||
@@ -1,22 +1,10 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from time import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import torch
|
import torch
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
def measure_runtime(fn, **kwargs):
|
|
||||||
# Warmup
|
|
||||||
for _ in range(5):
|
|
||||||
fn(**kwargs)
|
|
||||||
|
|
||||||
tic = time()
|
|
||||||
iters = 10
|
|
||||||
for _ in range(iters):
|
|
||||||
fn(**kwargs)
|
|
||||||
return (time() - tic) * 1000 / iters
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_gather_mlx(x_shape, idx_shape):
|
def benchmark_gather_mlx(x_shape, idx_shape):
|
||||||
|
|||||||
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_mm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = x @ w1.T
|
||||||
|
x = x @ w2.T
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_mm()
|
||||||
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate(
|
||||||
|
[
|
||||||
|
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||||
|
for i, j in enumerate(idx.tolist())
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_qmm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||||
|
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_qmm()
|
||||||
70
benchmarks/python/hadamard_bench.py
Normal file
70
benchmarks/python/hadamard_bench.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def had(x):
|
||||||
|
y = mx.hadamard_transform(x)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def copy(x):
|
||||||
|
y = x + 1.0
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def run(dtype):
|
||||||
|
system_size = 2**26
|
||||||
|
outputs = {}
|
||||||
|
for test_fn in (had, copy):
|
||||||
|
for m in [1, 12, 20, 28]:
|
||||||
|
if test_fn == copy:
|
||||||
|
key = "copy"
|
||||||
|
elif m == 1:
|
||||||
|
key = "had_2^k"
|
||||||
|
else:
|
||||||
|
key = "had_m*2^k"
|
||||||
|
outputs.setdefault(key, {})
|
||||||
|
for k in range(7, 14):
|
||||||
|
n = m * 2**k
|
||||||
|
if n > 2**15:
|
||||||
|
continue
|
||||||
|
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
|
||||||
|
x = mx.array(x_np)
|
||||||
|
runtime_ms = measure_runtime(test_fn, x=x)
|
||||||
|
bytes_per_gb = 1e9
|
||||||
|
ms_per_s = 1e3
|
||||||
|
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
|
||||||
|
bandwidth_gb = (
|
||||||
|
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
|
||||||
|
)
|
||||||
|
print(n, bandwidth_gb)
|
||||||
|
outputs[key][n] = bandwidth_gb
|
||||||
|
|
||||||
|
colors = {
|
||||||
|
"copy": "black",
|
||||||
|
"had_2^k": "steelblue",
|
||||||
|
"had_m*2^k": "skyblue",
|
||||||
|
}
|
||||||
|
for key, output in outputs.items():
|
||||||
|
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
|
||||||
|
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
|
||||||
|
plt.xlabel("N")
|
||||||
|
plt.ylabel("Bandwidth (GB/s)")
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(f"bench_{dtype.__name__}.png")
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--fp16", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
dtype = np.float16 if args.fp16 else np.float32
|
||||||
|
run(dtype)
|
||||||
82
benchmarks/python/layer_norm_bench.py
Normal file
82
benchmarks/python/layer_norm_bench.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def layer_norm(x, w, b, eps):
|
||||||
|
ot = x.dtype
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
mu = mx.mean(x, -1, keepdims=True)
|
||||||
|
v = mx.var(x, -1, keepdims=True)
|
||||||
|
y = (x - mu) * mx.rsqrt(v + eps)
|
||||||
|
if w is not None:
|
||||||
|
y = y * w
|
||||||
|
if b is not None:
|
||||||
|
y = y + b
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def time_layer_norm(N, dt):
|
||||||
|
L = 1024
|
||||||
|
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
|
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
|
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||||
|
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
|
def layer_norm_loop(f, x, w, b):
|
||||||
|
for _ in range(32):
|
||||||
|
x = f(x, w, b)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||||
|
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||||
|
|
||||||
|
def layer_norm_grad_loop(g, x, w, b):
|
||||||
|
gx, gw, gb = x, w, b
|
||||||
|
for _ in range(32):
|
||||||
|
gx, gw, gb = g(gx, gw, gb, y)
|
||||||
|
return gx, gw, gb
|
||||||
|
|
||||||
|
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||||
|
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||||
|
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||||
|
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||||
|
|
||||||
|
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
|
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
|
def layer_norm_grad_x_loop(g, x):
|
||||||
|
gx = x
|
||||||
|
for _ in range(32):
|
||||||
|
gx = g(gx, y)
|
||||||
|
return gx
|
||||||
|
|
||||||
|
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||||
|
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||||
|
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||||
|
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||||
|
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||||
|
print(dt, n)
|
||||||
|
time_layer_norm(n, dt)
|
||||||
212
benchmarks/python/masked_scatter.py
Normal file
212
benchmarks/python/masked_scatter.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from copy import copy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from matplotlib.ticker import FuncFormatter
|
||||||
|
|
||||||
|
RESULTS_DIR = "./results"
|
||||||
|
|
||||||
|
|
||||||
|
if not os.path.isdir(RESULTS_DIR):
|
||||||
|
os.mkdir(RESULTS_DIR)
|
||||||
|
|
||||||
|
DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")
|
||||||
|
|
||||||
|
TORCH_DEVICE = torch.device(
|
||||||
|
"mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
N_WARMUP = 5
|
||||||
|
N_ITER_BENCH = 50
|
||||||
|
N_ITER_FUNC = 20
|
||||||
|
|
||||||
|
VECTOR_LENGTHS = [4096 * (2**i) for i in range(10)]
|
||||||
|
MASK_DENSITIES = [0.01, 0.1, 0.25, 0.5]
|
||||||
|
D_TYPES = ("float32", "float16")
|
||||||
|
|
||||||
|
|
||||||
|
def _power_of_two_formatter(value, _position):
|
||||||
|
if value <= 0:
|
||||||
|
return ""
|
||||||
|
exponent = int(round(math.log2(value)))
|
||||||
|
if abs(value - (1 << exponent)) / value > 1e-6:
|
||||||
|
return f"{value:g}"
|
||||||
|
return f"$2^{{{exponent}}}$"
|
||||||
|
|
||||||
|
|
||||||
|
def torch_sync():
|
||||||
|
if TORCH_DEVICE.type == "cuda":
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elif TORCH_DEVICE.type == "mps":
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def masked_scatter_mlx(self_arr, mask_arr, src_arr):
|
||||||
|
outs = []
|
||||||
|
for _ in range(N_ITER_FUNC):
|
||||||
|
out = copy(self_arr)
|
||||||
|
out[mask_arr] = src_arr
|
||||||
|
outs.append(out)
|
||||||
|
mx.eval(outs)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def masked_scatter_torch(self_tensor, mask_tensor, src_tensor):
|
||||||
|
outs = []
|
||||||
|
for _ in range(N_ITER_FUNC):
|
||||||
|
out = self_tensor.clone()
|
||||||
|
out.masked_scatter_(mask_tensor, src_tensor)
|
||||||
|
outs.append(out)
|
||||||
|
torch_sync()
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
def measure(fn):
|
||||||
|
for _ in range(N_WARMUP):
|
||||||
|
fn()
|
||||||
|
start = time.perf_counter_ns()
|
||||||
|
for _ in range(N_ITER_BENCH):
|
||||||
|
fn()
|
||||||
|
end = time.perf_counter_ns()
|
||||||
|
return (end - start) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_touched(length, true_count, item_size):
|
||||||
|
mask_bytes = length
|
||||||
|
self_bytes = length * item_size * 2 # read + write
|
||||||
|
src_bytes = true_count * item_size
|
||||||
|
return (mask_bytes + self_bytes + src_bytes) * N_ITER_FUNC * N_ITER_BENCH
|
||||||
|
|
||||||
|
|
||||||
|
def build_case(length, density, np_dtype, torch_dtype):
|
||||||
|
true_count = max(1, int(round(length * density)))
|
||||||
|
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
self_np = rng.normal(0.0, 1.0, length).astype(np_dtype)
|
||||||
|
mask_np = np.zeros(length, dtype=bool)
|
||||||
|
mask_np[:true_count] = True
|
||||||
|
rng.shuffle(mask_np)
|
||||||
|
src_np = rng.normal(0.0, 1.0, true_count).astype(np_dtype)
|
||||||
|
|
||||||
|
self_mlx = mx.array(self_np)
|
||||||
|
mask_mlx = mx.array(mask_np)
|
||||||
|
src_mlx = mx.array(src_np)
|
||||||
|
|
||||||
|
self_torch = torch.from_numpy(self_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||||
|
mask_torch = torch.from_numpy(mask_np).to(device=TORCH_DEVICE)
|
||||||
|
src_torch = torch.from_numpy(src_np).to(device=TORCH_DEVICE, dtype=torch_dtype)
|
||||||
|
|
||||||
|
# Correctness check once per configuration
|
||||||
|
mx_out = mx.array(self_np)
|
||||||
|
mx_out[mask_mlx] = src_mlx
|
||||||
|
mx.eval(mx_out)
|
||||||
|
torch_out = self_torch.clone()
|
||||||
|
torch_out.masked_scatter_(mask_torch, src_torch)
|
||||||
|
|
||||||
|
atol = 5e-3 if np_dtype == np.float16 else 1e-5
|
||||||
|
if not np.allclose(np.array(mx_out), torch_out.cpu().numpy(), atol=atol):
|
||||||
|
raise AssertionError("masked_scatter results diverged between MLX and Torch")
|
||||||
|
|
||||||
|
return (self_mlx, mask_mlx, src_mlx, self_torch, mask_torch, src_torch, true_count)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_case(length, density, dtype):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
torch_dtype = getattr(torch, dtype)
|
||||||
|
(
|
||||||
|
self_mlx,
|
||||||
|
mask_mlx,
|
||||||
|
src_mlx,
|
||||||
|
self_torch,
|
||||||
|
mask_torch,
|
||||||
|
src_torch,
|
||||||
|
true_count,
|
||||||
|
) = build_case(length, density, np_dtype, torch_dtype)
|
||||||
|
|
||||||
|
time_mlx = measure(partial(masked_scatter_mlx, self_mlx, mask_mlx, src_mlx))
|
||||||
|
time_torch = measure(
|
||||||
|
partial(masked_scatter_torch, self_torch, mask_torch, src_torch)
|
||||||
|
)
|
||||||
|
|
||||||
|
total_bytes = bytes_touched(length, true_count, np_dtype().itemsize)
|
||||||
|
bytes_per_gb = float(1024**3)
|
||||||
|
mlx_gbps = (total_bytes / bytes_per_gb) / time_mlx
|
||||||
|
torch_gbps = (total_bytes / bytes_per_gb) / time_torch
|
||||||
|
|
||||||
|
return time_mlx, time_torch, mlx_gbps, torch_gbps
|
||||||
|
|
||||||
|
|
||||||
|
def plot_density(ax_perf, ax_speedup, density, dtype):
|
||||||
|
mlx_gbps = []
|
||||||
|
torch_gbps = []
|
||||||
|
mlx_times = []
|
||||||
|
torch_times = []
|
||||||
|
|
||||||
|
for length in VECTOR_LENGTHS:
|
||||||
|
t_mlx, t_torch, gbps_mlx, gbps_torch = bench_case(length, density, dtype)
|
||||||
|
mlx_gbps.append(gbps_mlx)
|
||||||
|
torch_gbps.append(gbps_torch)
|
||||||
|
mlx_times.append(t_mlx)
|
||||||
|
torch_times.append(t_torch)
|
||||||
|
|
||||||
|
ax_perf.plot(VECTOR_LENGTHS, mlx_gbps, "tab:blue", label="MLX")
|
||||||
|
ax_perf.plot(VECTOR_LENGTHS, torch_gbps, "tab:red", label="Torch")
|
||||||
|
ax_perf.set_xscale("log", base=2)
|
||||||
|
ax_perf.set_xticks(VECTOR_LENGTHS)
|
||||||
|
formatter = FuncFormatter(_power_of_two_formatter)
|
||||||
|
ax_perf.xaxis.set_major_formatter(formatter)
|
||||||
|
ax_perf.set_title(f"density={density:.2f}")
|
||||||
|
ax_perf.set_ylabel("GB/s")
|
||||||
|
ax_perf.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||||
|
ax_perf.legend()
|
||||||
|
|
||||||
|
speedup = np.array(torch_times) / np.array(mlx_times)
|
||||||
|
ax_speedup.plot(VECTOR_LENGTHS, speedup, "tab:green")
|
||||||
|
ax_speedup.axhline(1.0, color="tab:gray", linestyle="--")
|
||||||
|
ax_speedup.set_xscale("log", base=2)
|
||||||
|
ax_speedup.set_xticks(VECTOR_LENGTHS)
|
||||||
|
ax_speedup.xaxis.set_major_formatter(formatter)
|
||||||
|
ax_speedup.set_ylabel("Speedup (Torch_t / MLX_t)")
|
||||||
|
ax_speedup.grid(True, which="both", linestyle=":", alpha=0.4)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
for dtype in D_TYPES:
|
||||||
|
fig, axs = plt.subplots(
|
||||||
|
len(MASK_DENSITIES),
|
||||||
|
2,
|
||||||
|
figsize=(10, 12),
|
||||||
|
layout="constrained",
|
||||||
|
sharex=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, density in enumerate(MASK_DENSITIES):
|
||||||
|
plot_density(axs[i][0], axs[i][1], density, dtype)
|
||||||
|
axs[i][0].set_xlabel("vector length")
|
||||||
|
axs[i][1].set_xlabel("vector length")
|
||||||
|
|
||||||
|
fig.suptitle(
|
||||||
|
f"{DEVICE_NAME.replace('Apple ', '')} ({TORCH_DEVICE.type}) | dtype={dtype}"
|
||||||
|
)
|
||||||
|
output_path = os.path.join(
|
||||||
|
RESULTS_DIR,
|
||||||
|
f"{DEVICE_NAME.replace(' ', '_')}_masked_scatter_{dtype}.pdf",
|
||||||
|
)
|
||||||
|
fig.savefig(output_path)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
63
benchmarks/python/rms_norm_bench.py
Normal file
63
benchmarks/python/rms_norm_bench.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(x, w, eps):
|
||||||
|
ot = x.dtype
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||||
|
y = (x * n).astype(ot)
|
||||||
|
if w is not None:
|
||||||
|
y = y * w
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def time_rms_norm():
|
||||||
|
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
|
||||||
|
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
|
||||||
|
g1 = mx.grad(f1, argnums=(0, 1))
|
||||||
|
g2 = mx.grad(f2, argnums=(0, 1))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
mx.eval(x, w, y)
|
||||||
|
|
||||||
|
def rms_norm_loop(g, x, w):
|
||||||
|
gx, gw = x, w
|
||||||
|
for _ in range(32):
|
||||||
|
gx, gw = g(gx, gw, y)
|
||||||
|
return gx, gw
|
||||||
|
|
||||||
|
time_fn(rms_norm_loop, g1, x, w)
|
||||||
|
time_fn(rms_norm_loop, g2, x, w)
|
||||||
|
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
||||||
|
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
||||||
|
|
||||||
|
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
|
||||||
|
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
|
||||||
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
mx.eval(x, w, y)
|
||||||
|
|
||||||
|
def rms_norm_loop(g, x):
|
||||||
|
gx = x
|
||||||
|
for _ in range(32):
|
||||||
|
gx = g(gx, y)
|
||||||
|
return gx
|
||||||
|
|
||||||
|
time_fn(rms_norm_loop, g1, x)
|
||||||
|
time_fn(rms_norm_loop, g2, x)
|
||||||
|
time_fn(rms_norm_loop, mx.compile(g1), x)
|
||||||
|
time_fn(rms_norm_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_rms_norm()
|
||||||
35
benchmarks/python/rope_bench.py
Normal file
35
benchmarks/python/rope_bench.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def time_rope():
|
||||||
|
rope = nn.RoPE(64)
|
||||||
|
|
||||||
|
# vec
|
||||||
|
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def rope_vec(x):
|
||||||
|
for _ in range(32):
|
||||||
|
x = rope(x, offset=100)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(rope_vec, x)
|
||||||
|
|
||||||
|
# matrix
|
||||||
|
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def rope_mat(x):
|
||||||
|
for _ in range(32):
|
||||||
|
x = rope(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(rope_mat, x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_rope()
|
||||||
96
benchmarks/python/scatter_bench.py
Normal file
96
benchmarks/python/scatter_bench.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import torch
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||||
|
def scatter(dst, x, idx):
|
||||||
|
dst[tuple(idx)] = x
|
||||||
|
mx.eval(dst)
|
||||||
|
|
||||||
|
idx = []
|
||||||
|
for idx_shape in idx_shapes:
|
||||||
|
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
|
||||||
|
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||||
|
dst = mx.random.normal(dst_shape).astype(mx.float32)
|
||||||
|
|
||||||
|
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
|
||||||
|
print(f"MLX: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||||
|
def scatter(dst, x, idx, device):
|
||||||
|
dst[tuple(idx)] = x
|
||||||
|
if device == torch.device("mps"):
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
idx = []
|
||||||
|
for idx_shape in idx_shapes:
|
||||||
|
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
|
||||||
|
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||||
|
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
|
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
|
||||||
|
print(f"PyTorch: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device("mps")
|
||||||
|
|
||||||
|
dst_shapes = [
|
||||||
|
(10, 64),
|
||||||
|
(100_000, 64),
|
||||||
|
(1_000_000, 64),
|
||||||
|
(100_000,),
|
||||||
|
(200_000,),
|
||||||
|
(20_000_000,),
|
||||||
|
(10000, 64),
|
||||||
|
(100, 64),
|
||||||
|
(100, 10_000, 64),
|
||||||
|
(10, 100, 100, 21),
|
||||||
|
(1_000, 1_000, 10),
|
||||||
|
]
|
||||||
|
idx_shapes = [
|
||||||
|
[(1_000_000,)],
|
||||||
|
[(1_000_000,)],
|
||||||
|
[(100_000,)],
|
||||||
|
[(1_000_000,)],
|
||||||
|
[(20_000_000,)],
|
||||||
|
[(20_000_000,)],
|
||||||
|
[(1000000,)],
|
||||||
|
[(10000000,)],
|
||||||
|
[(1_000,)],
|
||||||
|
[(10_000,)],
|
||||||
|
[(1_000,), (1_000,)],
|
||||||
|
]
|
||||||
|
x_shapes = [
|
||||||
|
(1_000_000, 64),
|
||||||
|
(1_000_000, 64),
|
||||||
|
(100_000, 64),
|
||||||
|
(1_000_000,),
|
||||||
|
(20_000_000,),
|
||||||
|
(20_000_000,),
|
||||||
|
(1000000, 64),
|
||||||
|
(10000000, 64),
|
||||||
|
(1_000, 10_000, 64),
|
||||||
|
(10_000, 100, 100, 21),
|
||||||
|
(1_000, 10),
|
||||||
|
]
|
||||||
|
|
||||||
|
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||||
|
print("=" * 20)
|
||||||
|
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
|
||||||
|
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||||
|
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
||||||
223
benchmarks/python/sdpa_bench.py
Normal file
223
benchmarks/python/sdpa_bench.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
device_name = device_name.decode("utf-8").strip("\n")
|
||||||
|
|
||||||
|
N_warmup = 5
|
||||||
|
N_iter_bench = 40
|
||||||
|
N_iter_func = 8
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, *args):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(*args)
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(*args)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
|
||||||
|
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||||
|
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||||
|
|
||||||
|
scale = 1.0 / math.sqrt(D)
|
||||||
|
|
||||||
|
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||||
|
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
|
||||||
|
q_mx = mx.array(q_np)
|
||||||
|
k_mx = mx.array(k_np)
|
||||||
|
v_mx = mx.array(v_np)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask == "additive":
|
||||||
|
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
elif mask == "bool":
|
||||||
|
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
|
||||||
|
return q_mx, k_mx, v_mx, scale, mask
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||||
|
q_dtype = q.dtype
|
||||||
|
q = q * mx.array(scale, q_dtype)
|
||||||
|
n_q_heads = q.shape[-3]
|
||||||
|
n_kv_heads = k.shape[-3]
|
||||||
|
n_repeats = n_q_heads // n_kv_heads
|
||||||
|
|
||||||
|
B = q.shape[0]
|
||||||
|
L = q.shape[2]
|
||||||
|
kL = k.shape[2]
|
||||||
|
|
||||||
|
if n_repeats > 1:
|
||||||
|
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||||
|
k = mx.expand_dims(k, 2)
|
||||||
|
v = mx.expand_dims(v, 2)
|
||||||
|
|
||||||
|
scores = q @ mx.swapaxes(k, -1, -2)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
|
||||||
|
if mask == "causal":
|
||||||
|
q_offset = max(0, kL - L)
|
||||||
|
q_indices = mx.arange(q_offset, q_offset + L)
|
||||||
|
k_indices = mx.arange(kL)
|
||||||
|
mask = q_indices[:, None] >= k_indices[None]
|
||||||
|
|
||||||
|
if n_repeats > 1 and mask.ndim >= 3:
|
||||||
|
if mask.shape[-3] == 1:
|
||||||
|
mask = mx.expand_dims(mask, -3)
|
||||||
|
else:
|
||||||
|
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||||
|
|
||||||
|
if mask.dtype == mx.bool_:
|
||||||
|
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||||
|
else:
|
||||||
|
scores += mask
|
||||||
|
|
||||||
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||||
|
|
||||||
|
out = scores @ v
|
||||||
|
if n_repeats > 1:
|
||||||
|
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_fused_attn(q, k, v, scale, mask):
|
||||||
|
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||||
|
if transpose:
|
||||||
|
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||||
|
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||||
|
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||||
|
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||||
|
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||||
|
else:
|
||||||
|
return f(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
||||||
|
q_out = q
|
||||||
|
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
||||||
|
|
||||||
|
mx.eval(q_out)
|
||||||
|
return q_out
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(
|
||||||
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
||||||
|
):
|
||||||
|
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||||
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
time_mlx_unfused = bench(
|
||||||
|
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
time_mlx_fused = bench(
|
||||||
|
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
|
||||||
|
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
||||||
|
o_mlx_unfused = do_attention(
|
||||||
|
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
|
||||||
|
atol = 1e-5 if dtype == "float32" else 2e-4
|
||||||
|
|
||||||
|
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx_fused, time_mlx_unfused
|
||||||
|
|
||||||
|
|
||||||
|
def get_gflop_count(B, M, N, K):
|
||||||
|
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float16", "float32")[:1]
|
||||||
|
transposes = (False,)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
shapes_64 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 32, 32, 64, 32, 32),
|
||||||
|
( 1, 64, 64, 64, 32, 32),
|
||||||
|
( 1, 128, 128, 64, 32, 32),
|
||||||
|
( 1, 256, 256, 64, 32, 32),
|
||||||
|
( 1, 512, 512, 64, 32, 32),
|
||||||
|
( 1, 1024, 1024, 64, 32, 8),
|
||||||
|
( 1, 2048, 2048, 64, 32, 8),
|
||||||
|
( 1, 4096, 4096, 64, 32, 8),
|
||||||
|
)
|
||||||
|
|
||||||
|
shapes_80 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 1024, 1024, 80, 32, 8),
|
||||||
|
( 1, 2048, 2048, 80, 32, 8),
|
||||||
|
( 1, 4096, 4096, 80, 32, 8),
|
||||||
|
)
|
||||||
|
|
||||||
|
shapes_128 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 1024, 1024, 128, 32, 8),
|
||||||
|
( 1, 2048, 2048, 128, 32, 8),
|
||||||
|
( 1, 4096, 4096, 128, 32, 8),
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
shapes = shapes_64 + shapes_80 + shapes_128
|
||||||
|
|
||||||
|
masks = [None, "bool", "causal"]
|
||||||
|
|
||||||
|
print(
|
||||||
|
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
for transpose in transposes:
|
||||||
|
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||||
|
for mask_in in masks:
|
||||||
|
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||||
|
B,
|
||||||
|
qsl,
|
||||||
|
ksl,
|
||||||
|
head_dim,
|
||||||
|
n_q_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
dtype,
|
||||||
|
transpose,
|
||||||
|
mask_in,
|
||||||
|
)
|
||||||
|
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||||
|
t_str = 1 if transpose else 0
|
||||||
|
print(
|
||||||
|
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
95
benchmarks/python/sdpa_vector_bench.py
Normal file
95
benchmarks/python/sdpa_vector_bench.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
L = 16384
|
||||||
|
H = 32
|
||||||
|
H_k = H // 4
|
||||||
|
D = 128
|
||||||
|
V = 128
|
||||||
|
dtype = mx.float16
|
||||||
|
loops = 10
|
||||||
|
|
||||||
|
|
||||||
|
def upproject(x, w):
|
||||||
|
if w is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x @ w.T
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q, k, v, mask=None, w=None):
|
||||||
|
def _sdpa(q, k, v):
|
||||||
|
B, Hq, L, D = q.shape
|
||||||
|
_, Hk, S, _ = k.shape
|
||||||
|
_, _, _, V = v.shape
|
||||||
|
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||||
|
k = k[:, :, None, :, :]
|
||||||
|
v = v[:, :, None, :, :]
|
||||||
|
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||||
|
if mask is not None:
|
||||||
|
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
|
||||||
|
s = mx.where(m, s, mx.finfo(s.dtype).min)
|
||||||
|
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||||
|
o = p @ v
|
||||||
|
return o.reshape(B, Hq, L, V)
|
||||||
|
|
||||||
|
for i in range(loops):
|
||||||
|
q = _sdpa(q, k, v)
|
||||||
|
q = upproject(q, w)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
def sdpa(q, k, v, mask=None, w=None):
|
||||||
|
for i in range(loops):
|
||||||
|
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||||
|
q = upproject(q, w)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
def time_self_attention_primitives():
|
||||||
|
mx.random.seed(3)
|
||||||
|
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||||
|
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
|
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||||
|
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||||
|
mx.eval(q, k, v, w)
|
||||||
|
time_fn(attention, q, k, v, w=w)
|
||||||
|
|
||||||
|
|
||||||
|
def time_self_attention_sdpa():
|
||||||
|
mx.random.seed(3)
|
||||||
|
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||||
|
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
|
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||||
|
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||||
|
mx.eval(q, k, v, w)
|
||||||
|
time_fn(sdpa, q, k, v, w=w)
|
||||||
|
|
||||||
|
|
||||||
|
def time_self_attention_sdpa_with_mask():
|
||||||
|
mx.random.seed(3)
|
||||||
|
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||||
|
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
|
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||||
|
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||||
|
mask = mx.full((L,), True)
|
||||||
|
mask[L // 2 :] = False
|
||||||
|
mx.eval(q, k, v, mask, w)
|
||||||
|
|
||||||
|
def sdpa_mask(*args):
|
||||||
|
return sdpa(*args, mask=mask, w=w)
|
||||||
|
|
||||||
|
def attention_mask(*args):
|
||||||
|
return attention(*args, mask=mask, w=w)
|
||||||
|
|
||||||
|
time_fn(attention_mask, q, k, v)
|
||||||
|
time_fn(sdpa_mask, q, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_self_attention_sdpa()
|
||||||
|
time_self_attention_primitives()
|
||||||
|
time_self_attention_sdpa_with_mask()
|
||||||
@@ -51,6 +51,20 @@ def time_maximum():
|
|||||||
time_fn(mx.maximum, a, b)
|
time_fn(mx.maximum, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_max():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.max, a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def time_min():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.min, a, 0)
|
||||||
|
|
||||||
|
|
||||||
def time_negative():
|
def time_negative():
|
||||||
a = mx.random.uniform(shape=(10000, 1000))
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
@@ -108,6 +122,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
time_add()
|
time_add()
|
||||||
time_matmul()
|
time_matmul()
|
||||||
|
time_min()
|
||||||
|
time_max()
|
||||||
time_maximum()
|
time_maximum()
|
||||||
time_exp()
|
time_exp()
|
||||||
time_negative()
|
time_negative()
|
||||||
|
|||||||
55
benchmarks/python/synchronize_bench.py
Normal file
55
benchmarks/python/synchronize_bench.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
rank = mx.distributed.init().rank()
|
||||||
|
|
||||||
|
|
||||||
|
def timeit(fn, a):
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(5):
|
||||||
|
mx.eval(fn(a))
|
||||||
|
|
||||||
|
its = 10
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(its):
|
||||||
|
mx.eval(fn(a))
|
||||||
|
toc = time.perf_counter()
|
||||||
|
ms = 1000 * (toc - tic) / its
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce_benchmark():
|
||||||
|
a = mx.ones((5, 5), mx.int32)
|
||||||
|
|
||||||
|
its_per_eval = 100
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
for _ in range(its_per_eval):
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
x = x - 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
ms = timeit(fn, a) / its_per_eval
|
||||||
|
if rank == 0:
|
||||||
|
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_benchmark():
|
||||||
|
a = mx.ones((5, 5), mx.int32)
|
||||||
|
its_per_eval = 100
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
for _ in range(its_per_eval):
|
||||||
|
x = mx.distributed.all_gather(x)[0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
ms = timeit(fn, a) / its_per_eval
|
||||||
|
if rank == 0:
|
||||||
|
print(f"All gather: time per iteration {ms:.6f} (ms)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
all_reduce_benchmark()
|
||||||
|
all_gather_benchmark()
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -6,6 +6,10 @@ import mlx.core as mx
|
|||||||
|
|
||||||
|
|
||||||
def time_fn(fn, *args, **kwargs):
|
def time_fn(fn, *args, **kwargs):
|
||||||
|
msg = kwargs.pop("msg", None)
|
||||||
|
if msg:
|
||||||
|
print(f"Timing {msg} ...", end=" ")
|
||||||
|
else:
|
||||||
print(f"Timing {fn.__name__} ...", end=" ")
|
print(f"Timing {fn.__name__} ...", end=" ")
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
@@ -20,3 +24,15 @@ def time_fn(fn, *args, **kwargs):
|
|||||||
|
|
||||||
msec = 1e3 * (toc - tic) / num_iters
|
msec = 1e3 * (toc - tic) / num_iters
|
||||||
print(f"{msec:.5f} msec")
|
print(f"{msec:.5f} msec")
|
||||||
|
|
||||||
|
|
||||||
|
def measure_runtime(fn, **kwargs):
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
fn(**kwargs)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
iters = 100
|
||||||
|
for _ in range(iters):
|
||||||
|
fn(**kwargs)
|
||||||
|
return (time.time() - tic) * 1000 / iters
|
||||||
|
|||||||
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
||||||
|
# directories.
|
||||||
|
|
||||||
|
set(NCCL_ROOT_DIR
|
||||||
|
$ENV{NCCL_ROOT_DIR}
|
||||||
|
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||||
|
|
||||||
|
find_path(
|
||||||
|
NCCL_INCLUDE_DIRS
|
||||||
|
NAMES nccl.h
|
||||||
|
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||||
|
|
||||||
|
if($ENV{USE_STATIC_NCCL})
|
||||||
|
message(
|
||||||
|
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||||
|
set(NCCL_LIBNAME "libnccl_static.a")
|
||||||
|
else()
|
||||||
|
set(NCCL_LIBNAME "nccl")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(
|
||||||
|
NCCL_LIBRARIES
|
||||||
|
NAMES ${NCCL_LIBNAME}
|
||||||
|
HINTS ${NCCL_LIB_DIR}
|
||||||
|
${NCCL_ROOT_DIR}
|
||||||
|
${NCCL_ROOT_DIR}/lib
|
||||||
|
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||||
|
${NCCL_ROOT_DIR}/lib64
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||||
|
NCCL_LIBRARIES)
|
||||||
|
|
||||||
|
if(NCCL_FOUND)
|
||||||
|
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||||
|
message(
|
||||||
|
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||||
|
file(
|
||||||
|
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||||
|
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||||
|
LIMIT_COUNT 1)
|
||||||
|
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||||
|
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||||
|
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||||
|
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||||
|
endif()
|
||||||
|
message(
|
||||||
|
STATUS
|
||||||
|
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
|
endif()
|
||||||
3
cmake/Findnvpl.cmake
Normal file
3
cmake/Findnvpl.cmake
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# This file does nothing but to suppress the cmake warning: "By not providing
|
||||||
|
# Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
|
||||||
|
# find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.
|
||||||
@@ -1,56 +1,50 @@
|
|||||||
include(CMakeParseArguments)
|
include(CMakeParseArguments)
|
||||||
|
|
||||||
|
# clang format off
|
||||||
|
#
|
||||||
# ##############################################################################
|
# ##############################################################################
|
||||||
# Build metal library
|
# Build metal library
|
||||||
#
|
#
|
||||||
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
||||||
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
||||||
#
|
#
|
||||||
# Args:
|
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||||
# TARGET: Custom target to be added for the metal library
|
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||||
# TITLE: Name of the .metallib
|
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||||
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
||||||
# SOURCES: List of source files
|
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
||||||
# INCLUDE_DIRS: List of include dirs
|
|
||||||
# DEPS: List of dependency files (like headers)
|
|
||||||
#
|
#
|
||||||
|
# clang format on
|
||||||
|
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
||||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
cmake_parse_arguments(
|
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
MTLLIB
|
|
||||||
""
|
|
||||||
"${oneValueArgs}"
|
|
||||||
"${multiValueArgs}"
|
|
||||||
${ARGN}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set output
|
# Set output
|
||||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||||
|
|
||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||||
|
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
||||||
|
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
||||||
|
-frecord-sources)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||||
COMMAND xcrun -sdk macosx metal
|
COMMAND
|
||||||
|
xcrun -sdk macosx metal
|
||||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||||
${MTLLIB_COMPILE_OPTIONS}
|
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||||
${MTLLIB_SOURCES}
|
|
||||||
-o ${MTLLIB_BUILD_TARGET}
|
|
||||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||||
COMMAND_EXPAND_LISTS
|
COMMAND_EXPAND_LISTS
|
||||||
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
||||||
VERBATIM
|
VERBATIM)
|
||||||
)
|
|
||||||
|
|
||||||
# Add metallib custom target
|
# Add metallib custom target
|
||||||
add_custom_target(
|
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
|
||||||
${MTLLIB_TARGET}
|
|
||||||
DEPENDS
|
|
||||||
${MTLLIB_BUILD_TARGET}
|
|
||||||
)
|
|
||||||
|
|
||||||
endmacro(mlx_build_metallib)
|
endmacro(mlx_build_metallib)
|
||||||
1
docs/.gitignore
vendored
1
docs/.gitignore
vendored
@@ -1,2 +1,3 @@
|
|||||||
src/python/_autosummary*/
|
src/python/_autosummary*/
|
||||||
src/python/nn/_autosummary*/
|
src/python/nn/_autosummary*/
|
||||||
|
src/python/optimizers/_autosummary*/
|
||||||
|
|||||||
50
docs/Doxyfile
Normal file
50
docs/Doxyfile
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
################################################################################
|
||||||
|
# Primary project setup. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
PROJECT_NAME = "MLX"
|
||||||
|
OUTPUT_DIRECTORY = build
|
||||||
|
XML_OUTPUT = xml
|
||||||
|
HTML_OUTPUT = html
|
||||||
|
STRIP_FROM_PATH = ../
|
||||||
|
INPUT = ../mlx
|
||||||
|
FILE_PATTERNS = *.h
|
||||||
|
EXCLUDE_PATTERNS = */private/*
|
||||||
|
CREATE_SUBDIRS = NO
|
||||||
|
FULL_PATH_NAMES = YES
|
||||||
|
RECURSIVE = YES
|
||||||
|
GENERATE_HTML = NO
|
||||||
|
GENERATE_LATEX = NO
|
||||||
|
GENERATE_XML = YES
|
||||||
|
XML_PROGRAMLISTING = YES
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Doxygen preprocessor / parser control. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
ENABLE_PREPROCESSING = YES
|
||||||
|
MACRO_EXPANSION = YES
|
||||||
|
EXPAND_ONLY_PREDEF = NO
|
||||||
|
SKIP_FUNCTION_MACROS = NO
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Compound extraction control. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
EXTRACT_ALL = YES
|
||||||
|
EXTRACT_PACKAGE = YES
|
||||||
|
EXTRACT_STATIC = YES
|
||||||
|
CASE_SENSE_NAMES = NO
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Docstring control / customization. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
JAVADOC_AUTOBRIEF = YES
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Warning suppression. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
QUIET = YES
|
||||||
|
WARN_IF_UNDOCUMENTED = NO
|
||||||
@@ -2,12 +2,16 @@
|
|||||||
|
|
||||||
### Setup (do once)
|
### Setup (do once)
|
||||||
|
|
||||||
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
|
Install Doxygen:
|
||||||
for example with `conda`:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
conda install sphinx
|
brew install doxygen
|
||||||
pip install sphinx-book-theme
|
```
|
||||||
|
|
||||||
|
Install Python packages:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Build
|
### Build
|
||||||
@@ -15,7 +19,7 @@ pip install sphinx-book-theme
|
|||||||
Build the docs from `mlx/docs/`
|
Build the docs from `mlx/docs/`
|
||||||
|
|
||||||
```
|
```
|
||||||
make html
|
doxygen && make html
|
||||||
```
|
```
|
||||||
|
|
||||||
View the docs by running a server in `mlx/docs/build/html/`:
|
View the docs by running a server in `mlx/docs/build/html/`:
|
||||||
|
|||||||
5
docs/requirements.txt
Normal file
5
docs/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
sphinx
|
||||||
|
breathe
|
||||||
|
sphinx-book-theme
|
||||||
|
sphinx-copybutton
|
||||||
|
mlx
|
||||||
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.2 MiB |
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 746 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 7.2 KiB After Width: | Height: | Size: 76 KiB |
BIN
docs/src/_static/mlx_logo_dark.png
Normal file
BIN
docs/src/_static/mlx_logo_dark.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 48 KiB |
20
docs/src/_templates/nn-module-template.rst
Normal file
20
docs/src/_templates/nn-module-template.rst
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
{{ fullname | escape | underline}}
|
||||||
|
|
||||||
|
.. currentmodule:: {{ module }}
|
||||||
|
|
||||||
|
.. autoclass:: {{ objname }}
|
||||||
|
|
||||||
|
{% block methods %}
|
||||||
|
|
||||||
|
{% if methods %}
|
||||||
|
.. rubric:: {{ _('Methods') }}
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
{% for item in methods %}
|
||||||
|
{%- if item not in inherited_members and item != "__init__" %}
|
||||||
|
~{{ name }}.{{ item }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{% endif %}
|
||||||
|
{% endblock %}
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ import mlx.core as mx
|
|||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "MLX"
|
project = "MLX"
|
||||||
copyright = "2023, MLX Contributors"
|
copyright = "2023, Apple"
|
||||||
author = "MLX Contributors"
|
author = "MLX Contributors"
|
||||||
version = ".".join(mx.__version__.split(".")[:3])
|
version = ".".join(mx.__version__.split(".")[:3])
|
||||||
release = version
|
release = version
|
||||||
@@ -18,26 +18,33 @@ release = version
|
|||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
extensions = [
|
extensions = [
|
||||||
|
"sphinx_copybutton",
|
||||||
"sphinx.ext.autodoc",
|
"sphinx.ext.autodoc",
|
||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
"sphinx.ext.napoleon",
|
"sphinx.ext.napoleon",
|
||||||
|
"breathe",
|
||||||
]
|
]
|
||||||
|
|
||||||
python_use_unqualified_type_names = True
|
python_use_unqualified_type_names = True
|
||||||
autosummary_generate = True
|
autosummary_generate = True
|
||||||
|
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
|
||||||
|
|
||||||
intersphinx_mapping = {
|
intersphinx_mapping = {
|
||||||
"https://docs.python.org/3": None,
|
"python": ("https://docs.python.org/3", None),
|
||||||
"https://numpy.org/doc/stable/": None,
|
"numpy": ("https://numpy.org/doc/stable/", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
breathe_projects = {"mlx": "../build/xml"}
|
||||||
|
breathe_default_project = "mlx"
|
||||||
|
|
||||||
templates_path = ["_templates"]
|
templates_path = ["_templates"]
|
||||||
html_static_path = ["_static"]
|
html_static_path = ["_static"]
|
||||||
source_suffix = ".rst"
|
source_suffix = ".rst"
|
||||||
master_doc = "index"
|
main_doc = "index"
|
||||||
highlight_language = "python"
|
highlight_language = "python"
|
||||||
pygments_style = "sphinx"
|
pygments_style = "sphinx"
|
||||||
|
add_module_names = False
|
||||||
|
|
||||||
# -- Options for HTML output -------------------------------------------------
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
||||||
@@ -48,11 +55,45 @@ html_theme_options = {
|
|||||||
"repository_url": "https://github.com/ml-explore/mlx",
|
"repository_url": "https://github.com/ml-explore/mlx",
|
||||||
"use_repository_button": True,
|
"use_repository_button": True,
|
||||||
"navigation_with_keys": False,
|
"navigation_with_keys": False,
|
||||||
|
"logo": {
|
||||||
|
"image_light": "_static/mlx_logo.png",
|
||||||
|
"image_dark": "_static/mlx_logo_dark.png",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
html_logo = "_static/mlx_logo.png"
|
html_favicon = html_theme_options["logo"]["image_light"]
|
||||||
|
|
||||||
|
|
||||||
# -- Options for HTMLHelp output ---------------------------------------------
|
# -- Options for HTMLHelp output ---------------------------------------------
|
||||||
|
|
||||||
htmlhelp_basename = "mlx_doc"
|
htmlhelp_basename = "mlx_doc"
|
||||||
|
|
||||||
|
|
||||||
|
def setup(app):
|
||||||
|
from sphinx.util import inspect
|
||||||
|
|
||||||
|
wrapped_isfunc = inspect.isfunction
|
||||||
|
|
||||||
|
def isfunc(obj):
|
||||||
|
type_name = str(type(obj))
|
||||||
|
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
|
||||||
|
return True
|
||||||
|
return wrapped_isfunc(obj)
|
||||||
|
|
||||||
|
inspect.isfunction = isfunc
|
||||||
|
|
||||||
|
|
||||||
|
# -- Options for LaTeX output ------------------------------------------------
|
||||||
|
|
||||||
|
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
|
||||||
|
latex_elements = {
|
||||||
|
"preamble": r"""
|
||||||
|
\usepackage{enumitem}
|
||||||
|
\setlistdepth{5}
|
||||||
|
\setlist[itemize,1]{label=$\bullet$}
|
||||||
|
\setlist[itemize,2]{label=$\bullet$}
|
||||||
|
\setlist[itemize,3]{label=$\bullet$}
|
||||||
|
\setlist[itemize,4]{label=$\bullet$}
|
||||||
|
\setlist[itemize,5]{label=$\bullet$}
|
||||||
|
\renewlist{itemize}{itemize}{5}
|
||||||
|
""",
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,4 +3,5 @@
|
|||||||
Operations
|
Operations
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
.. doxygengroup:: ops
|
||||||
|
:content-only:
|
||||||
|
|||||||
445
docs/src/dev/custom_metal_kernels.rst
Normal file
445
docs/src/dev/custom_metal_kernels.rst
Normal file
@@ -0,0 +1,445 @@
|
|||||||
|
.. _custom_metal_kernels:
|
||||||
|
|
||||||
|
Custom Metal Kernels
|
||||||
|
====================
|
||||||
|
|
||||||
|
MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||||
|
|
||||||
|
Simple Example
|
||||||
|
--------------
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
T tmp = inp[elem];
|
||||||
|
out[elem] = metal::exp(tmp);
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="myexp",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
|
outputs = kernel(
|
||||||
|
inputs=[a],
|
||||||
|
template=[("T", mx.float32)],
|
||||||
|
grid=(a.size, 1, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
output_shapes=[a.shape],
|
||||||
|
output_dtypes=[a.dtype],
|
||||||
|
)
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
|
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||||
|
b = exp_elementwise(a)
|
||||||
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
|
Every time you make a kernel, a new Metal library is created and possibly
|
||||||
|
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||||
|
:func:`fast.metal_kernel` and then use it many times.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Only pass the body of the Metal kernel in ``source``. The function
|
||||||
|
signature is generated automatically.
|
||||||
|
|
||||||
|
The full function signature will be generated using:
|
||||||
|
|
||||||
|
* The shapes/dtypes of ``inputs``
|
||||||
|
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
||||||
|
so we will add ``const device float16_t* inp`` to the signature.
|
||||||
|
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
||||||
|
in ``source``.
|
||||||
|
* The list of ``output_dtypes``
|
||||||
|
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
||||||
|
so we add ``device float16_t* out``.
|
||||||
|
* Template parameters passed using ``template``
|
||||||
|
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
|
||||||
|
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
||||||
|
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
||||||
|
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
||||||
|
These will be added as function arguments.
|
||||||
|
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
|
||||||
|
|
||||||
|
Putting this all together, the generated function signature for ``myexp`` is as follows:
|
||||||
|
|
||||||
|
.. code-block:: cpp
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[kernel]] void custom_kernel_myexp_float(
|
||||||
|
const device float16_t* inp [[buffer(0)]],
|
||||||
|
device float16_t* out [[buffer(1)]],
|
||||||
|
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
T tmp = inp[elem];
|
||||||
|
out[elem] = metal::exp(tmp);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||||
|
|
||||||
|
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||||
|
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||||
|
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||||
|
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||||
|
dimension should be less than or equal to the corresponding grid dimension.
|
||||||
|
|
||||||
|
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||||
|
generated code for debugging purposes.
|
||||||
|
|
||||||
|
Using Shape/Strides
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||||
|
is ``True`` by default. This will copy the array inputs if needed
|
||||||
|
before the kernel is launched to ensure that the memory layout is row
|
||||||
|
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||||
|
have to worry about gaps or the ordering of the dims when indexing.
|
||||||
|
|
||||||
|
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||||
|
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||||
|
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||||
|
the right elements for each thread.
|
||||||
|
|
||||||
|
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||||
|
relying on a copy from ``ensure_row_contiguous``:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||||
|
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||||
|
T tmp = inp[loc];
|
||||||
|
// Output arrays are always row contiguous
|
||||||
|
out[elem] = metal::exp(tmp);
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="myexp_strided",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source,
|
||||||
|
ensure_row_contiguous=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
|
outputs = kernel(
|
||||||
|
inputs=[a],
|
||||||
|
template=[("T", mx.float32)],
|
||||||
|
grid=(a.size, 1, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
output_shapes=[a.shape],
|
||||||
|
output_dtypes=[a.dtype],
|
||||||
|
)
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
|
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||||
|
# make non-contiguous
|
||||||
|
a = a[::2]
|
||||||
|
b = exp_elementwise(a)
|
||||||
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
|
Complex Example
|
||||||
|
-----------------------------
|
||||||
|
|
||||||
|
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
|
||||||
|
|
||||||
|
We'll start with the following MLX implementation using standard ops:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def grid_sample_ref(x, grid):
|
||||||
|
N, H_in, W_in, _ = x.shape
|
||||||
|
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||||
|
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||||
|
|
||||||
|
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||||
|
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||||
|
|
||||||
|
ix_ne = ix_nw + 1
|
||||||
|
iy_ne = iy_nw
|
||||||
|
|
||||||
|
ix_sw = ix_nw
|
||||||
|
iy_sw = iy_nw + 1
|
||||||
|
|
||||||
|
ix_se = ix_nw + 1
|
||||||
|
iy_se = iy_nw + 1
|
||||||
|
|
||||||
|
nw = (ix_se - ix) * (iy_se - iy)
|
||||||
|
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||||
|
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||||
|
se = (ix - ix_nw) * (iy - iy_nw)
|
||||||
|
|
||||||
|
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||||
|
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||||
|
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||||
|
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||||
|
|
||||||
|
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||||
|
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||||
|
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||||
|
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||||
|
|
||||||
|
I_nw *= mask_nw[..., None]
|
||||||
|
I_ne *= mask_ne[..., None]
|
||||||
|
I_sw *= mask_sw[..., None]
|
||||||
|
I_se *= mask_se[..., None]
|
||||||
|
|
||||||
|
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||||
|
to write a fast GPU kernel for both the forward and backward passes.
|
||||||
|
|
||||||
|
First we'll implement the forward pass as a fused kernel:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
int H = x_shape[1];
|
||||||
|
int W = x_shape[2];
|
||||||
|
int C = x_shape[3];
|
||||||
|
int gH = grid_shape[1];
|
||||||
|
int gW = grid_shape[2];
|
||||||
|
|
||||||
|
int w_stride = C;
|
||||||
|
int h_stride = W * w_stride;
|
||||||
|
int b_stride = H * h_stride;
|
||||||
|
|
||||||
|
uint grid_idx = elem / C * 2;
|
||||||
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
|
int ix_nw = floor(ix);
|
||||||
|
int iy_nw = floor(iy);
|
||||||
|
|
||||||
|
int ix_ne = ix_nw + 1;
|
||||||
|
int iy_ne = iy_nw;
|
||||||
|
|
||||||
|
int ix_sw = ix_nw;
|
||||||
|
int iy_sw = iy_nw + 1;
|
||||||
|
|
||||||
|
int ix_se = ix_nw + 1;
|
||||||
|
int iy_se = iy_nw + 1;
|
||||||
|
|
||||||
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
|
|
||||||
|
int batch_idx = elem / C / gH / gW * b_stride;
|
||||||
|
int channel_idx = elem % C;
|
||||||
|
int base_idx = batch_idx + channel_idx;
|
||||||
|
|
||||||
|
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||||
|
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||||
|
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||||
|
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||||
|
|
||||||
|
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||||
|
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||||
|
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||||
|
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||||
|
|
||||||
|
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="grid_sample",
|
||||||
|
input_names=["x", "grid"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
|
||||||
|
@mx.custom_function
|
||||||
|
def grid_sample(x, grid):
|
||||||
|
|
||||||
|
assert x.ndim == 4, "`x` must be 4D."
|
||||||
|
assert grid.ndim == 4, "`grid` must be 4D."
|
||||||
|
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
out_shape = (B, gN, gM, C)
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
|
outputs = kernel(
|
||||||
|
inputs=[x, grid],
|
||||||
|
template=[("T", x.dtype)],
|
||||||
|
output_shapes=[out_shape],
|
||||||
|
output_dtypes=[x.dtype],
|
||||||
|
grid=(np.prod(out_shape), 1, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
)
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
|
For a reasonably sized input such as:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
x.shape = (8, 1024, 1024, 64)
|
||||||
|
grid.shape = (8, 256, 256, 2)
|
||||||
|
|
||||||
|
On an M1 Max, we see a big performance improvement:
|
||||||
|
|
||||||
|
``55.7ms -> 6.7ms => 8x speed up``
|
||||||
|
|
||||||
|
Grid Sample VJP
|
||||||
|
---------------
|
||||||
|
|
||||||
|
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||||
|
define its custom vjp transform so MLX can differentiate it.
|
||||||
|
|
||||||
|
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||||
|
requires a few extra :func:`fast.metal_kernel` features:
|
||||||
|
|
||||||
|
* ``init_value=0``
|
||||||
|
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||||
|
|
||||||
|
* ``atomic_outputs=True``
|
||||||
|
Designate all of the kernel outputs as ``atomic`` in the function signature.
|
||||||
|
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
|
||||||
|
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
|
||||||
|
|
||||||
|
We can then implement the backwards pass as follows:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
int H = x_shape[1];
|
||||||
|
int W = x_shape[2];
|
||||||
|
int C = x_shape[3];
|
||||||
|
// Pad C to the nearest larger simdgroup size multiple
|
||||||
|
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||||
|
|
||||||
|
int gH = grid_shape[1];
|
||||||
|
int gW = grid_shape[2];
|
||||||
|
|
||||||
|
int w_stride = C;
|
||||||
|
int h_stride = W * w_stride;
|
||||||
|
int b_stride = H * h_stride;
|
||||||
|
|
||||||
|
uint grid_idx = elem / C_padded * 2;
|
||||||
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
|
int ix_nw = floor(ix);
|
||||||
|
int iy_nw = floor(iy);
|
||||||
|
|
||||||
|
int ix_ne = ix_nw + 1;
|
||||||
|
int iy_ne = iy_nw;
|
||||||
|
|
||||||
|
int ix_sw = ix_nw;
|
||||||
|
int iy_sw = iy_nw + 1;
|
||||||
|
|
||||||
|
int ix_se = ix_nw + 1;
|
||||||
|
int iy_se = iy_nw + 1;
|
||||||
|
|
||||||
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
|
|
||||||
|
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||||
|
int channel_idx = elem % C_padded;
|
||||||
|
int base_idx = batch_idx + channel_idx;
|
||||||
|
|
||||||
|
T gix = T(0);
|
||||||
|
T giy = T(0);
|
||||||
|
if (channel_idx < C) {
|
||||||
|
int cot_index = elem / C_padded * C + channel_idx;
|
||||||
|
T cot = cotangent[cot_index];
|
||||||
|
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||||
|
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||||
|
|
||||||
|
T I_nw = x[offset];
|
||||||
|
gix -= I_nw * (iy_se - iy) * cot;
|
||||||
|
giy -= I_nw * (ix_se - ix) * cot;
|
||||||
|
}
|
||||||
|
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||||
|
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||||
|
|
||||||
|
T I_ne = x[offset];
|
||||||
|
gix += I_ne * (iy_sw - iy) * cot;
|
||||||
|
giy -= I_ne * (ix - ix_sw) * cot;
|
||||||
|
}
|
||||||
|
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||||
|
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||||
|
|
||||||
|
T I_sw = x[offset];
|
||||||
|
gix -= I_sw * (iy - iy_ne) * cot;
|
||||||
|
giy += I_sw * (ix_ne - ix) * cot;
|
||||||
|
}
|
||||||
|
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||||
|
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||||
|
|
||||||
|
T I_se = x[offset];
|
||||||
|
gix += I_se * (iy - iy_nw) * cot;
|
||||||
|
giy += I_se * (ix - ix_nw) * cot;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
T gix_mult = W / 2;
|
||||||
|
T giy_mult = H / 2;
|
||||||
|
|
||||||
|
// Reduce across each simdgroup first.
|
||||||
|
// This is much faster than relying purely on atomics.
|
||||||
|
gix = simd_sum(gix);
|
||||||
|
giy = simd_sum(giy);
|
||||||
|
|
||||||
|
if (thread_index_in_simdgroup == 0) {
|
||||||
|
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||||
|
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="grid_sample_grad",
|
||||||
|
input_names=["x", "grid", "cotangent"],
|
||||||
|
output_names=["x_grad", "grid_grad"],
|
||||||
|
source=source,
|
||||||
|
atomic_outputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@grid_sample.vjp
|
||||||
|
def grid_sample_vjp(primals, cotangent, _):
|
||||||
|
x, grid = primals
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
|
# pad the output channels to simd group size
|
||||||
|
# so that our `simd_sum`s don't overlap.
|
||||||
|
simdgroup_size = 32
|
||||||
|
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||||
|
grid_size = B * gN * gM * C_padded
|
||||||
|
outputs = kernel(
|
||||||
|
inputs=[x, grid, cotangent],
|
||||||
|
template=[("T", x.dtype)],
|
||||||
|
output_shapes=[x.shape, grid.shape],
|
||||||
|
output_dtypes=[x.dtype, x.dtype],
|
||||||
|
grid=(grid_size, 1, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
init_value=0,
|
||||||
|
)
|
||||||
|
return outputs[0], outputs[1]
|
||||||
|
|
||||||
|
There's an even larger speed up for the vjp:
|
||||||
|
|
||||||
|
``676.4ms -> 16.7ms => 40x speed up``
|
||||||
File diff suppressed because it is too large
Load Diff
68
docs/src/dev/metal_debugger.rst
Normal file
68
docs/src/dev/metal_debugger.rst
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
Metal Debugger
|
||||||
|
==============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
Profiling is a key step for performance optimization. You can build MLX with
|
||||||
|
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
|
||||||
|
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
|
||||||
|
|
||||||
|
* Records source during Metal compilation, for later inspection while
|
||||||
|
debugging.
|
||||||
|
* Labels Metal objects such as command queues, improving capture readability.
|
||||||
|
|
||||||
|
To build with debugging enabled in Python prepend
|
||||||
|
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
|
||||||
|
|
||||||
|
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
|
||||||
|
work.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
To capture a GPU trace you must run the application with
|
||||||
|
``MTL_CAPTURE_ENABLED=1``.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
a = mx.random.uniform(shape=(512, 512))
|
||||||
|
b = mx.random.uniform(shape=(512, 512))
|
||||||
|
mx.eval(a, b)
|
||||||
|
|
||||||
|
trace_file = "mlx_trace.gputrace"
|
||||||
|
|
||||||
|
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
|
||||||
|
# that the path trace_file does not already exist.
|
||||||
|
mx.metal.start_capture(trace_file)
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
mx.eval(mx.add(a, b))
|
||||||
|
|
||||||
|
mx.metal.stop_capture()
|
||||||
|
|
||||||
|
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
|
||||||
|
has a great overview of all operations. Checkout the `Metal debugger
|
||||||
|
documentation`_ for more information.
|
||||||
|
|
||||||
|
.. image:: ../_static/metal_debugger/capture.png
|
||||||
|
:class: dark-light
|
||||||
|
|
||||||
|
Xcode Workflow
|
||||||
|
--------------
|
||||||
|
|
||||||
|
You can skip saving to a path by running within Xcode. First, generate an
|
||||||
|
Xcode project using CMake.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
mkdir build && cd build
|
||||||
|
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
|
||||||
|
open mlx.xcodeproj
|
||||||
|
|
||||||
|
Select the ``metal_capture`` example schema and run.
|
||||||
|
|
||||||
|
.. image:: ../_static/metal_debugger/schema.png
|
||||||
|
:class: dark-light
|
||||||
|
|
||||||
|
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger
|
||||||
121
docs/src/dev/mlx_in_cpp.rst
Normal file
121
docs/src/dev/mlx_in_cpp.rst
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
.. _mlx_in_cpp:
|
||||||
|
|
||||||
|
Using MLX in C++
|
||||||
|
================
|
||||||
|
|
||||||
|
You can use MLX in a C++ project with CMake.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
This guide is based one the following `example using MLX in C++
|
||||||
|
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
|
||||||
|
|
||||||
|
First install MLX:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install -U mlx
|
||||||
|
|
||||||
|
You can also install the MLX Python package from source or just the C++
|
||||||
|
library. For more information see the :ref:`documentation on installing MLX
|
||||||
|
<build_and_install>`.
|
||||||
|
|
||||||
|
Next make an example program in ``example.cpp``:
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
auto x = mx::array({1, 2, 3});
|
||||||
|
auto y = mx::array({1, 2, 3});
|
||||||
|
std::cout << x + y << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
The next step is to setup a CMake file in ``CMakeLists.txt``:
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
cmake_minimum_required(VERSION 3.27)
|
||||||
|
|
||||||
|
project(example LANGUAGES CXX)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
|
|
||||||
|
Depending on how you installed MLX, you may need to tell CMake where to
|
||||||
|
find it.
|
||||||
|
|
||||||
|
If you installed MLX with Python, then add the following to the CMake file:
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
find_package(
|
||||||
|
Python 3.9
|
||||||
|
COMPONENTS Interpreter Development.Module
|
||||||
|
REQUIRED)
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE MLX_ROOT)
|
||||||
|
|
||||||
|
If you installed the MLX C++ package to a system path, then CMake should be
|
||||||
|
able to find it. If you installed it to a non-standard location or CMake can't
|
||||||
|
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
set(MLX_ROOT "/path/to/mlx/")
|
||||||
|
|
||||||
|
Next, instruct CMake to find MLX:
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
find_package(MLX CONFIG REQUIRED)
|
||||||
|
|
||||||
|
Finally, add the ``example.cpp`` program as an executable and link MLX.
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
add_executable(example example.cpp)
|
||||||
|
target_link_libraries(example PRIVATE mlx)
|
||||||
|
|
||||||
|
You can build the example with:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||||
|
cmake --build build
|
||||||
|
|
||||||
|
And run it with:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./build/example
|
||||||
|
|
||||||
|
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
|
||||||
|
|
||||||
|
.. list-table:: Package Variables
|
||||||
|
:widths: 20 20
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Variable
|
||||||
|
- Description
|
||||||
|
* - MLX_FOUND
|
||||||
|
- ``True`` if MLX is found
|
||||||
|
* - MLX_INCLUDE_DIRS
|
||||||
|
- Include directory
|
||||||
|
* - MLX_LIBRARIES
|
||||||
|
- Libraries to link against
|
||||||
|
* - MLX_CXX_FLAGS
|
||||||
|
- Additional compiler flags
|
||||||
|
* - MLX_BUILD_ACCELERATE
|
||||||
|
- ``True`` if MLX was built with Accelerate
|
||||||
|
* - MLX_BUILD_METAL
|
||||||
|
- ``True`` if MLX was built with Metal
|
||||||
@@ -15,7 +15,7 @@ module to concisely define the model architecture.
|
|||||||
Attention layer
|
Attention layer
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
We will start with the llama attention layer which notably uses the RoPE
|
We will start with the Llama attention layer which notably uses the RoPE
|
||||||
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
||||||
key/value cache that will be concatenated with the provided keys and values to
|
key/value cache that will be concatenated with the provided keys and values to
|
||||||
support efficient inference.
|
support efficient inference.
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ set:
|
|||||||
Next, setup the problem parameters and load the data. To load the data, you need our
|
Next, setup the problem parameters and load the data. To load the data, you need our
|
||||||
`mnist data loader
|
`mnist data loader
|
||||||
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
||||||
we will import as `mnist`.
|
we will import as ``mnist``.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ are the CPU and GPU.
|
|||||||
usage/function_transforms
|
usage/function_transforms
|
||||||
usage/compile
|
usage/compile
|
||||||
usage/numpy
|
usage/numpy
|
||||||
|
usage/distributed
|
||||||
usage/using_streams
|
usage/using_streams
|
||||||
|
usage/export
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:caption: Examples
|
:caption: Examples
|
||||||
@@ -58,14 +60,21 @@ are the CPU and GPU.
|
|||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
python/array
|
python/array
|
||||||
|
python/data_types
|
||||||
python/devices_and_streams
|
python/devices_and_streams
|
||||||
|
python/export
|
||||||
python/ops
|
python/ops
|
||||||
python/random
|
python/random
|
||||||
python/transforms
|
python/transforms
|
||||||
|
python/fast
|
||||||
python/fft
|
python/fft
|
||||||
python/linalg
|
python/linalg
|
||||||
|
python/metal
|
||||||
|
python/cuda
|
||||||
|
python/memory_management
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
|
python/distributed
|
||||||
python/tree_utils
|
python/tree_utils
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
@@ -79,3 +88,6 @@ are the CPU and GPU.
|
|||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
dev/extensions
|
dev/extensions
|
||||||
|
dev/metal_debugger
|
||||||
|
dev/custom_metal_kernels
|
||||||
|
dev/mlx_in_cpp
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
.. _build_and_install:
|
||||||
|
|
||||||
Build and Install
|
Build and Install
|
||||||
=================
|
=================
|
||||||
|
|
||||||
@@ -11,22 +13,48 @@ silicon computer is
|
|||||||
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
|
|
||||||
To install from PyPI you must meet the following requirements:
|
To install from PyPI your system must meet the following requirements:
|
||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.8
|
- Using a native Python >= 3.10
|
||||||
- macOS >= 13.3
|
- macOS >= 14.0
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
MLX is only available on devices running macOS >= 13.3
|
MLX is only available on devices running macOS >= 14.0 and higher.
|
||||||
It is highly recommended to use macOS 14 (Sonoma)
|
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
MLX is also available on conda-forge. To install MLX with conda do:
|
MLX has a CUDA backend which you can install with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
conda install conda-forge::mlx
|
pip install mlx[cuda]
|
||||||
|
|
||||||
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Nvidia architecture >= SM 7.0 (Volta)
|
||||||
|
- Nvidia driver >= 550.54.14
|
||||||
|
- CUDA toolkit >= 12.0
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.10
|
||||||
|
|
||||||
|
|
||||||
|
CPU-only (Linux)
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
For a CPU-only version of MLX that runs on Linux use:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install mlx[cpu]
|
||||||
|
|
||||||
|
To install the CPU-only package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.10
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
@@ -53,8 +81,8 @@ Build Requirements
|
|||||||
^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
|
||||||
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
|
- Xcode >= 15.0 and macOS SDK >= 14.0
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
|
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
|
||||||
@@ -63,6 +91,8 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
.. _python install:
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@@ -70,44 +100,43 @@ To build and install the MLX python library from source, first, clone MLX from
|
|||||||
|
|
||||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||||
|
|
||||||
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
|
Then simply build and install MLX using pip:
|
||||||
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install "pybind11[global]"
|
pip install .
|
||||||
conda install pybind11
|
|
||||||
brew install pybind11
|
|
||||||
|
|
||||||
Then simply build and install it using pip:
|
For developing, install the package with development dependencies, and use an
|
||||||
|
editable install:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
For developing use an editable install:
|
Once the development dependencies are installed, you can build faster with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
|
python setup.py build_ext --inplace
|
||||||
|
|
||||||
To make sure the install is working run the tests with:
|
Run the tests with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install ".[testing]"
|
|
||||||
python -m unittest discover python/tests
|
python -m unittest discover python/tests
|
||||||
|
|
||||||
Optional: Install stubs to enable auto completions and type checking from your IDE:
|
Optional: Install stubs to enable auto completions and type checking from your
|
||||||
|
IDE:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install ".[dev]"
|
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
|
|
||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
.. _cpp install:
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@@ -156,9 +185,18 @@ should point to the path to the built metal library.
|
|||||||
- OFF
|
- OFF
|
||||||
* - MLX_BUILD_METAL
|
* - MLX_BUILD_METAL
|
||||||
- ON
|
- ON
|
||||||
|
* - MLX_BUILD_CPU
|
||||||
|
- ON
|
||||||
* - MLX_BUILD_PYTHON_BINDINGS
|
* - MLX_BUILD_PYTHON_BINDINGS
|
||||||
- OFF
|
- OFF
|
||||||
|
* - MLX_METAL_DEBUG
|
||||||
|
- OFF
|
||||||
|
* - MLX_BUILD_SAFETENSORS
|
||||||
|
- ON
|
||||||
|
* - MLX_BUILD_GGUF
|
||||||
|
- ON
|
||||||
|
* - MLX_METAL_JIT
|
||||||
|
- OFF
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@@ -177,10 +215,82 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
|
Binary Size Minimization
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
|
||||||
|
and ``BUILD_SHARED_LIBS=ON``.
|
||||||
|
|
||||||
|
The MLX CMake build has several additional options to make smaller binaries.
|
||||||
|
For example, if you don't need the CPU backend or support for safetensors and
|
||||||
|
GGUF, you can do:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
cmake .. \
|
||||||
|
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
|
-DMLX_BUILD_CPU=OFF \
|
||||||
|
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||||
|
-DMLX_BUILD_GGUF=OFF \
|
||||||
|
-DMLX_METAL_JIT=ON
|
||||||
|
|
||||||
|
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
|
||||||
|
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||||
|
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||||
|
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||||
|
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||||
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
|
Linux
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||||
|
For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
apt-get update -y
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
From here follow the instructions to install either the :ref:`Python <python
|
||||||
|
install>` or :ref:`C++ <cpp install>` APIs.
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||||
|
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
apt-get update -y
|
||||||
|
apt-get -y install cuda-toolkit-12-9
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
||||||
|
|
||||||
|
|
||||||
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||||
|
|
||||||
|
To build the C++ package run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
||||||
Metal not found
|
Metal not found
|
||||||
~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -207,7 +317,7 @@ x86 Shell
|
|||||||
|
|
||||||
.. _build shell:
|
.. _build shell:
|
||||||
|
|
||||||
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||||
Rosetta instead of natively.
|
Rosetta instead of natively.
|
||||||
|
|
||||||
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
||||||
@@ -231,4 +341,4 @@ Also check that cmake is using the correct architecture:
|
|||||||
|
|
||||||
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
||||||
but the build errors out with "Building for x86_64 on macOS is not supported."
|
but the build errors out with "Building for x86_64 on macOS is not supported."
|
||||||
wipe your build cahce with ``rm -rf build/`` and try again.
|
wipe your build cache with ``rm -rf build/`` and try again.
|
||||||
|
|||||||
@@ -10,27 +10,42 @@ Array
|
|||||||
|
|
||||||
array
|
array
|
||||||
array.astype
|
array.astype
|
||||||
|
array.at
|
||||||
array.item
|
array.item
|
||||||
array.tolist
|
array.tolist
|
||||||
array.dtype
|
array.dtype
|
||||||
|
array.itemsize
|
||||||
|
array.nbytes
|
||||||
array.ndim
|
array.ndim
|
||||||
array.shape
|
array.shape
|
||||||
array.size
|
array.size
|
||||||
Dtype
|
array.real
|
||||||
|
array.imag
|
||||||
array.abs
|
array.abs
|
||||||
array.all
|
array.all
|
||||||
array.any
|
array.any
|
||||||
array.argmax
|
array.argmax
|
||||||
array.argmin
|
array.argmin
|
||||||
|
array.conj
|
||||||
array.cos
|
array.cos
|
||||||
array.dtype
|
array.cummax
|
||||||
|
array.cummin
|
||||||
|
array.cumprod
|
||||||
|
array.cumsum
|
||||||
|
array.diag
|
||||||
|
array.diagonal
|
||||||
array.exp
|
array.exp
|
||||||
|
array.flatten
|
||||||
array.log
|
array.log
|
||||||
|
array.log10
|
||||||
array.log1p
|
array.log1p
|
||||||
|
array.log2
|
||||||
|
array.logcumsumexp
|
||||||
array.logsumexp
|
array.logsumexp
|
||||||
array.max
|
array.max
|
||||||
array.mean
|
array.mean
|
||||||
array.min
|
array.min
|
||||||
|
array.moveaxis
|
||||||
array.prod
|
array.prod
|
||||||
array.reciprocal
|
array.reciprocal
|
||||||
array.reshape
|
array.reshape
|
||||||
@@ -40,7 +55,11 @@ Array
|
|||||||
array.split
|
array.split
|
||||||
array.sqrt
|
array.sqrt
|
||||||
array.square
|
array.square
|
||||||
|
array.squeeze
|
||||||
|
array.std
|
||||||
array.sum
|
array.sum
|
||||||
|
array.swapaxes
|
||||||
array.transpose
|
array.transpose
|
||||||
array.T
|
array.T
|
||||||
array.var
|
array.var
|
||||||
|
array.view
|
||||||
|
|||||||
9
docs/src/python/cuda.rst
Normal file
9
docs/src/python/cuda.rst
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
CUDA
|
||||||
|
=====
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.cuda
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
is_available
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
.. _data_types:
|
.. _data_types:
|
||||||
|
|
||||||
:orphan:
|
|
||||||
|
|
||||||
Data Types
|
Data Types
|
||||||
==========
|
==========
|
||||||
|
|
||||||
@@ -44,9 +42,37 @@ The default floating point type is ``float32`` and the default integer type is
|
|||||||
* - ``int64``
|
* - ``int64``
|
||||||
- 8
|
- 8
|
||||||
- 64-bit signed integer
|
- 64-bit signed integer
|
||||||
|
* - ``bfloat16``
|
||||||
|
- 2
|
||||||
|
- 16-bit brain float (e8, m7)
|
||||||
* - ``float16``
|
* - ``float16``
|
||||||
- 2
|
- 2
|
||||||
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
|
- 16-bit IEEE float (e5, m10)
|
||||||
* - ``float32``
|
* - ``float32``
|
||||||
- 4
|
- 4
|
||||||
- 32-bit float
|
- 32-bit float
|
||||||
|
* - ``float64``
|
||||||
|
- 4
|
||||||
|
- 64-bit double
|
||||||
|
* - ``complex64``
|
||||||
|
- 8
|
||||||
|
- 64-bit complex float
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Arrays with type ``float64`` only work with CPU operations. Using
|
||||||
|
``float64`` arrays on the GPU will result in an exception.
|
||||||
|
|
||||||
|
|
||||||
|
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||||
|
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||||
|
``dtype`` (or category) is a subtype of another category.
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Dtype
|
||||||
|
DtypeCategory
|
||||||
|
issubdtype
|
||||||
|
finfo
|
||||||
|
|||||||
@@ -9,9 +9,11 @@ Devices and Streams
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
Device
|
Device
|
||||||
|
Stream
|
||||||
default_device
|
default_device
|
||||||
set_default_device
|
set_default_device
|
||||||
Stream
|
|
||||||
default_stream
|
default_stream
|
||||||
new_stream
|
new_stream
|
||||||
set_default_stream
|
set_default_stream
|
||||||
|
stream
|
||||||
|
synchronize
|
||||||
|
|||||||
22
docs/src/python/distributed.rst
Normal file
22
docs/src/python/distributed.rst
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
.. _distributed:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.distributed
|
||||||
|
|
||||||
|
Distributed Communication
|
||||||
|
==========================
|
||||||
|
|
||||||
|
MLX provides a distributed communication package using MPI. The MPI library is
|
||||||
|
loaded at runtime; if MPI is available then distributed communication is also
|
||||||
|
made available.
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Group
|
||||||
|
is_available
|
||||||
|
init
|
||||||
|
all_sum
|
||||||
|
all_gather
|
||||||
|
send
|
||||||
|
recv
|
||||||
|
recv_like
|
||||||
14
docs/src/python/export.rst
Normal file
14
docs/src/python/export.rst
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
.. _export:
|
||||||
|
|
||||||
|
Export Functions
|
||||||
|
================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
export_function
|
||||||
|
import_function
|
||||||
|
exporter
|
||||||
|
export_to_dot
|
||||||
16
docs/src/python/fast.rst
Normal file
16
docs/src/python/fast.rst
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
.. _fast:
|
||||||
|
|
||||||
|
Fast
|
||||||
|
====
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.fast
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
rms_norm
|
||||||
|
layer_norm
|
||||||
|
rope
|
||||||
|
scaled_dot_product_attention
|
||||||
|
metal_kernel
|
||||||
|
cuda_kernel
|
||||||
@@ -20,3 +20,5 @@ FFT
|
|||||||
irfft2
|
irfft2
|
||||||
rfftn
|
rfftn
|
||||||
irfftn
|
irfftn
|
||||||
|
fftshift
|
||||||
|
ifftshift
|
||||||
|
|||||||
@@ -8,5 +8,20 @@ Linear Algebra
|
|||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
inv
|
||||||
|
tri_inv
|
||||||
norm
|
norm
|
||||||
|
cholesky
|
||||||
|
cholesky_inv
|
||||||
|
cross
|
||||||
qr
|
qr
|
||||||
|
svd
|
||||||
|
eigvals
|
||||||
|
eig
|
||||||
|
eigvalsh
|
||||||
|
eigh
|
||||||
|
lu
|
||||||
|
lu_factor
|
||||||
|
pinv
|
||||||
|
solve
|
||||||
|
solve_triangular
|
||||||
|
|||||||
16
docs/src/python/memory_management.rst
Normal file
16
docs/src/python/memory_management.rst
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
Memory Management
|
||||||
|
=================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
get_active_memory
|
||||||
|
get_peak_memory
|
||||||
|
reset_peak_memory
|
||||||
|
get_cache_memory
|
||||||
|
set_memory_limit
|
||||||
|
set_cache_limit
|
||||||
|
set_wired_limit
|
||||||
|
clear_cache
|
||||||
12
docs/src/python/metal.rst
Normal file
12
docs/src/python/metal.rst
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
Metal
|
||||||
|
=====
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.metal
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
is_available
|
||||||
|
device_info
|
||||||
|
start_capture
|
||||||
|
stop_capture
|
||||||
@@ -173,6 +173,8 @@ In detail:
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
value_and_grad
|
value_and_grad
|
||||||
|
quantize
|
||||||
|
average_gradients
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
|
||||||
|
|||||||
@@ -12,13 +12,29 @@ simple functions.
|
|||||||
:toctree: _autosummary_functions
|
:toctree: _autosummary_functions
|
||||||
:template: nn-module-template.rst
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
elu
|
||||||
|
celu
|
||||||
gelu
|
gelu
|
||||||
gelu_approx
|
gelu_approx
|
||||||
gelu_fast_approx
|
gelu_fast_approx
|
||||||
|
glu
|
||||||
|
hard_shrink
|
||||||
|
hard_tanh
|
||||||
|
hardswish
|
||||||
|
leaky_relu
|
||||||
|
log_sigmoid
|
||||||
|
log_softmax
|
||||||
mish
|
mish
|
||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
|
relu2
|
||||||
|
relu6
|
||||||
selu
|
selu
|
||||||
softshrink
|
sigmoid
|
||||||
silu
|
silu
|
||||||
|
softmax
|
||||||
|
softmin
|
||||||
|
softplus
|
||||||
|
softshrink
|
||||||
step
|
step
|
||||||
|
tanh
|
||||||
|
|||||||
@@ -10,29 +10,61 @@ Layers
|
|||||||
:template: nn-module-template.rst
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
ALiBi
|
ALiBi
|
||||||
|
AvgPool1d
|
||||||
|
AvgPool2d
|
||||||
|
AvgPool3d
|
||||||
BatchNorm
|
BatchNorm
|
||||||
|
CELU
|
||||||
Conv1d
|
Conv1d
|
||||||
Conv2d
|
Conv2d
|
||||||
|
Conv3d
|
||||||
|
ConvTranspose1d
|
||||||
|
ConvTranspose2d
|
||||||
|
ConvTranspose3d
|
||||||
Dropout
|
Dropout
|
||||||
Dropout2d
|
Dropout2d
|
||||||
Dropout3d
|
Dropout3d
|
||||||
Embedding
|
Embedding
|
||||||
|
ELU
|
||||||
GELU
|
GELU
|
||||||
|
GLU
|
||||||
GroupNorm
|
GroupNorm
|
||||||
|
GRU
|
||||||
|
HardShrink
|
||||||
|
HardTanh
|
||||||
|
Hardswish
|
||||||
InstanceNorm
|
InstanceNorm
|
||||||
LayerNorm
|
LayerNorm
|
||||||
|
LeakyReLU
|
||||||
Linear
|
Linear
|
||||||
|
LogSigmoid
|
||||||
|
LogSoftmax
|
||||||
|
LSTM
|
||||||
|
MaxPool1d
|
||||||
|
MaxPool2d
|
||||||
|
MaxPool3d
|
||||||
Mish
|
Mish
|
||||||
MultiHeadAttention
|
MultiHeadAttention
|
||||||
PReLU
|
PReLU
|
||||||
|
QuantizedEmbedding
|
||||||
QuantizedLinear
|
QuantizedLinear
|
||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
|
ReLU2
|
||||||
|
ReLU6
|
||||||
|
RNN
|
||||||
RoPE
|
RoPE
|
||||||
SELU
|
SELU
|
||||||
Sequential
|
Sequential
|
||||||
|
Sigmoid
|
||||||
SiLU
|
SiLU
|
||||||
SinusoidalPositionalEncoding
|
SinusoidalPositionalEncoding
|
||||||
|
Softmin
|
||||||
Softshrink
|
Softshrink
|
||||||
|
Softsign
|
||||||
|
Softmax
|
||||||
|
Softplus
|
||||||
Step
|
Step
|
||||||
|
Tanh
|
||||||
Transformer
|
Transformer
|
||||||
|
Upsample
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ Module
|
|||||||
Module.named_modules
|
Module.named_modules
|
||||||
Module.parameters
|
Module.parameters
|
||||||
Module.save_weights
|
Module.save_weights
|
||||||
|
Module.set_dtype
|
||||||
Module.train
|
Module.train
|
||||||
Module.trainable_parameters
|
Module.trainable_parameters
|
||||||
Module.unfreeze
|
Module.unfreeze
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user