mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
738 Commits
v0.19.0
...
sign-warns
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24828b1b2f | ||
|
|
9f649b5658 | ||
|
|
18aa921388 | ||
|
|
8d13a0bc6b | ||
|
|
ac75c87fd7 | ||
|
|
7107802e09 | ||
|
|
c5913131cf | ||
|
|
19ab7911f6 | ||
|
|
4a1b1796b7 | ||
|
|
b48d298205 | ||
|
|
8277e71ea9 | ||
|
|
b0d985416a | ||
|
|
8d10f3ec75 | ||
|
|
6343622c67 | ||
|
|
979abf462b | ||
|
|
981d2fdaf0 | ||
|
|
5a306d3495 | ||
|
|
5baa361779 | ||
|
|
1bac0db7e3 | ||
|
|
a1212b4e44 | ||
|
|
45a8b226af | ||
|
|
76ef1e98f3 | ||
|
|
63d91557e0 | ||
|
|
310e501e6a | ||
|
|
cacc3ab7fd | ||
|
|
53525cba23 | ||
|
|
3d67b717a0 | ||
|
|
953b2f5be2 | ||
|
|
26f7155537 | ||
|
|
66fcb9fe94 | ||
|
|
d1e06117e8 | ||
|
|
539d8322d1 | ||
|
|
c4767d110f | ||
|
|
895217f25b | ||
|
|
0cfeeb60ca | ||
|
|
8f8af61a37 | ||
|
|
233384161e | ||
|
|
5bcf3a6794 | ||
|
|
7707196297 | ||
|
|
7e3471c987 | ||
|
|
9f0ba3ddf1 | ||
|
|
4bce5f9b2d | ||
|
|
e9eab527eb | ||
|
|
36ca62dba8 | ||
|
|
9cbb1b0148 | ||
|
|
9bfc476d72 | ||
|
|
25e2356316 | ||
|
|
226a1d24e0 | ||
|
|
630350ad3e | ||
|
|
380aeb58ae | ||
|
|
f37389d100 | ||
|
|
e89e8b4272 | ||
|
|
85a8824a8c | ||
|
|
f5d4397e5c | ||
|
|
343e33b6d5 | ||
|
|
0073096dd1 | ||
|
|
e3d004fed9 | ||
|
|
a393435d28 | ||
|
|
a7a94b29d7 | ||
|
|
22a5da76c8 | ||
|
|
287c63a093 | ||
|
|
1c9ae1eaa1 | ||
|
|
c2c3e0b0a2 | ||
|
|
b0cc71ae71 | ||
|
|
e88f2d4a8e | ||
|
|
9cee557423 | ||
|
|
bbf1423953 | ||
|
|
eb24267b56 | ||
|
|
dc371ae7a5 | ||
|
|
e76a8dd5c5 | ||
|
|
b466dea982 | ||
|
|
7a6adda1e6 | ||
|
|
1a9f820af6 | ||
|
|
d4f4ff3c5e | ||
|
|
7c7e48dbd1 | ||
|
|
fbbf3b9b3e | ||
|
|
bf01ad9367 | ||
|
|
ae438d05fa | ||
|
|
711a645807 | ||
|
|
aa9d44b3d4 | ||
|
|
ec2ab42888 | ||
|
|
787c0d90cd | ||
|
|
e8b604a6a3 | ||
|
|
50cc09887f | ||
|
|
3f730e77aa | ||
|
|
caecbe876a | ||
|
|
8afb6d62f2 | ||
|
|
6ccfa603cd | ||
|
|
36cad99a11 | ||
|
|
ee18e1cbf0 | ||
|
|
af120c2bc0 | ||
|
|
6a3acf2301 | ||
|
|
d6977f2a57 | ||
|
|
db5443e831 | ||
|
|
52b8384d10 | ||
|
|
44cc5da4bc | ||
|
|
dde3682b69 | ||
|
|
17310d91a6 | ||
|
|
b194d65a6a | ||
|
|
a44b27f5f8 | ||
|
|
e5a33f2223 | ||
|
|
c1e3340b23 | ||
|
|
8f163a367d | ||
|
|
89a3df9014 | ||
|
|
c5d2937aa5 | ||
|
|
b61a65e313 | ||
|
|
04cbb4191c | ||
|
|
c5460762e7 | ||
|
|
8ce49cd39e | ||
|
|
9c68b50853 | ||
|
|
111f1e71af | ||
|
|
827003d568 | ||
|
|
d363a76aa4 | ||
|
|
70560b6bd5 | ||
|
|
7ef8a6f2d5 | ||
|
|
31c6f6e33f | ||
|
|
584d48458e | ||
|
|
5cf984ca87 | ||
|
|
a9bac3d9e5 | ||
|
|
5458d43247 | ||
|
|
a4dba65220 | ||
|
|
3dcb286baf | ||
|
|
4822c3dbe9 | ||
|
|
2ca75bb529 | ||
|
|
db14e29a0b | ||
|
|
d2f540f4e0 | ||
|
|
333ffea273 | ||
|
|
f55b6f1f2f | ||
|
|
30561229c7 | ||
|
|
068a4612e9 | ||
|
|
5722c147de | ||
|
|
f6819a1f26 | ||
|
|
f93f87c802 | ||
|
|
9392fc3f88 | ||
|
|
e843c4d8d5 | ||
|
|
0c5fc63a36 | ||
|
|
e397177f6e | ||
|
|
f4c8888cbe | ||
|
|
25c1e03205 | ||
|
|
512281781c | ||
|
|
ac85ddfdb7 | ||
|
|
65d0d40232 | ||
|
|
cea9369610 | ||
|
|
e7c6e1db82 | ||
|
|
c5fcd5b61b | ||
|
|
1df9887998 | ||
|
|
73f22d6226 | ||
|
|
c422050ca7 | ||
|
|
1ba18ff7d9 | ||
|
|
37b440faa8 | ||
|
|
888b13ed63 | ||
|
|
4abb218d21 | ||
|
|
6441c21a94 | ||
|
|
dfb5022eab | ||
|
|
ac207ce7aa | ||
|
|
fce53b61d6 | ||
|
|
8ae4a76308 | ||
|
|
7fde1b6a1e | ||
|
|
aa7b47481a | ||
|
|
56be773610 | ||
|
|
a9bdd67baa | ||
|
|
f2adb5638d | ||
|
|
728d4db582 | ||
|
|
db5c7efcf6 | ||
|
|
7bb96e4249 | ||
|
|
fa89f0b150 | ||
|
|
ca973d1e83 | ||
|
|
828c5f1137 | ||
|
|
7d86a5c108 | ||
|
|
0b807893a7 | ||
|
|
6ad0889c8a | ||
|
|
737dd6d1ac | ||
|
|
aaf78f4c6b | ||
|
|
8831064493 | ||
|
|
be9bc96da4 | ||
|
|
86258f292f | ||
|
|
b26d88591c | ||
|
|
86c6a15571 | ||
|
|
8b25ce62d5 | ||
|
|
da5912e4f2 | ||
|
|
daafee676f | ||
|
|
d32519c8ee | ||
|
|
b405591249 | ||
|
|
3bf81ed1bd | ||
|
|
2204182bba | ||
|
|
3628e5d497 | ||
|
|
a0ae49d397 | ||
|
|
254476718b | ||
|
|
3adba92ebe | ||
|
|
ef631d63af | ||
|
|
970dbe8e25 | ||
|
|
641be9463b | ||
|
|
ab0e608862 | ||
|
|
1588659062 | ||
|
|
b9e88fb976 | ||
|
|
4ad53414dd | ||
|
|
d1165b215e | ||
|
|
dcb8319f3d | ||
|
|
5597fa089c | ||
|
|
9acec364c2 | ||
|
|
7d9d6ef456 | ||
|
|
6f5874a2f2 | ||
|
|
70dc336785 | ||
|
|
4e504039f5 | ||
|
|
d1f4d291e8 | ||
|
|
e1840853ce | ||
|
|
0f5ce173da | ||
|
|
588854195f | ||
|
|
28d068bce6 | ||
|
|
d107d8d495 | ||
|
|
1e496ddb82 | ||
|
|
74eccbf3fa | ||
|
|
08638223ca | ||
|
|
56cc858af9 | ||
|
|
f55c4ed1d6 | ||
|
|
93d70419e7 | ||
|
|
63f663d9c6 | ||
|
|
84b4d96efa | ||
|
|
aec67f2fa6 | ||
|
|
deee214a95 | ||
|
|
45adec102c | ||
|
|
31fc530c76 | ||
|
|
fbb3f65a1a | ||
|
|
6b1b8ea91b | ||
|
|
b2273733ea | ||
|
|
f409b229a4 | ||
|
|
30571e2326 | ||
|
|
d7734edd9f | ||
|
|
2ba69bc8fa | ||
|
|
cb349a291c | ||
|
|
f0a0b077a0 | ||
|
|
49114f28ab | ||
|
|
e7d2ebadd2 | ||
|
|
e569803d7c | ||
|
|
d34f887abc | ||
|
|
5201df5030 | ||
|
|
2d3c26c565 | ||
|
|
6325f60d52 | ||
|
|
42cc9cfbc7 | ||
|
|
8347575ba1 | ||
|
|
b6eec20260 | ||
|
|
0eb035b4b1 | ||
|
|
afb9817599 | ||
|
|
8fb3e7a26c | ||
|
|
8c7bc30ce4 | ||
|
|
85873cb162 | ||
|
|
e14ee12491 | ||
|
|
8b9a3f3cea | ||
|
|
fb4e8b896b | ||
|
|
2ca533b279 | ||
|
|
4a9b29a875 | ||
|
|
a4fcc893cd | ||
|
|
9d10239af7 | ||
|
|
19facd4b20 | ||
|
|
f5299f72cd | ||
|
|
0e0d9ac522 | ||
|
|
8917022deb | ||
|
|
ec0d5db67b | ||
|
|
e76e9b87f0 | ||
|
|
cfb6a244ea | ||
|
|
58f3860306 | ||
|
|
dd4f53db63 | ||
|
|
3d5e17e507 | ||
|
|
33bf1a244b | ||
|
|
772f471ff2 | ||
|
|
2c11d10f8d | ||
|
|
656ed7f780 | ||
|
|
81bb9a2a9e | ||
|
|
5adf185f86 | ||
|
|
c9a9180584 | ||
|
|
76831ed83d | ||
|
|
b3d7b85376 | ||
|
|
cad5c0241c | ||
|
|
b8022c578a | ||
|
|
bc53f8293f | ||
|
|
c552ff2451 | ||
|
|
4fda5fbdf9 | ||
|
|
580776559b | ||
|
|
a14aaa7c9d | ||
|
|
a6d780154f | ||
|
|
6871e2eeb7 | ||
|
|
8402a2acf4 | ||
|
|
fddb6933e1 | ||
|
|
c8b4787e4e | ||
|
|
2188199ff8 | ||
|
|
aa07429bad | ||
|
|
918761a25a | ||
|
|
a4fc671d3e | ||
|
|
f5f65ef48c | ||
|
|
c2dd81a8aa | ||
|
|
d7e680ffe4 | ||
|
|
c371baf53a | ||
|
|
ccf78f566c | ||
|
|
c9fa68664a | ||
|
|
c35f4d089a | ||
|
|
8590c0941e | ||
|
|
095163b8d1 | ||
|
|
99c33d011d | ||
|
|
62fecf3e13 | ||
|
|
7c4eb5d03e | ||
|
|
bae9a6b404 | ||
|
|
004c1d8ef2 | ||
|
|
7ebb2e0193 | ||
|
|
9ce77798b1 | ||
|
|
f8bad60609 | ||
|
|
5866b3857b | ||
|
|
1ca616844b | ||
|
|
2e8cf0b450 | ||
|
|
24f89173d1 | ||
|
|
c6a20b427a | ||
|
|
a5ac9244c4 | ||
|
|
c763fe1be0 | ||
|
|
52dc8c8cd5 | ||
|
|
aede70e81d | ||
|
|
85a8beb5e4 | ||
|
|
0bb89e9e5f | ||
|
|
5685ceb3c7 | ||
|
|
0408ba0a76 | ||
|
|
cbad6c3093 | ||
|
|
1b021f6984 | ||
|
|
95b7551d65 | ||
|
|
db5a7c6192 | ||
|
|
6ef2f67e7f | ||
|
|
f76ee1ffd2 | ||
|
|
54a71f270a | ||
|
|
55b4062dd8 | ||
|
|
79071bfba4 | ||
|
|
7774b87cbd | ||
|
|
35c87741cf | ||
|
|
4cbe605214 | ||
|
|
ab8883dd55 | ||
|
|
eebe73001a | ||
|
|
0359bf02c9 | ||
|
|
237f9e58a8 | ||
|
|
8576e6fe36 | ||
|
|
0654543dcc | ||
|
|
48ef3e74e2 | ||
|
|
7d4b378952 | ||
|
|
7ff5c41e06 | ||
|
|
602f43e3d1 | ||
|
|
a2cadb8218 | ||
|
|
c1eb9d05d9 | ||
|
|
cf6c939e86 | ||
|
|
130df35e1b | ||
|
|
0751263dec | ||
|
|
eca2f3eb97 | ||
|
|
3aa9cf3f9e | ||
|
|
8f3d208dce | ||
|
|
caaa3f1f8c | ||
|
|
659a51919f | ||
|
|
6661387066 | ||
|
|
a7fae8a176 | ||
|
|
0cae0bdac8 | ||
|
|
5a1a5d5ed1 | ||
|
|
1683975acf | ||
|
|
af705590ac | ||
|
|
825124af8f | ||
|
|
9c5e7da507 | ||
|
|
481349495b | ||
|
|
9daa6b003f | ||
|
|
a3a632d567 | ||
|
|
e496c5a4b4 | ||
|
|
ea890d8710 | ||
|
|
aa5d84f102 | ||
|
|
f1606486d2 | ||
|
|
87720a8908 | ||
|
|
bb6565ef14 | ||
|
|
7bb063bcb3 | ||
|
|
b36dd472bb | ||
|
|
167b759a38 | ||
|
|
99b9868859 | ||
|
|
6b2d5448f2 | ||
|
|
eaf709b83e | ||
|
|
f0e70afff0 | ||
|
|
86984cad68 | ||
|
|
fbc89e3ced | ||
|
|
38c1e720c2 | ||
|
|
600e87e03c | ||
|
|
3836445241 | ||
|
|
1d2c9d6a07 | ||
|
|
e8ac6bd2f5 | ||
|
|
fdadc4f22c | ||
|
|
79b527f45f | ||
|
|
dc4eada7f0 | ||
|
|
70ebc3b598 | ||
|
|
b13f2aed16 | ||
|
|
5f04c0f818 | ||
|
|
55935ccae7 | ||
|
|
b529515eb1 | ||
|
|
3cde719eb7 | ||
|
|
5de6d94a90 | ||
|
|
99eefd2ec0 | ||
|
|
e9e268336b | ||
|
|
7275ac7523 | ||
|
|
c4189a38e4 | ||
|
|
68d1b3256b | ||
|
|
9c6953bda7 | ||
|
|
ef7ece9851 | ||
|
|
ddaa4b7dcb | ||
|
|
dfae2c6989 | ||
|
|
515f104926 | ||
|
|
9ecefd56db | ||
|
|
e5d35aa187 | ||
|
|
00794c42bc | ||
|
|
08a1bf3f10 | ||
|
|
60c4154346 | ||
|
|
f2c85308c1 | ||
|
|
1a28b69ee2 | ||
|
|
ba09f01ce8 | ||
|
|
6cf48872b7 | ||
|
|
7b3b8fa000 | ||
|
|
ec5e2aae61 | ||
|
|
86389bf970 | ||
|
|
3290bfa690 | ||
|
|
8777fd104f | ||
|
|
c41f7565ed | ||
|
|
9ba81e3da4 | ||
|
|
c23888acd7 | ||
|
|
f98ce25ab9 | ||
|
|
de5f38fd48 | ||
|
|
ec2854b13a | ||
|
|
90823d2938 | ||
|
|
5f5770e3a2 | ||
|
|
28f39e9038 | ||
|
|
b2d2b37888 | ||
|
|
fe597e141c | ||
|
|
72ca1539e0 | ||
|
|
13b26775f1 | ||
|
|
05d7118561 | ||
|
|
98b901ad66 | ||
|
|
5580b47291 | ||
|
|
bc62932984 | ||
|
|
a6b5d6e759 | ||
|
|
a8931306e1 | ||
|
|
fecdb8717e | ||
|
|
916fd273ea | ||
|
|
0da8506552 | ||
|
|
eda7a7b43e | ||
|
|
022eabb734 | ||
|
|
aba899cef8 | ||
|
|
6a40e1c176 | ||
|
|
9307b2ab8b | ||
|
|
522d8d3917 | ||
|
|
a84cc0123f | ||
|
|
f018e248cd | ||
|
|
cfd7237a80 | ||
|
|
4eef8102c9 | ||
|
|
69e4dd506b | ||
|
|
25814a9458 | ||
|
|
2a980a76ce | ||
|
|
d343782c8b | ||
|
|
4e1994e9d7 | ||
|
|
65a38c452b | ||
|
|
7b7e2352cd | ||
|
|
1177d28395 | ||
|
|
005e7efa64 | ||
|
|
b42d13ec84 | ||
|
|
9adcd1a650 | ||
|
|
3c164fca8c | ||
|
|
95e335db7b | ||
|
|
f90206ad74 | ||
|
|
3779150750 | ||
|
|
0a9777aa5c | ||
|
|
45ad06aac8 | ||
|
|
c6ea2ba329 | ||
|
|
2770a10240 | ||
|
|
d2a94f9e6a | ||
|
|
32da94507a | ||
|
|
736a340478 | ||
|
|
117e1355a2 | ||
|
|
3c3e558c60 | ||
|
|
cffceda6ee | ||
|
|
048805ad2c | ||
|
|
d14c9fe7ea | ||
|
|
5db90ce822 | ||
|
|
d699cc1330 | ||
|
|
c4230747a1 | ||
|
|
5245f12a46 | ||
|
|
a198b2787e | ||
|
|
04edad8c59 | ||
|
|
392b3060b0 | ||
|
|
85b34d59bc | ||
|
|
f599c11bc8 | ||
|
|
0792ff02ff | ||
|
|
fd0d63ba5b | ||
|
|
3835a428c5 | ||
|
|
9680f72cca | ||
|
|
a0737273d3 | ||
|
|
e613d0eaf0 | ||
|
|
6bcd6bcf70 | ||
|
|
ba12e4999a | ||
|
|
4e7cd31d12 | ||
|
|
5e6c130d93 | ||
|
|
5d68082881 | ||
|
|
607181644f | ||
|
|
89d327075f | ||
|
|
6bf00ef631 | ||
|
|
7d042f17fe | ||
|
|
28b8079e30 | ||
|
|
7face5d9fd | ||
|
|
a44dc4bdb0 | ||
|
|
2d0f384b6f | ||
|
|
8ff84b5c43 | ||
|
|
10b271d963 | ||
|
|
0ebc8a3d25 | ||
|
|
bbda0fdbdb | ||
|
|
c86422bdd4 | ||
|
|
c707b2b0a6 | ||
|
|
78ba24c37d | ||
|
|
1a2cb72030 | ||
|
|
344a29506e | ||
|
|
71de73a668 | ||
|
|
4c1dfa58b7 | ||
|
|
5274c3c43f | ||
|
|
1762793989 | ||
|
|
6cec78d8f2 | ||
|
|
2dc307f2e6 | ||
|
|
7aea5b1895 | ||
|
|
9733e16496 | ||
|
|
7f2d1024f3 | ||
|
|
428f589364 | ||
|
|
5cd97f7ffe | ||
|
|
e425dc00c0 | ||
|
|
d274ae77f2 | ||
|
|
55c5ac7820 | ||
|
|
0145911bea | ||
|
|
0a5215693e | ||
|
|
2a45056ba8 | ||
|
|
142b77751d | ||
|
|
a5ededf1c3 | ||
|
|
7df3f792a2 | ||
|
|
9eb7d7362f | ||
|
|
1c0c118f7c | ||
|
|
1a1b2108ec | ||
|
|
b6c6552d20 | ||
|
|
83a0340fa7 | ||
|
|
a62fc1b39f | ||
|
|
af1b725fda | ||
|
|
9174606d4c | ||
|
|
ca305afdbe | ||
|
|
fe5987b81d | ||
|
|
a229c8cef0 | ||
|
|
f6c0499b8d | ||
|
|
1156c84e86 | ||
|
|
ec7c7def40 | ||
|
|
2d8e667400 | ||
|
|
80c863b972 | ||
|
|
f5cc1eea72 | ||
|
|
b7c9f1d38f | ||
|
|
c6fc07f1f4 | ||
|
|
ded914f442 | ||
|
|
4758c8baa1 | ||
|
|
7064fed1b1 | ||
|
|
1017ac4a9e | ||
|
|
ccb61d7aae | ||
|
|
2235dee906 | ||
|
|
28091aa1ff | ||
|
|
121d9a0702 | ||
|
|
0cea88bcc5 | ||
|
|
72146fc4cd | ||
|
|
e6a7ab9675 | ||
|
|
1f4c127fb9 | ||
|
|
90532b1f37 | ||
|
|
a8666a757a | ||
|
|
a4667da1eb | ||
|
|
0c259961ac | ||
|
|
f288db8d34 | ||
|
|
33421c1dd3 | ||
|
|
5cc5201914 | ||
|
|
252e423e81 | ||
|
|
a4a2764a52 | ||
|
|
ab8e832c18 | ||
|
|
1ce0c0fcb0 | ||
|
|
657f466402 | ||
|
|
c7b0300af5 | ||
|
|
da8c885784 | ||
|
|
1ccaf80575 | ||
|
|
ec36bfa317 | ||
|
|
b8f76f717a | ||
|
|
d1766f2c70 | ||
|
|
516ded618b | ||
|
|
c9c81d0584 | ||
|
|
545f84d905 | ||
|
|
d5ec172c95 | ||
|
|
25b3a3e541 | ||
|
|
058d6ce683 | ||
|
|
eab93985b8 | ||
|
|
b51d70a83c | ||
|
|
259025100e | ||
|
|
c9d30aa6ac | ||
|
|
8544b42007 | ||
|
|
6fa0501387 | ||
|
|
ae69cb15e9 | ||
|
|
a64a8dfe45 | ||
|
|
491fa95b1f | ||
|
|
92ec632ad5 | ||
|
|
8ecdfb718b | ||
|
|
4ba0c24a8f | ||
|
|
935c8c4bb1 | ||
|
|
88f993da38 | ||
|
|
ebfe64b92d | ||
|
|
0308e9af71 | ||
|
|
c3628eea49 | ||
|
|
e03f0372b1 | ||
|
|
f17536af9c | ||
|
|
ed4ec81bca | ||
|
|
7480059306 | ||
|
|
8bae22b0fa | ||
|
|
49c34c4161 | ||
|
|
5548fcc96d | ||
|
|
070bd433ab | ||
|
|
c8fb54951a | ||
|
|
f110357aaa | ||
|
|
a6b426422e | ||
|
|
d03c01dfbc | ||
|
|
a82996e9fb | ||
|
|
af5a614aad | ||
|
|
f9640e049d | ||
|
|
4768c61b57 | ||
|
|
dfccd17ab9 | ||
|
|
635117c5d4 | ||
|
|
50f3535693 | ||
|
|
9111999af3 | ||
|
|
6bd28d246e | ||
|
|
4d595a2a39 | ||
|
|
3a21f61772 | ||
|
|
4e1e9520e1 | ||
|
|
0bf19037ca | ||
|
|
f3dfa36a3a | ||
|
|
4f9b60dd53 | ||
|
|
f76a49e555 | ||
|
|
310ad8d9db | ||
|
|
56db268f47 | ||
|
|
92ab6bdeb8 | ||
|
|
0070e360a1 | ||
|
|
9df8fed046 | ||
|
|
a59fae040f | ||
|
|
29a620cab2 | ||
|
|
87d7a2520e | ||
|
|
40c62c1321 | ||
|
|
35b412c099 | ||
|
|
d0f471cff7 | ||
|
|
6f316b8bf5 | ||
|
|
7c10c93a1f | ||
|
|
d92ea094f1 | ||
|
|
6ae5423b4a | ||
|
|
9635cffdc8 | ||
|
|
96986fb362 | ||
|
|
3ceb341a75 | ||
|
|
50fa705125 | ||
|
|
69a2991614 | ||
|
|
fd3377dd1f | ||
|
|
d0b6cb0425 | ||
|
|
95c4a2e3af | ||
|
|
bc2a29f033 | ||
|
|
3bb5b4a302 | ||
|
|
fc88fd9097 | ||
|
|
c5b0928c1f | ||
|
|
e047fd977d | ||
|
|
9d40e521d7 | ||
|
|
1445dcaa60 | ||
|
|
e4eeb4e910 | ||
|
|
aa86876813 | ||
|
|
974bb54ab2 | ||
|
|
9bc2183a31 | ||
|
|
d4b222b6d3 | ||
|
|
af2af818a6 | ||
|
|
698e63a608 | ||
|
|
211411faf2 | ||
|
|
bb303c45a5 | ||
|
|
6f7986d592 | ||
|
|
7cbb4aef17 | ||
|
|
02bec0bb6d | ||
|
|
c79f6a4a8c | ||
|
|
0c5eea226b | ||
|
|
dcca0d7477 | ||
|
|
0d5e7716ad | ||
|
|
d8c824c594 | ||
|
|
cb431dfc9f | ||
|
|
61d787726a | ||
|
|
5e89aace9b | ||
|
|
2af7e8a9a6 | ||
|
|
2419edd5b2 | ||
|
|
bf481e8e5d | ||
|
|
9d7fa6b8e6 | ||
|
|
073076ac7d | ||
|
|
9bd03dd9b4 | ||
|
|
6931f84412 | ||
|
|
16ec0556a0 | ||
|
|
610af352d4 | ||
|
|
b35f1e3c9c | ||
|
|
dfa0b9aab4 | ||
|
|
a4c47b0276 | ||
|
|
111fefd5e9 | ||
|
|
c1fe1ef081 | ||
|
|
8c34c9dac4 | ||
|
|
91c0277356 | ||
|
|
9f0d5c12fc | ||
|
|
59247c2b62 | ||
|
|
9a3842a2d9 | ||
|
|
726dbd9267 | ||
|
|
54f05e7195 | ||
|
|
26be608470 | ||
|
|
248431eb3c | ||
|
|
76f275b4df | ||
|
|
f1951d6cce | ||
|
|
62f297b51d | ||
|
|
09bc32f62f | ||
|
|
46d8b16ab4 | ||
|
|
42533931fa | ||
|
|
9bd3a7102f | ||
|
|
9e516b71ea | ||
|
|
eac961ddb1 | ||
|
|
57c6aa7188 | ||
|
|
cde5b4ad80 | ||
|
|
4f72c66911 | ||
|
|
960e3f0f05 | ||
|
|
884af42da2 | ||
|
|
048fabdabd | ||
|
|
917252a5a1 | ||
|
|
1a992e31e8 | ||
|
|
d2ff04a4f2 | ||
|
|
015c247393 | ||
|
|
d3cd26820e | ||
|
|
91f6c499d7 | ||
|
|
35e9c87ab9 | ||
|
|
8e88e30d95 | ||
|
|
0eb56d5be0 | ||
|
|
f70764a162 | ||
|
|
dad1b00b13 | ||
|
|
430ffef58a | ||
|
|
3d17077187 | ||
|
|
c9b41d460f | ||
|
|
32972a5924 | ||
|
|
f6afb9c09b | ||
|
|
3ddc07e936 | ||
|
|
c26208f67d | ||
|
|
d15fa13daf |
@@ -7,15 +7,9 @@ parameters:
|
|||||||
nightly_build:
|
nightly_build:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
weekly_build:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
test_release:
|
test_release:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
linux_release:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
build_documentation:
|
||||||
@@ -24,21 +18,22 @@ jobs:
|
|||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
macos:
|
macos:
|
||||||
xcode: "15.2.0"
|
xcode: "26.0.0"
|
||||||
resource_class: macos.m1.medium.gen1
|
resource_class: m4pro.medium
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Install
|
name: Install
|
||||||
command: |
|
command: |
|
||||||
brew install python@3.9
|
xcodebuild -downloadComponent MetalToolchain
|
||||||
|
brew install python@3.10
|
||||||
brew install doxygen
|
brew install doxygen
|
||||||
python3.9 -m venv env
|
python3.10 -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install -r docs/requirements.txt
|
pip install -r docs/requirements.txt
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
pip install . -v
|
||||||
- when:
|
- when:
|
||||||
condition:
|
condition:
|
||||||
not: << parameters.upload-docs >>
|
not: << parameters.upload-docs >>
|
||||||
@@ -70,9 +65,9 @@ jobs:
|
|||||||
git push -f origin gh-pages
|
git push -f origin gh-pages
|
||||||
|
|
||||||
linux_build_and_test:
|
linux_build_and_test:
|
||||||
docker:
|
machine:
|
||||||
- image: cimg/python:3.9
|
image: ubuntu-2204:current
|
||||||
|
resource_class: large
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
@@ -84,33 +79,36 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
pip install --upgrade cmake
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
pip install nanobind==2.2.0
|
export NEEDRESTART_MODE=a
|
||||||
pip install numpy
|
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
uv venv
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
uv pip install cmake
|
||||||
python3 setup.py build_ext --inplace
|
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
uv pip install -e ".[dev]" -v
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
python3 setup.py develop
|
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
echo "stubs"
|
uv pip install typing_extensions
|
||||||
pip install typing_extensions
|
uv run --no-project setup.py generate_stubs
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
python3 -m unittest discover python/tests -v
|
source .venv/bin/activate
|
||||||
|
python -m unittest discover python/tests -v
|
||||||
|
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||||
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
|
source .venv/bin/activate
|
||||||
mkdir -p build && cd build
|
mkdir -p build && cd build
|
||||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||||
make -j `nproc`
|
make -j `nproc`
|
||||||
@@ -122,57 +120,64 @@ jobs:
|
|||||||
parameters:
|
parameters:
|
||||||
xcode_version:
|
xcode_version:
|
||||||
type: string
|
type: string
|
||||||
default: "15.2.0"
|
default: "26.0.0"
|
||||||
|
macosx_deployment_target:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
macos:
|
macos:
|
||||||
xcode: << parameters.xcode_version >>
|
xcode: << parameters.xcode_version >>
|
||||||
resource_class: macos.m1.medium.gen1
|
environment:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||||
|
resource_class: m4pro.medium
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
brew install python@3.9
|
xcodebuild -downloadComponent MetalToolchain
|
||||||
brew install openmpi
|
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
||||||
python3.9 -m venv env
|
brew install openmpi uv
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install nanobind==2.2.0
|
|
||||||
pip install numpy
|
|
||||||
pip install torch
|
|
||||||
pip install tensorflow
|
|
||||||
pip install unittest-xml-reporting
|
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
uv venv --python 3.10
|
||||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
uv pip install \
|
||||||
|
nanobind==2.4.0 \
|
||||||
|
cmake \
|
||||||
|
numpy \
|
||||||
|
torch \
|
||||||
|
tensorflow \
|
||||||
|
unittest-xml-reporting
|
||||||
|
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||||
|
uv pip install -e . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
uv pip install typing_extensions
|
||||||
pip install typing_extensions
|
uv run --no-project setup.py generate_stubs
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||||
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||||
- run:
|
- run:
|
||||||
name: Build example extension
|
name: Build example extension
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
cd examples/extensions
|
cd examples/extensions
|
||||||
pip install -r requirements.txt
|
uv pip install -r requirements.txt
|
||||||
python setup.py build_ext -j8
|
uv run --no-project setup.py build_ext --inplace
|
||||||
|
uv run --no-project python test.py
|
||||||
- store_test_results:
|
- store_test_results:
|
||||||
path: test-results
|
path: test-results
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||||
- run:
|
- run:
|
||||||
name: Run CPP tests
|
name: Run CPP tests
|
||||||
@@ -181,7 +186,7 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Build small binary
|
name: Build small binary
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
cd build/
|
cd build/
|
||||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
-DBUILD_SHARED_LIBS=ON \
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
@@ -193,40 +198,112 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Run Python tests with JIT
|
name: Run Python tests with JIT
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||||
pip install -e . -v
|
uv pip install -e . -v
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||||
METAL_DEBUG_ERROR_MODE=0 \
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
uv run --no-project python -m xmlrunner discover \
|
||||||
|
-v python/tests \
|
||||||
|
-o test-results/gpu_jit
|
||||||
|
|
||||||
|
cuda_build_and_test:
|
||||||
|
parameters:
|
||||||
|
image_date:
|
||||||
|
type: string
|
||||||
|
default: "2023.11.1"
|
||||||
|
machine:
|
||||||
|
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||||
|
resource_class: gpu.nvidia.small.gen2
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- restore_cache:
|
||||||
|
keys:
|
||||||
|
- cuda-<< parameters.image_date >>-{{ arch }}-
|
||||||
|
- run:
|
||||||
|
name: Install dependencies
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libcudnn9-dev-cuda-12
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install libnccl2 libnccl-dev
|
||||||
|
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||||
|
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||||
|
rm -rf ccache-4.11.3-linux-x86_64
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
- run:
|
||||||
|
name: Set CCache size
|
||||||
|
command: ccache --max-size 1G
|
||||||
|
- run:
|
||||||
|
name: Install Python package
|
||||||
|
command: |
|
||||||
|
uv venv
|
||||||
|
uv pip install cmake
|
||||||
|
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
uv pip install -e ".[dev]" -v
|
||||||
|
- run:
|
||||||
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||||
|
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||||
|
- run:
|
||||||
|
name: Build CPP only
|
||||||
|
command: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
cmake . -B build \
|
||||||
|
-DMLX_BUILD_CUDA=ON \
|
||||||
|
-DCMAKE_CUDA_COMPILER=`which nvcc` \
|
||||||
|
-DCMAKE_BUILD_TYPE=DEBUG
|
||||||
|
cmake --build build -j `nproc`
|
||||||
|
- run:
|
||||||
|
name: Run CPP tests
|
||||||
|
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
||||||
|
- run:
|
||||||
|
name: CCache report
|
||||||
|
command: |
|
||||||
|
ccache --show-stats
|
||||||
|
ccache --zero-stats
|
||||||
|
ccache --cleanup
|
||||||
|
- save_cache:
|
||||||
|
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||||
|
paths:
|
||||||
|
- /home/circleci/.cache/ccache
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
parameters:
|
parameters:
|
||||||
python_version:
|
python_version:
|
||||||
type: string
|
type: string
|
||||||
default: "3.9"
|
default: "3.10"
|
||||||
xcode_version:
|
xcode_version:
|
||||||
type: string
|
type: string
|
||||||
default: "15.2.0"
|
default: "26.0.0"
|
||||||
build_env:
|
build_env:
|
||||||
type: string
|
type: string
|
||||||
default: ""
|
default: ""
|
||||||
|
macosx_deployment_target:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
macos:
|
macos:
|
||||||
xcode: << parameters.xcode_version >>
|
xcode: << parameters.xcode_version >>
|
||||||
resource_class: macos.m1.medium.gen1
|
resource_class: m4pro.medium
|
||||||
|
environment:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
brew install python@<< parameters.python_version >>
|
xcodebuild -downloadComponent MetalToolchain
|
||||||
brew install openmpi
|
mkdir -p ~/miniconda3
|
||||||
python<< parameters.python_version >> -m venv env
|
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||||
source env/bin/activate
|
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||||
pip install --upgrade pip
|
rm ~/miniconda3/miniconda.sh
|
||||||
|
source ~/miniconda3/bin/activate
|
||||||
|
conda init --all
|
||||||
|
conda create -n env python=<< parameters.python_version >> -y
|
||||||
|
conda activate env
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.2.0
|
pip install nanobind==2.4.0
|
||||||
pip install --upgrade setuptools
|
pip install --upgrade setuptools
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install twine
|
pip install twine
|
||||||
@@ -234,30 +311,38 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
conda activate env
|
||||||
DEV_RELEASE=1 \
|
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
pip install . -v
|
pip install . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
conda activate env
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
conda activate env
|
||||||
<< parameters.build_env >> \
|
python setup.py clean --all
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||||
python -m build -w
|
- when:
|
||||||
|
condition:
|
||||||
|
equal: ["3.10", << parameters.python_version >>]
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Build common package
|
||||||
|
command: |
|
||||||
|
conda activate env
|
||||||
|
python setup.py clean --all
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||||
- when:
|
- when:
|
||||||
condition: << parameters.build_env >>
|
condition: << parameters.build_env >>
|
||||||
steps:
|
steps:
|
||||||
- run:
|
- run:
|
||||||
name: Upload package
|
name: Upload package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
conda activate env
|
||||||
twine upload dist/*
|
twine upload dist/*
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: dist/
|
path: dist/
|
||||||
@@ -266,53 +351,101 @@ jobs:
|
|||||||
parameters:
|
parameters:
|
||||||
python_version:
|
python_version:
|
||||||
type: string
|
type: string
|
||||||
default: "3.9"
|
default: "3.10"
|
||||||
extra_env:
|
build_env:
|
||||||
type: string
|
type: string
|
||||||
default: "DEV_RELEASE=1"
|
default: ""
|
||||||
docker:
|
machine:
|
||||||
- image: ubuntu:20.04
|
image: ubuntu-2204:current
|
||||||
|
resource_class: large
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Build wheel
|
name: Build wheel
|
||||||
command: |
|
command: |
|
||||||
PYTHON=python<< parameters.python_version >>
|
PYTHON=python<< parameters.python_version >>
|
||||||
apt-get update
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
apt-get upgrade -y
|
export NEEDRESTART_MODE=a
|
||||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
sudo apt-get update
|
||||||
apt-get install -y apt-utils
|
TZ=Etc/UTC sudo apt-get -y install tzdata
|
||||||
apt-get install -y software-properties-common
|
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||||
add-apt-repository -y ppa:deadsnakes/ppa
|
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
apt-get install -y build-essential git
|
|
||||||
$PYTHON -m venv env
|
$PYTHON -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.2.0
|
|
||||||
pip install --upgrade setuptools
|
|
||||||
pip install numpy
|
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
pip install patchelf
|
pip install patchelf
|
||||||
pip install build
|
pip install build
|
||||||
pip install twine
|
pip install twine
|
||||||
<< parameters.extra_env >> \
|
<< parameters.build_env >> pip install ".[dev]" -v
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
pip install . -v
|
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> \
|
python setup.py clean --all
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||||
python -m build --wheel
|
bash python/scripts/repair_linux.sh
|
||||||
auditwheel show dist/*
|
- when:
|
||||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
condition:
|
||||||
|
equal: ["3.10", << parameters.python_version >>]
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Build common package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
python setup.py clean --all
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||||
|
python -m build -w
|
||||||
|
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||||
|
- when:
|
||||||
|
condition: << parameters.build_env >>
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Upload packages
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload wheelhouse/*.whl
|
||||||
|
- store_artifacts:
|
||||||
|
path: wheelhouse/
|
||||||
|
|
||||||
|
build_cuda_release:
|
||||||
|
parameters:
|
||||||
|
build_env:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
|
machine:
|
||||||
|
image: ubuntu-2204:current
|
||||||
|
resource_class: xlarge
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Build wheel
|
||||||
|
command: |
|
||||||
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
|
export NEEDRESTART_MODE=a
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install zip
|
||||||
|
pip install auditwheel
|
||||||
|
pip install patchelf
|
||||||
|
pip install build
|
||||||
|
pip install twine
|
||||||
|
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
python -m build -w
|
||||||
|
bash python/scripts/repair_cuda.sh
|
||||||
|
- when:
|
||||||
|
condition: << parameters.build_env >>
|
||||||
|
steps:
|
||||||
- run:
|
- run:
|
||||||
name: Upload package
|
name: Upload package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
twine upload wheelhouse/*.whl
|
||||||
twine upload wheelhouse/*
|
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
@@ -324,21 +457,23 @@ workflows:
|
|||||||
pattern: "^(?!pull/)[-\\w]+$"
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
value: << pipeline.git.branch >>
|
value: << pipeline.git.branch >>
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- mac_build_and_test:
|
- mac_build_and_test:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
macosx_deployment_target: ["13.5", "15.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
|
- cuda_build_and_test:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
image_date: ["2023.11.1", "2025.05.1"]
|
||||||
- build_documentation
|
- build_documentation
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_release:
|
- build_release:
|
||||||
@@ -349,9 +484,10 @@ workflows:
|
|||||||
ignore: /.*/
|
ignore: /.*/
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
xcode_version: ["15.0.0", "15.2.0"]
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
build_env: ["PYPI_RELEASE=1"]
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
xcode_version: ["26.0.0"]
|
||||||
- build_documentation:
|
- build_documentation:
|
||||||
filters:
|
filters:
|
||||||
tags:
|
tags:
|
||||||
@@ -359,6 +495,25 @@ workflows:
|
|||||||
branches:
|
branches:
|
||||||
ignore: /.*/
|
ignore: /.*/
|
||||||
upload-docs: true
|
upload-docs: true
|
||||||
|
- build_linux_release:
|
||||||
|
filters:
|
||||||
|
tags:
|
||||||
|
only: /^v.*/
|
||||||
|
branches:
|
||||||
|
ignore: /.*/
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
- build_cuda_release:
|
||||||
|
filters:
|
||||||
|
tags:
|
||||||
|
only: /^v.*/
|
||||||
|
branches:
|
||||||
|
ignore: /.*/
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
|
||||||
prb:
|
prb:
|
||||||
when:
|
when:
|
||||||
@@ -374,9 +529,14 @@ workflows:
|
|||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
macosx_deployment_target: ["13.5", "15.0"]
|
||||||
- linux_build_and_test:
|
- linux_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
|
- cuda_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
image_date: ["2023.11.1", "2025.05.1"]
|
||||||
nightly_build:
|
nightly_build:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
@@ -386,28 +546,34 @@ workflows:
|
|||||||
- build_release:
|
- build_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
xcode_version: ["15.0.0", "15.2.0"]
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
weekly_build:
|
xcode_version: ["26.0.0"]
|
||||||
|
- build_linux_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
|
- build_cuda_release
|
||||||
|
|
||||||
|
build_dev_release:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
- << pipeline.parameters.weekly_build >>
|
- << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_release:
|
- build_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
build_env: ["DEV_RELEASE=1"]
|
build_env: ["DEV_RELEASE=1"]
|
||||||
linux_test_release:
|
xcode_version: ["26.0.0"]
|
||||||
when:
|
|
||||||
and:
|
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
|
||||||
- << pipeline.parameters.linux_release >>
|
|
||||||
jobs:
|
|
||||||
- build_linux_release:
|
- build_linux_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
- build_cuda_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
|||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
|
uv.lock
|
||||||
|
|
||||||
# vim
|
# vim
|
||||||
*.swp
|
*.swp
|
||||||
@@ -76,6 +77,9 @@ build/
|
|||||||
*.out
|
*.out
|
||||||
*.app
|
*.app
|
||||||
|
|
||||||
|
# Debug symbols
|
||||||
|
*.pdb
|
||||||
|
|
||||||
# VSCode
|
# VSCode
|
||||||
.vscode/
|
.vscode/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v18.1.8
|
rev: v19.1.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 24.8.0
|
rev: 25.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.13.2
|
rev: 6.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args:
|
args:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
|||||||
|
|
||||||
MLX was developed with contributions from the following individuals:
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||||
@@ -19,11 +19,17 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
|
# Organizations
|
||||||
|
|
||||||
|
MLX has received contributions from the following companies:
|
||||||
|
- NVIDIA Corporation & Affiliates
|
||||||
|
|
||||||
# Third-Party Software
|
# Third-Party Software
|
||||||
|
|
||||||
MLX leverages several third-party software, listed here together with
|
MLX leverages several third-party software, listed here together with
|
||||||
|
|||||||
188
CMakeLists.txt
188
CMakeLists.txt
@@ -1,13 +1,36 @@
|
|||||||
cmake_minimum_required(VERSION 3.24)
|
cmake_minimum_required(VERSION 3.25)
|
||||||
|
|
||||||
project(mlx LANGUAGES C CXX)
|
if(NOT MLX_VERSION)
|
||||||
|
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
|
||||||
|
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||||
|
set(_major ${CMAKE_MATCH_1})
|
||||||
|
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||||
|
set(_minor ${CMAKE_MATCH_1})
|
||||||
|
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
||||||
|
set(_patch ${CMAKE_MATCH_1})
|
||||||
|
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
||||||
|
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
||||||
|
else()
|
||||||
|
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
||||||
|
${MLX_VERSION})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
project(
|
||||||
|
mlx
|
||||||
|
LANGUAGES C CXX
|
||||||
|
VERSION ${MLX_PROJECT_VERSION})
|
||||||
|
|
||||||
|
if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
|
||||||
|
add_compile_options(-Wall -Wextra)
|
||||||
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Setup -----------------------------
|
# ----------------------------- Setup -----------------------------
|
||||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 20)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
# ----------------------------- Configuration -----------------------------
|
# ----------------------------- Configuration -----------------------------
|
||||||
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
||||||
@@ -16,26 +39,23 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
|||||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||||
|
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||||
|
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||||
|
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||||
if(NOT MLX_VERSION)
|
|
||||||
set(MLX_VERSION 0.19.0)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
|
|
||||||
message(
|
message(
|
||||||
STATUS
|
STATUS
|
||||||
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
||||||
)
|
)
|
||||||
|
|
||||||
set(MLX_BUILD_ARM OFF)
|
|
||||||
|
|
||||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||||
if(NOT MLX_ENABLE_X64_MAC)
|
if(NOT MLX_ENABLE_X64_MAC)
|
||||||
@@ -51,14 +71,17 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|||||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
else()
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
if(MLX_USE_CCACHE)
|
||||||
set(MLX_BUILD_ARM ON)
|
find_program(CCACHE_PROGRAM ccache)
|
||||||
|
if(CCACHE_PROGRAM)
|
||||||
|
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Lib -----------------------------
|
# ----------------------------- Lib -----------------------------
|
||||||
@@ -69,18 +92,21 @@ cmake_policy(SET CMP0135 NEW)
|
|||||||
|
|
||||||
add_library(mlx)
|
add_library(mlx)
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_CUDA)
|
||||||
set(METAL_LIB "-framework Metal")
|
enable_language(CUDA)
|
||||||
set(FOUNDATION_LIB "-framework Foundation")
|
|
||||||
set(QUARTZ_LIB "-framework QuartzCore")
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
if(MLX_BUILD_METAL)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
find_library(METAL_LIB Metal)
|
||||||
set(MLX_BUILD_METAL OFF)
|
find_library(FOUNDATION_LIB Foundation)
|
||||||
set(MLX_METAL_DEBUG OFF)
|
find_library(QUARTZ_LIB QuartzCore)
|
||||||
elseif(MLX_BUILD_METAL)
|
if(METAL_LIB)
|
||||||
message(STATUS "Building METAL sources")
|
message(STATUS "Metal found ${METAL_LIB}")
|
||||||
|
else()
|
||||||
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
|
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
add_compile_definitions(MLX_METAL_DEBUG)
|
add_compile_definitions(MLX_METAL_DEBUG)
|
||||||
@@ -89,25 +115,27 @@ elseif(MLX_BUILD_METAL)
|
|||||||
# Throw an error if xcrun not found
|
# Throw an error if xcrun not found
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
if(${MACOS_VERSION} LESS 14.0)
|
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||||
message(
|
message(
|
||||||
FATAL_ERROR
|
FATAL_ERROR
|
||||||
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
||||||
endif()
|
endif()
|
||||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||||
|
|
||||||
set(METAL_CPP_URL
|
set(METAL_CPP_URL
|
||||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
||||||
)
|
|
||||||
# Get the metal version
|
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||||
|
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||||
|
endif()
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND
|
COMMAND
|
||||||
zsh "-c"
|
zsh "-c"
|
||||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||||
|
|
||||||
FetchContent_MakeAvailable(metal_cpp)
|
FetchContent_MakeAvailable(metal_cpp)
|
||||||
@@ -115,20 +143,64 @@ elseif(MLX_BUILD_METAL)
|
|||||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||||
$<INSTALL_INTERFACE:include/metal_cpp>)
|
$<INSTALL_INTERFACE:include/metal_cpp>)
|
||||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||||
|
endif()
|
||||||
|
|
||||||
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||||
|
# With newer clang/gcc versions following libs are implicitly linked, but when
|
||||||
|
# building on old distributions they need to be explicitly listed.
|
||||||
|
target_link_libraries(mlx PRIVATE dl pthread)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
if(MSVC)
|
||||||
|
# GGUF does not build with MSVC.
|
||||||
|
set(MLX_BUILD_GGUF OFF)
|
||||||
|
# There is no prebuilt OpenBLAS distribution for MSVC.
|
||||||
|
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
|
||||||
|
endif()
|
||||||
|
# Windows implementation of dlfcn.h APIs.
|
||||||
|
FetchContent_Declare(
|
||||||
|
dlfcn-win32
|
||||||
|
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
||||||
|
GIT_TAG v1.4.1
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
block()
|
||||||
|
set(BUILD_SHARED_LIBS OFF)
|
||||||
|
FetchContent_MakeAvailable(dlfcn-win32)
|
||||||
|
endblock()
|
||||||
|
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
|
||||||
|
target_link_libraries(mlx PRIVATE dl)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_CPU)
|
if(MLX_BUILD_CPU)
|
||||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||||
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
if(ACCELERATE_LIBRARY)
|
||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
set(MLX_BUILD_ACCELERATE ON)
|
set(MLX_BUILD_ACCELERATE ON)
|
||||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
|
||||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
|
||||||
else()
|
else()
|
||||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
message(STATUS "Accelerate not found, using default backend.")
|
||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_ACCELERATE)
|
||||||
|
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||||
|
add_compile_definitions(MLX_USE_ACCELERATE)
|
||||||
|
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||||
|
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
|
||||||
|
# Download and build OpenBLAS from source code.
|
||||||
|
FetchContent_Declare(
|
||||||
|
openblas
|
||||||
|
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
|
||||||
|
GIT_TAG v0.3.28
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
set(BUILD_STATIC_LIBS ON) # link statically
|
||||||
|
set(NOFORTRAN ON) # msvc has no fortran compiler
|
||||||
|
FetchContent_MakeAvailable(openblas)
|
||||||
|
target_link_libraries(mlx PRIVATE openblas)
|
||||||
|
target_include_directories(
|
||||||
|
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
|
||||||
|
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
|
||||||
|
else()
|
||||||
if(${CMAKE_HOST_APPLE})
|
if(${CMAKE_HOST_APPLE})
|
||||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||||
# openblas instead.
|
# openblas instead.
|
||||||
@@ -146,7 +218,7 @@ if(MLX_BUILD_CPU)
|
|||||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
|
||||||
# List blas after lapack otherwise we may accidentally incldue an old
|
# List blas after lapack otherwise we may accidentally incldue an old
|
||||||
# version of lapack.h from the include dirs of blas.
|
# version of lapack.h from the include dirs of blas.
|
||||||
find_package(BLAS REQUIRED)
|
find_package(BLAS REQUIRED)
|
||||||
@@ -159,29 +231,19 @@ if(MLX_BUILD_CPU)
|
|||||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||||
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_package(MPI)
|
message(STATUS "Downloading json")
|
||||||
if(MPI_FOUND)
|
FetchContent_Declare(
|
||||||
execute_process(
|
json
|
||||||
COMMAND zsh "-c" "mpirun --version"
|
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||||
OUTPUT_VARIABLE MPI_VERSION
|
FetchContent_MakeAvailable(json)
|
||||||
ERROR_QUIET)
|
target_include_directories(
|
||||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
||||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
|
||||||
elseif(MPI_VERSION STREQUAL "")
|
|
||||||
set(MPI_FOUND FALSE)
|
|
||||||
message(
|
|
||||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
|
||||||
else()
|
|
||||||
set(MPI_FOUND FALSE)
|
|
||||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||||
|
|
||||||
@@ -189,12 +251,19 @@ target_include_directories(
|
|||||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
$<INSTALL_INTERFACE:include>)
|
$<INSTALL_INTERFACE:include>)
|
||||||
|
|
||||||
FetchContent_Declare(
|
# Do not add mlx_EXPORTS define for shared library.
|
||||||
|
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||||
|
|
||||||
|
if(USE_SYSTEM_FMT)
|
||||||
|
find_package(fmt REQUIRED)
|
||||||
|
else()
|
||||||
|
FetchContent_Declare(
|
||||||
fmt
|
fmt
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
GIT_TAG 10.2.1
|
GIT_TAG 10.2.1
|
||||||
EXCLUDE_FROM_ALL)
|
EXCLUDE_FROM_ALL)
|
||||||
FetchContent_MakeAvailable(fmt)
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
endif()
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||||
|
|
||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
@@ -206,8 +275,7 @@ if(MLX_BUILD_PYTHON_BINDINGS)
|
|||||||
execute_process(
|
execute_process(
|
||||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
OUTPUT_VARIABLE NB_DIR)
|
OUTPUT_VARIABLE nanobind_ROOT)
|
||||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
|
||||||
find_package(nanobind CONFIG REQUIRED)
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -17,11 +17,11 @@ possible.
|
|||||||
|
|
||||||
You can also run the formatters manually as follows:
|
You can also run the formatters manually as follows:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
clang-format -i file.cpp
|
clang-format -i file.cpp
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```shell
|
||||||
black file.py
|
black file.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
include CMakeLists.txt
|
include CMakeLists.txt
|
||||||
|
include mlx.pc.in
|
||||||
recursive-include mlx/ *
|
recursive-include mlx/ *
|
||||||
|
include cmake/*
|
||||||
include python/src/*
|
include python/src/*
|
||||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||||
|
|||||||
33
README.md
33
README.md
@@ -6,33 +6,33 @@
|
|||||||
|
|
||||||
[](https://circleci.com/gh/ml-explore/mlx)
|
[](https://circleci.com/gh/ml-explore/mlx)
|
||||||
|
|
||||||
MLX is an array framework for machine learning research on Apple silicon,
|
MLX is an array framework for machine learning on Apple silicon,
|
||||||
brought to you by Apple machine learning research.
|
brought to you by Apple machine learning research.
|
||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
|
||||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||||
more complex models.
|
more complex models.
|
||||||
|
|
||||||
- **Composable function transformations**: MLX supports composable function
|
- **Composable function transformations**: MLX supports composable function
|
||||||
transformations for automatic differentiation, automatic vectorization,
|
transformations for automatic differentiation, automatic vectorization,
|
||||||
and computation graph optimization.
|
and computation graph optimization.
|
||||||
|
|
||||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||||
materialized when needed.
|
materialized when needed.
|
||||||
|
|
||||||
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||||
dynamically. Changing the shapes of function arguments does not trigger
|
dynamically. Changing the shapes of function arguments does not trigger
|
||||||
slow compilations, and debugging is simple and intuitive.
|
slow compilations, and debugging is simple and intuitive.
|
||||||
|
|
||||||
- **Multi-device**: Operations can run on any of the supported devices
|
- **Multi-device**: Operations can run on any of the supported devices
|
||||||
(currently the CPU and the GPU).
|
(currently the CPU and the GPU).
|
||||||
|
|
||||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||||
Operations on MLX arrays can be performed on any of the supported
|
Operations on MLX arrays can be performed on any of the supported
|
||||||
device types without transferring data.
|
device types without transferring data.
|
||||||
@@ -68,18 +68,23 @@ in the documentation.
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||||
|
macOS, run:
|
||||||
|
|
||||||
**With `pip`**:
|
```bash
|
||||||
|
|
||||||
```
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
**With `conda`**:
|
To install the CUDA backend on Linux, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cuda]
|
||||||
```
|
```
|
||||||
conda install -c conda-forge mlx
|
|
||||||
|
To install a CPU-only Linux package, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cpu]
|
||||||
```
|
```
|
||||||
|
|
||||||
Checkout the
|
Checkout the
|
||||||
@@ -105,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
|||||||
MLX useful in your research and wish to cite it, please use the following
|
MLX useful in your research and wish to cite it, please use the following
|
||||||
BibTex entry:
|
BibTex entry:
|
||||||
|
|
||||||
```
|
```text
|
||||||
@software{mlx2023,
|
@software{mlx2023,
|
||||||
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
||||||
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
||||||
|
|||||||
@@ -5,35 +5,35 @@
|
|||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
#include "time_utils.h"
|
#include "time_utils.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
void time_value_and_grad() {
|
void time_value_and_grad() {
|
||||||
auto x = ones({200, 1000});
|
auto x = mx::ones({200, 1000});
|
||||||
eval(x);
|
mx::eval(x);
|
||||||
auto fn = [](array x) {
|
auto fn = [](mx::array x) {
|
||||||
for (int i = 0; i < 20; ++i) {
|
for (int i = 0; i < 20; ++i) {
|
||||||
x = log(exp(x));
|
x = mx::log(mx::exp(x));
|
||||||
}
|
}
|
||||||
return sum(x);
|
return mx::sum(x);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto grad_fn = grad(fn);
|
auto grad_fn = mx::grad(fn);
|
||||||
auto independent_value_and_grad = [&]() {
|
auto independent_value_and_grad = [&]() {
|
||||||
auto value = fn(x);
|
auto value = fn(x);
|
||||||
auto dfdx = grad_fn(x);
|
auto dfdx = grad_fn(x);
|
||||||
return std::vector<array>{value, dfdx};
|
return std::vector<mx::array>{value, dfdx};
|
||||||
};
|
};
|
||||||
TIME(independent_value_and_grad);
|
TIME(independent_value_and_grad);
|
||||||
|
|
||||||
auto value_and_grad_fn = value_and_grad(fn);
|
auto value_and_grad_fn = mx::value_and_grad(fn);
|
||||||
auto combined_value_and_grad = [&]() {
|
auto combined_value_and_grad = [&]() {
|
||||||
auto [value, dfdx] = value_and_grad_fn(x);
|
auto [value, dfdx] = value_and_grad_fn(x);
|
||||||
return std::vector<array>{value, dfdx};
|
return std::vector<mx::array>{value, dfdx};
|
||||||
};
|
};
|
||||||
TIME(combined_value_and_grad);
|
TIME(combined_value_and_grad);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||||
time_value_and_grad();
|
time_value_and_grad();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,21 +4,21 @@
|
|||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
#include "time_utils.h"
|
#include "time_utils.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
void time_add_op() {
|
void time_add_op() {
|
||||||
std::vector<int> sizes(1, 1);
|
std::vector<int> sizes(1, 1);
|
||||||
for (int i = 0; i < 9; ++i) {
|
for (int i = 0; i < 9; ++i) {
|
||||||
sizes.push_back(10 * sizes.back());
|
sizes.push_back(10 * sizes.back());
|
||||||
}
|
}
|
||||||
set_default_device(Device::cpu);
|
set_default_device(mx::Device::cpu);
|
||||||
for (auto size : sizes) {
|
for (auto size : sizes) {
|
||||||
auto a = random::uniform({size});
|
auto a = mx::random::uniform({size});
|
||||||
auto b = random::uniform({size});
|
auto b = mx::random::uniform({size});
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
std::cout << "Size " << size << std::endl;
|
std::cout << "Size " << size << std::endl;
|
||||||
TIMEM("cpu", add, a, b, Device::cpu);
|
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
|
||||||
TIMEM("gpu", add, a, b, Device::gpu);
|
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,110 +1,111 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
#include "time_utils.h"
|
#include "time_utils.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
void time_irregular_binary_ops_1D() {
|
void time_irregular_binary_ops_1D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
int size = 1000000;
|
int size = 1000000;
|
||||||
int step = 2;
|
int step = 2;
|
||||||
auto a = random::uniform({size});
|
auto a = mx::random::uniform({size});
|
||||||
auto b = random::uniform({size});
|
auto b = mx::random::uniform({size});
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
a = slice(a, {0}, {size}, {step});
|
a = slice(a, {0}, {size}, {step});
|
||||||
b = slice(b, {0}, {size}, {step});
|
b = slice(b, {0}, {size}, {step});
|
||||||
TIMEM("1D strided", add, a, b, device);
|
TIMEM("1D strided", mx::add, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_binary_ops_2D() {
|
void time_irregular_binary_ops_2D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
int size = 2048;
|
int size = 2048;
|
||||||
auto a = random::uniform({size, size});
|
auto a = mx::random::uniform({size, size});
|
||||||
auto b = random::uniform({size, size});
|
auto b = mx::random::uniform({size, size});
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
TIMEM("2D regular", add, a, b, device);
|
TIMEM("2D regular", mx::add, a, b, device);
|
||||||
|
|
||||||
b = transpose(b);
|
b = mx::transpose(b);
|
||||||
eval(b);
|
mx::eval(b);
|
||||||
TIMEM("2D transpose", add, a, b, device);
|
TIMEM("2D mx::transpose", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({size});
|
b = mx::random::uniform({size});
|
||||||
eval(b);
|
mx::eval(b);
|
||||||
TIMEM("2D broadcast dim 0", add, a, b, device);
|
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
|
||||||
|
|
||||||
b = reshape(b, {size, 1});
|
b = mx::reshape(b, {size, 1});
|
||||||
eval(b);
|
mx::eval(b);
|
||||||
TIMEM("2D broadcast dim 1", add, a, b, device);
|
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_binary_ops_3D() {
|
void time_irregular_binary_ops_3D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
int d0 = 32;
|
int d0 = 32;
|
||||||
int d1 = 512;
|
int d1 = 512;
|
||||||
int d2 = 512;
|
int d2 = 512;
|
||||||
auto a = random::uniform({d0, d1, d2});
|
auto a = mx::random::uniform({d0, d1, d2});
|
||||||
auto b = random::uniform({d0, d1, d2});
|
auto b = mx::random::uniform({d0, d1, d2});
|
||||||
TIMEM("3D regular", add, a, b, device);
|
TIMEM("3D regular", mx::add, a, b, device);
|
||||||
|
|
||||||
b = transpose(b, {0, 2, 1});
|
b = mx::transpose(b, {0, 2, 1});
|
||||||
TIMEM("3D transpose", add, a, b, device);
|
TIMEM("3D mx::transpose", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d1, d2});
|
b = mx::random::uniform({d1, d2});
|
||||||
TIMEM("3D broadcast dim 0", add, a, b, device);
|
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d0, 1, d2});
|
b = mx::random::uniform({d0, 1, d2});
|
||||||
TIMEM("3D broadcast dim 1", add, a, b, device);
|
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d0, d1, 1});
|
b = mx::random::uniform({d0, d1, 1});
|
||||||
TIMEM("3D broadcast dim 2", add, a, b, device);
|
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d2});
|
b = mx::random::uniform({d2});
|
||||||
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d1, 1});
|
b = mx::random::uniform({d1, 1});
|
||||||
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
|
||||||
|
|
||||||
b = random::uniform({d0, 1, 1});
|
b = mx::random::uniform({d0, 1, 1});
|
||||||
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_binary_ops_4D() {
|
void time_irregular_binary_ops_4D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
std::vector<int> shape = {8, 8, 512, 512};
|
std::vector<int> shape = {8, 8, 512, 512};
|
||||||
auto a = random::uniform(shape);
|
auto a = mx::random::uniform(shape);
|
||||||
auto b = random::uniform(shape);
|
auto b = mx::random::uniform(shape);
|
||||||
|
|
||||||
TIMEM("4D regular", add, a, b, device);
|
TIMEM("4D regular", mx::add, a, b, device);
|
||||||
|
|
||||||
b = transpose(b, {0, 1, 3, 2});
|
b = mx::transpose(b, {0, 1, 3, 2});
|
||||||
TIMEM("4D transpose", add, a, b, device);
|
TIMEM("4D mx::transpose", mx::add, a, b, device);
|
||||||
|
|
||||||
std::string om = "4D broadcast dims ";
|
std::string om = "4D broadcast dims ";
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
shape[i] = 1;
|
shape[i] = 1;
|
||||||
b = random::uniform(shape);
|
b = mx::random::uniform(shape);
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << om << i;
|
msg << om << i;
|
||||||
TIMEM(msg.str(), add, a, b, device);
|
TIMEM(msg.str(), mx::add, a, b, device);
|
||||||
|
|
||||||
for (int j = i + 1; j < shape.size(); ++j) {
|
for (int j = i + 1; j < shape.size(); ++j) {
|
||||||
shape[j] = 1;
|
shape[j] = 1;
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << om << i << ", " << j;
|
msg << om << i << ", " << j;
|
||||||
b = random::uniform(shape);
|
b = mx::random::uniform(shape);
|
||||||
TIMEM(msg.str(), add, a, b, device);
|
TIMEM(msg.str(), mx::add, a, b, device);
|
||||||
shape[j] = a.shape(j);
|
shape[j] = a.shape(j);
|
||||||
|
|
||||||
for (int k = j + 1; k < shape.size(); ++k) {
|
for (int k = j + 1; k < shape.size(); ++k) {
|
||||||
shape[k] = 1;
|
shape[k] = 1;
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << om << i << ", " << j << ", " << k;
|
msg << om << i << ", " << j << ", " << k;
|
||||||
b = random::uniform(shape);
|
b = mx::random::uniform(shape);
|
||||||
TIMEM(msg.str(), add, a, b, device);
|
TIMEM(msg.str(), mx::add, a, b, device);
|
||||||
shape[k] = a.shape(k);
|
shape[k] = a.shape(k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -113,83 +114,83 @@ void time_irregular_binary_ops_4D() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_reshape() {
|
void time_irregular_reshape() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
std::vector<int> shape;
|
std::vector<int> shape;
|
||||||
auto reshape_fn = [&shape, device](const array& a) {
|
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||||
return reshape(a, shape, device);
|
return mx::reshape(a, shape, device);
|
||||||
};
|
};
|
||||||
|
|
||||||
int size = 64;
|
int size = 64;
|
||||||
int d = 2 * size;
|
int d = 2 * size;
|
||||||
|
|
||||||
auto a = random::uniform({d, d, d});
|
auto a = mx::random::uniform({d, d, d});
|
||||||
|
|
||||||
shape = {8 * size, size, size};
|
shape = {8 * size, size, size};
|
||||||
TIMEM("3D contiguous", reshape_fn, a);
|
TIMEM("3D contiguous", reshape_fn, a);
|
||||||
|
|
||||||
a = transpose(a);
|
a = mx::transpose(a);
|
||||||
shape = {8 * size, size, size};
|
shape = {8 * size, size, size};
|
||||||
TIMEM("3D transpose", reshape_fn, a);
|
TIMEM("3D mx::transpose", reshape_fn, a);
|
||||||
|
|
||||||
a = transpose(a, {1, 2, 0});
|
a = mx::transpose(a, {1, 2, 0});
|
||||||
shape = {8 * size, size, size};
|
shape = {8 * size, size, size};
|
||||||
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
|
||||||
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
|
||||||
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
|
||||||
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
|
||||||
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
|
||||||
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
|
||||||
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
|
||||||
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_astype_1D() {
|
void time_irregular_astype_1D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
int size = 1000000;
|
int size = 1000000;
|
||||||
int step = 2;
|
int step = 2;
|
||||||
auto a = random::uniform({size});
|
auto a = mx::random::uniform({size});
|
||||||
a = slice(a, {0}, {size}, {step});
|
a = slice(a, {0}, {size}, {step});
|
||||||
TIMEM("1D strided", astype, a, int32, device);
|
TIMEM("1D strided", mx::astype, a, mx::int32, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_irregular_astype_2D() {
|
void time_irregular_astype_2D() {
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
int size = 2048;
|
int size = 2048;
|
||||||
std::vector<int> shape = {size, size};
|
std::vector<int> shape = {size, size};
|
||||||
|
|
||||||
auto a = random::uniform(shape);
|
auto a = mx::random::uniform(shape);
|
||||||
TIMEM("2D regular", astype, a, int32, device);
|
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||||
|
|
||||||
a = transpose(a);
|
a = mx::transpose(a);
|
||||||
TIMEM("2D transpose", astype, a, int32, device);
|
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({size}), shape);
|
a = mx::broadcast_to(mx::random::uniform({size}), shape);
|
||||||
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
|
||||||
|
|
||||||
a = broadcast_to(random::uniform({size, 1}), shape);
|
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
|
||||||
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
if (argc > 1) {
|
if (argc > 1) {
|
||||||
bool use_gpu = !strcmp(argv[1], "gpu");
|
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||||
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
|
||||||
}
|
}
|
||||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||||
time_irregular_binary_ops_1D();
|
time_irregular_binary_ops_1D();
|
||||||
time_irregular_binary_ops_2D();
|
time_irregular_binary_ops_2D();
|
||||||
time_irregular_binary_ops_3D();
|
time_irregular_binary_ops_3D();
|
||||||
|
|||||||
@@ -3,20 +3,20 @@
|
|||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
#include "time_utils.h"
|
#include "time_utils.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
void time_creation_ops() {
|
void time_creation_ops() {
|
||||||
int M = 2000;
|
int M = 2000;
|
||||||
int N = 500;
|
int N = 500;
|
||||||
auto shape = {M, N};
|
auto shape = {M, N};
|
||||||
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
|
||||||
TIME(full_fp32);
|
TIME(full_fp32);
|
||||||
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
|
||||||
TIME(zeros_fp32);
|
TIME(zeros_fp32);
|
||||||
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
|
||||||
TIME(ones_fp32);
|
TIME(ones_fp32);
|
||||||
|
|
||||||
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
|
||||||
TIME(arange_fp32);
|
TIME(arange_fp32);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,194 +24,212 @@ void time_type_conversions() {
|
|||||||
int M = 2000;
|
int M = 2000;
|
||||||
int N = 500;
|
int N = 500;
|
||||||
auto shape = {M, N};
|
auto shape = {M, N};
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
|
|
||||||
auto a = zeros(shape, float32);
|
auto a = mx::zeros(shape, mx::float32);
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
TIMEM("float32 to int32", astype, a, int32, device);
|
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
|
||||||
TIMEM("float32 to uint32", astype, a, uint32, device);
|
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||||
|
|
||||||
a = zeros(shape, int32);
|
a = mx::zeros(shape, mx::int32);
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
TIMEM("int32 to float32", astype, a, float32, device);
|
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
|
||||||
|
|
||||||
a = zeros(shape, bool_);
|
a = mx::zeros(shape, mx::bool_);
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
TIMEM("bool to float32", astype, a, float32, device);
|
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
|
||||||
TIMEM("bool to int32", astype, a, int32, device);
|
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
|
||||||
TIMEM("bool to uint32", astype, a, uint32, device);
|
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_random_generation() {
|
void time_random_generation() {
|
||||||
int M = 2000;
|
int M = 2000;
|
||||||
int N = 500;
|
int N = 500;
|
||||||
|
|
||||||
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
|
||||||
TIME(uniform);
|
TIME(uniform);
|
||||||
auto normal = [&]() { return random::normal({M, N}, float32); };
|
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
|
||||||
TIME(normal);
|
TIME(normal);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_unary_ops() {
|
void time_unary_ops() {
|
||||||
int M = 2000;
|
int M = 2000;
|
||||||
int N = 500;
|
int N = 500;
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
|
|
||||||
auto a = random::normal({M, N});
|
auto a = mx::random::normal({M, N});
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
TIME(mlx::core::abs, a, device);
|
TIME(mlx::core::abs, a, device);
|
||||||
TIME(negative, a, device);
|
TIME(mx::negative, a, device);
|
||||||
TIME(sign, a, device);
|
TIME(mx::sign, a, device);
|
||||||
TIME(square, a, device);
|
TIME(mx::square, a, device);
|
||||||
TIME(mlx::core::sqrt, a, device);
|
TIME(mlx::core::sqrt, a, device);
|
||||||
TIME(rsqrt, a, device);
|
TIME(mx::rsqrt, a, device);
|
||||||
TIME(mlx::core::exp, a, device);
|
TIME(mlx::core::exp, a, device);
|
||||||
|
|
||||||
a = random::uniform({M, N});
|
a = mx::random::uniform({M, N});
|
||||||
TIME(mlx::core::log, a, device);
|
TIME(mlx::core::log, a, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_binary_ops() {
|
void time_binary_ops() {
|
||||||
int M = 1000, N = 100, K = 10;
|
int M = 1000, N = 100, K = 10;
|
||||||
auto condition = random::randint(0, 2, {M, N, K});
|
auto condition = mx::random::randint(0, 2, {M, N, K});
|
||||||
auto a = random::uniform({M, N, K});
|
auto a = mx::random::uniform({M, N, K});
|
||||||
auto b = random::uniform({M, N, K});
|
auto b = mx::random::uniform({M, N, K});
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
|
|
||||||
TIME(add, a, b, device);
|
TIME(mx::add, a, b, device);
|
||||||
TIME(subtract, a, b, device);
|
TIME(mx::subtract, a, b, device);
|
||||||
TIME(multiply, a, b, device);
|
TIME(mx::multiply, a, b, device);
|
||||||
TIME(divide, a, b, device);
|
TIME(mx::divide, a, b, device);
|
||||||
TIME(maximum, a, b, device);
|
TIME(mx::maximum, a, b, device);
|
||||||
TIME(minimum, a, b, device);
|
TIME(mx::minimum, a, b, device);
|
||||||
TIME(where, condition, a, b, device);
|
TIME(mx::where, condition, a, b, device);
|
||||||
|
|
||||||
condition = array({true});
|
condition = mx::array({true});
|
||||||
b = random::uniform({1});
|
b = mx::random::uniform({1});
|
||||||
eval(b);
|
mx::eval(b);
|
||||||
TIMEM("scalar", add, a, b, device);
|
TIMEM("scalar", mx::add, a, b, device);
|
||||||
TIMEM("vector-scalar", subtract, a, b, device);
|
TIMEM("vector-scalar", mx::subtract, a, b, device);
|
||||||
TIMEM("scalar-vector", subtract, b, a, device);
|
TIMEM("scalar-vector", mx::subtract, b, a, device);
|
||||||
TIMEM("scalar", multiply, a, b, device);
|
TIMEM("scalar", mx::multiply, a, b, device);
|
||||||
TIMEM("vector-scalar", divide, a, b, device);
|
TIMEM("vector-scalar", mx::divide, a, b, device);
|
||||||
TIMEM("scalar-vector", divide, b, a, device);
|
TIMEM("scalar-vector", mx::divide, b, a, device);
|
||||||
TIMEM("scalar-vector", where, condition, a, b, device);
|
TIMEM("scalar-vector", mx::where, condition, a, b, device);
|
||||||
|
|
||||||
condition = broadcast_to(array({true}), {1000, 100});
|
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
|
||||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
|
||||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
|
||||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
|
||||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
|
||||||
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
|
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_strided_ops() {
|
void time_strided_ops() {
|
||||||
int M = 50, N = 50, O = 50, P = 50;
|
int M = 50, N = 50, O = 50, P = 50;
|
||||||
auto a = random::uniform({M, N, O, P});
|
auto a = mx::random::uniform({M, N, O, P});
|
||||||
auto b = random::uniform({M, N, O, P});
|
auto b = mx::random::uniform({M, N, O, P});
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
TIMEM("non-strided", add, a, b, device);
|
TIMEM("non-strided", mx::add, a, b, device);
|
||||||
a = transpose(a, {1, 0, 2, 3});
|
a = mx::transpose(a, {1, 0, 2, 3});
|
||||||
b = transpose(b, {3, 2, 0, 1});
|
b = mx::transpose(b, {3, 2, 0, 1});
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
TIMEM("strided", add, a, b, device);
|
TIMEM("strided", mx::add, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_comparisons() {
|
void time_comparisons() {
|
||||||
int M = 1000, N = 100, K = 10;
|
int M = 1000, N = 100, K = 10;
|
||||||
auto a = random::uniform({M, N, K});
|
auto a = mx::random::uniform({M, N, K});
|
||||||
auto b = random::uniform({M, N, K});
|
auto b = mx::random::uniform({M, N, K});
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
TIME(equal, a, b, device);
|
TIME(mx::equal, a, b, device);
|
||||||
TIME(greater, a, b, device);
|
TIME(mx::greater, a, b, device);
|
||||||
TIME(greater_equal, a, b, device);
|
TIME(mx::greater_equal, a, b, device);
|
||||||
TIME(less, a, b, device);
|
TIME(mx::less, a, b, device);
|
||||||
TIME(less_equal, a, b, device);
|
TIME(mx::less_equal, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_matvec() {
|
void time_matvec() {
|
||||||
int M = 2000, N = 200;
|
int M = 2000, N = 200;
|
||||||
auto a = random::uniform({M, N});
|
auto a = mx::random::uniform({M, N});
|
||||||
auto b = random::uniform({N});
|
auto b = mx::random::uniform({N});
|
||||||
auto c = random::uniform({M});
|
auto c = mx::random::uniform({M});
|
||||||
eval(a, b, c);
|
mx::eval(a, b, c);
|
||||||
auto matvec = [&]() { return matmul(a, b); };
|
auto matvec = [&]() { return mx::matmul(a, b); };
|
||||||
TIME(matvec);
|
TIME(matvec);
|
||||||
|
|
||||||
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
|
||||||
TIME(matvec_transpose);
|
TIME(matvec_transpose);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_matmul() {
|
void time_matmul() {
|
||||||
int M = 1000, N = 1000, K = 1000;
|
int M = 1000, N = 1000, K = 1000;
|
||||||
auto a = random::uniform({M, K});
|
auto a = mx::random::uniform({M, K});
|
||||||
auto b = random::uniform({K, N});
|
auto b = mx::random::uniform({K, N});
|
||||||
auto device = default_device();
|
auto device = mx::default_device();
|
||||||
eval(a, b);
|
mx::eval(a, b);
|
||||||
TIME(matmul, a, b, device);
|
TIME(mx::matmul, a, b, device);
|
||||||
|
|
||||||
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
|
||||||
TIME(transpose_matmul);
|
TIME(transpose_matmul);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_reductions() {
|
void time_reductions() {
|
||||||
auto a = random::normal({10000, 1000});
|
auto a = mx::random::normal({10000, 1000});
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
auto sum_all = [&a]() { return sum(a, false); };
|
auto sum_all = [&a]() { return mx::sum(a, false); };
|
||||||
TIME(sum_all);
|
TIME(sum_all);
|
||||||
|
|
||||||
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
|
||||||
TIME(sum_along_0);
|
TIME(sum_along_0);
|
||||||
|
|
||||||
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
|
||||||
TIME(sum_along_1);
|
TIME(sum_along_1);
|
||||||
|
|
||||||
auto prod_all = [&a]() { return prod(a, false); };
|
auto prod_all = [&a]() { return mx::prod(a, false); };
|
||||||
TIME(prod_all);
|
TIME(prod_all);
|
||||||
|
|
||||||
auto all_true = [&a]() { return all(a, false); };
|
auto all_true = [&a]() { return mx::all(a, false); };
|
||||||
TIME(all_true);
|
TIME(all_true);
|
||||||
|
|
||||||
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
|
||||||
TIME(all_along_0);
|
TIME(all_along_0);
|
||||||
|
|
||||||
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
|
||||||
TIME(all_along_1);
|
TIME(all_along_1);
|
||||||
|
|
||||||
auto any_true = [&a]() { return any(a, false); };
|
auto any_true = [&a]() { return mx::any(a, false); };
|
||||||
TIME(any_true);
|
TIME(any_true);
|
||||||
|
|
||||||
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
|
||||||
TIME(argmin_along_0);
|
TIME(argmin_along_0);
|
||||||
|
|
||||||
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||||
TIME(argmin_along_1);
|
TIME(argmin_along_1);
|
||||||
|
|
||||||
|
auto indices = mx::array({1});
|
||||||
|
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||||
|
std::vector<int> axes{0};
|
||||||
|
auto b = scatter(a, {indices}, updates, axes);
|
||||||
|
mx::eval(b);
|
||||||
|
|
||||||
|
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||||
|
TIME(max_along_0);
|
||||||
|
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||||
|
TIME(max_along_1);
|
||||||
|
|
||||||
|
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||||
|
TIME(min_along_0);
|
||||||
|
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||||
|
TIME(min_along_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_gather_scatter() {
|
void time_gather_scatter() {
|
||||||
auto a = random::normal({1000, 768});
|
auto a = mx::random::normal({1000, 768});
|
||||||
eval(a);
|
mx::eval(a);
|
||||||
auto indices = random::randint(0, 1000, {256});
|
auto indices = mx::random::randint(0, 1000, {256});
|
||||||
eval(indices);
|
mx::eval(indices);
|
||||||
|
|
||||||
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
|
||||||
TIME(embedding_lookup);
|
TIME(embedding_lookup);
|
||||||
|
|
||||||
indices = random::randint(0, 768 * 1000, {256 * 768});
|
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
|
||||||
eval(indices);
|
mx::eval(indices);
|
||||||
|
|
||||||
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
auto single_element_lookup = [&a, &indices]() {
|
||||||
|
return mx::take(a, indices);
|
||||||
|
};
|
||||||
TIME(single_element_lookup);
|
TIME(single_element_lookup);
|
||||||
|
|
||||||
indices = random::randint(0, 1000, {256});
|
indices = mx::random::randint(0, 1000, {256});
|
||||||
auto updates = random::normal({256, 1, 768});
|
auto updates = mx::random::normal({256, 1, 768});
|
||||||
eval(indices, updates);
|
mx::eval(indices, updates);
|
||||||
|
|
||||||
auto embedding_update = [&a, &indices, &updates]() {
|
auto embedding_update = [&a, &indices, &updates]() {
|
||||||
return scatter(a, indices, updates, 0);
|
return scatter(a, indices, updates, 0);
|
||||||
@@ -223,10 +241,10 @@ void time_gather_scatter() {
|
|||||||
};
|
};
|
||||||
TIME(embedding_add);
|
TIME(embedding_add);
|
||||||
|
|
||||||
a = reshape(a, {-1});
|
a = mx::reshape(a, {-1});
|
||||||
indices = random::randint(0, 768 * 1000, {768 * 256});
|
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
|
||||||
updates = random::normal({256 * 768, 1});
|
updates = mx::random::normal({256 * 768, 1});
|
||||||
eval(a, indices, updates);
|
mx::eval(a, indices, updates);
|
||||||
|
|
||||||
auto single_element_update = [&a, &indices, &updates]() {
|
auto single_element_update = [&a, &indices, &updates]() {
|
||||||
return scatter(a, indices, updates, 0);
|
return scatter(a, indices, updates, 0);
|
||||||
@@ -240,21 +258,21 @@ void time_gather_scatter() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void time_divmod() {
|
void time_divmod() {
|
||||||
auto a = random::normal({1000});
|
auto a = mx::random::normal({1000});
|
||||||
auto b = random::normal({1000});
|
auto b = mx::random::normal({1000});
|
||||||
eval({a, b});
|
mx::eval({a, b});
|
||||||
|
|
||||||
auto divmod_fused = [&a, &b]() { return divmod(a, b); };
|
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
|
||||||
TIME(divmod_fused);
|
TIME(divmod_fused);
|
||||||
|
|
||||||
auto divmod_separate = [&a, &b]() {
|
auto divmod_separate = [&a, &b]() {
|
||||||
return std::vector<array>{floor_divide(a, b), remainder(a, b)};
|
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
|
||||||
};
|
};
|
||||||
TIME(divmod_separate);
|
TIME(divmod_separate);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||||
time_creation_ops();
|
time_creation_ops();
|
||||||
time_type_conversions();
|
time_type_conversions();
|
||||||
time_unary_ops();
|
time_unary_ops();
|
||||||
|
|||||||
@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
|||||||
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
||||||
|
|
||||||
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
||||||
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
|
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
|
||||||
np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||||
|
|
||||||
dtypes = ("float32", "float16")
|
dtypes = ("float32", "float16", "complex64")
|
||||||
transposes = ("nn", "nt", "tn")
|
transposes = ("nn", "nt", "tn")
|
||||||
shapes = (
|
shapes = (
|
||||||
(16, 234, 768, 3072),
|
(16, 234, 768, 3072),
|
||||||
@@ -187,7 +185,7 @@ if __name__ == "__main__":
|
|||||||
diff = gflops_mx / gflops_pt - 1.0
|
diff = gflops_mx / gflops_pt - 1.0
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
|
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
|
||||||
)
|
)
|
||||||
if gflops_pt >= 2.0 * gflops_mx:
|
if gflops_pt >= 2.0 * gflops_mx:
|
||||||
print("ATTENTION ^^^^^^^")
|
print("ATTENTION ^^^^^^^")
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
|||||||
|
|
||||||
|
|
||||||
for transpose in (False, True):
|
for transpose in (False, True):
|
||||||
for dtype in ("float32", "float16"):
|
for dtype in ("float32", "float16", "complex64"):
|
||||||
fig, axs = plt.subplots(
|
fig, axs = plt.subplots(
|
||||||
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
||||||
)
|
)
|
||||||
@@ -215,7 +215,7 @@ for transpose in (False, True):
|
|||||||
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
||||||
fig.savefig(
|
fig.savefig(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
|
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|||||||
@@ -144,6 +144,13 @@ def reduction(op, axis, x):
|
|||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
mx.eval(z)
|
||||||
|
|
||||||
|
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
@@ -505,5 +512,8 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown benchmark")
|
raise ValueError("Unknown benchmark")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.cuda
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
@@ -44,8 +45,10 @@ def bench(f, *args):
|
|||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
def sync_if_needed(x):
|
||||||
if x.device != torch.device("cpu"):
|
if x.device == torch.device("mps"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
elif x.device == torch.device("cuda"):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
@@ -340,7 +351,11 @@ if __name__ == "__main__":
|
|||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "cpu" if args.cpu else "mps"
|
device = "mps"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
if args.cpu:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
types = args.dtype
|
types = args.dtype
|
||||||
if not types:
|
if not types:
|
||||||
@@ -460,5 +475,8 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
|||||||
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dtype = "float32"
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 21, 3, 3, 128),
|
||||||
|
(4, 32, 32, 21, 3, 3, 37),
|
||||||
|
(4, 32, 32, 370, 3, 3, 370),
|
||||||
|
(4, 32, 32, 370, 7, 7, 128),
|
||||||
|
(2, 320, 640, 21, 7, 7, 21),
|
||||||
|
)
|
||||||
|
for N, H, W, C, kh, kw, O in shapes:
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from time import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_mm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = x @ w1.T
|
||||||
|
x = x @ w2.T
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_mm()
|
||||||
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate(
|
||||||
|
[
|
||||||
|
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||||
|
for i, j in enumerate(idx.tolist())
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_qmm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||||
|
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_qmm()
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
@@ -10,32 +12,71 @@ def layer_norm(x, w, b, eps):
|
|||||||
x = x.astype(mx.float32)
|
x = x.astype(mx.float32)
|
||||||
mu = mx.mean(x, -1, keepdims=True)
|
mu = mx.mean(x, -1, keepdims=True)
|
||||||
v = mx.var(x, -1, keepdims=True)
|
v = mx.var(x, -1, keepdims=True)
|
||||||
return (x - mu) * mx.rsqrt(v + eps) * w + b
|
y = (x - mu) * mx.rsqrt(v + eps)
|
||||||
|
if w is not None:
|
||||||
|
y = y * w
|
||||||
|
if b is not None:
|
||||||
|
y = y + b
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
def time_layer_norm():
|
def time_layer_norm(N, dt):
|
||||||
|
L = 1024
|
||||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_loop(g, x, w, b):
|
def layer_norm_loop(f, x, w, b):
|
||||||
|
for _ in range(32):
|
||||||
|
x = f(x, w, b)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||||
|
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||||
|
|
||||||
|
def layer_norm_grad_loop(g, x, w, b):
|
||||||
gx, gw, gb = x, w, b
|
gx, gw, gb = x, w, b
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx, gw, gb = g(gx, gw, gb, y)
|
gx, gw, gb = g(gx, gw, gb, y)
|
||||||
return gx, gw, gb
|
return gx, gw, gb
|
||||||
|
|
||||||
time_fn(layer_norm_loop, g1, x, w, b)
|
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||||
time_fn(layer_norm_loop, g2, x, w, b)
|
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||||
|
|
||||||
|
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
|
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
|
def layer_norm_grad_x_loop(g, x):
|
||||||
|
gx = x
|
||||||
|
for _ in range(32):
|
||||||
|
gx = g(gx, y)
|
||||||
|
return gx
|
||||||
|
|
||||||
|
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||||
|
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||||
|
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||||
|
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
time_layer_norm()
|
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||||
|
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||||
|
print(dt, n)
|
||||||
|
time_layer_norm(n, dt)
|
||||||
|
|||||||
@@ -9,7 +9,10 @@ def rms_norm(x, w, eps):
|
|||||||
ot = x.dtype
|
ot = x.dtype
|
||||||
x = x.astype(mx.float32)
|
x = x.astype(mx.float32)
|
||||||
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||||
return (x * n).astype(ot) * w
|
y = (x * n).astype(ot)
|
||||||
|
if w is not None:
|
||||||
|
y = y * w
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
def time_rms_norm():
|
def time_rms_norm():
|
||||||
@@ -34,6 +37,27 @@ def time_rms_norm():
|
|||||||
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
||||||
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
||||||
|
|
||||||
|
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
|
||||||
|
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
|
||||||
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
mx.eval(x, w, y)
|
||||||
|
|
||||||
|
def rms_norm_loop(g, x):
|
||||||
|
gx = x
|
||||||
|
for _ in range(32):
|
||||||
|
gx = g(gx, y)
|
||||||
|
return gx
|
||||||
|
|
||||||
|
time_fn(rms_norm_loop, g1, x)
|
||||||
|
time_fn(rms_norm_loop, g2, x)
|
||||||
|
time_fn(rms_norm_loop, mx.compile(g1), x)
|
||||||
|
time_fn(rms_norm_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
time_rms_norm()
|
time_rms_norm()
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from time_utils import measure_runtime
|
|||||||
|
|
||||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||||
def scatter(dst, x, idx):
|
def scatter(dst, x, idx):
|
||||||
dst[*idx] = x
|
dst[tuple(idx)] = x
|
||||||
mx.eval(dst)
|
mx.eval(dst)
|
||||||
|
|
||||||
idx = []
|
idx = []
|
||||||
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
|||||||
|
|
||||||
|
|
||||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||||
def gather(dst, x, idx, device):
|
def scatter(dst, x, idx, device):
|
||||||
dst[*idx] = x
|
dst[tuple(idx)] = x
|
||||||
if device == torch.device("mps"):
|
if device == torch.device("mps"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
|
||||||
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
|||||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
|
||||||
print(f"PyTorch: {runtime:.3f}ms")
|
print(f"PyTorch: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
|
|||||||
(100_000, 64),
|
(100_000, 64),
|
||||||
(1_000_000, 64),
|
(1_000_000, 64),
|
||||||
(100_000,),
|
(100_000,),
|
||||||
(2_000_00,),
|
(200_000,),
|
||||||
(20_000_000,),
|
(20_000_000,),
|
||||||
(10000, 64),
|
(10000, 64),
|
||||||
(100, 64),
|
(100, 64),
|
||||||
@@ -91,6 +91,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||||
print("=" * 20)
|
print("=" * 20)
|
||||||
print(f"X {x_shape}, Indices {idx_shape}")
|
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
|
||||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
||||||
|
|||||||
@@ -1,62 +1,223 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
import numpy as np
|
||||||
|
|
||||||
MAX_SEQ = 300
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
START_SEQ = 100
|
device_name = device_name.decode("utf-8").strip("\n")
|
||||||
SEQ_INCREMENT = 50
|
|
||||||
|
N_warmup = 5
|
||||||
|
N_iter_bench = 40
|
||||||
|
N_iter_func = 8
|
||||||
|
|
||||||
|
|
||||||
def time_self_attention_primitives():
|
def bench(f, *args):
|
||||||
mx.random.seed(3)
|
for i in range(N_warmup):
|
||||||
B = 2
|
f(*args)
|
||||||
H = 38
|
|
||||||
D = 64
|
|
||||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
|
||||||
q = mx.random.uniform(shape=(B, H, R, D))
|
|
||||||
k = mx.random.uniform(shape=(B, H, R, D))
|
|
||||||
v = mx.random.uniform(shape=(B, H, R, D))
|
|
||||||
scale = 1.0 / math.sqrt(float(D))
|
|
||||||
mx.eval(q, k, v)
|
|
||||||
|
|
||||||
def sdpa_primitives(qs, ks, vs, alpha):
|
s = time.perf_counter_ns()
|
||||||
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
|
for i in range(N_iter_bench):
|
||||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
f(*args)
|
||||||
o = p @ vs
|
e = time.perf_counter_ns()
|
||||||
return o
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
time_fn(sdpa_primitives, q, k, v, scale)
|
|
||||||
|
|
||||||
|
|
||||||
def time_self_attention_sdpa():
|
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||||
mx.random.seed(3)
|
np_dtype = getattr(np, dtype)
|
||||||
B = 2
|
|
||||||
H = 38
|
|
||||||
D = 64
|
|
||||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
|
||||||
q = mx.random.uniform(shape=(B, H, R, D))
|
|
||||||
k = mx.random.uniform(shape=(B, H, R, D))
|
|
||||||
v = mx.random.uniform(shape=(B, H, R, D))
|
|
||||||
scale = 1.0 / math.sqrt(float(D))
|
|
||||||
mx.eval(q, k, v)
|
|
||||||
|
|
||||||
def sdpa_fused(qs, ks, vs, alpha):
|
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||||
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
|
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||||
return o
|
|
||||||
|
|
||||||
time_fn(sdpa_fused, q, k, v, scale)
|
scale = 1.0 / math.sqrt(D)
|
||||||
|
|
||||||
|
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||||
|
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
|
||||||
|
q_mx = mx.array(q_np)
|
||||||
|
k_mx = mx.array(k_np)
|
||||||
|
v_mx = mx.array(v_np)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask == "additive":
|
||||||
|
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
elif mask == "bool":
|
||||||
|
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
|
||||||
|
return q_mx, k_mx, v_mx, scale, mask
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||||
|
q_dtype = q.dtype
|
||||||
|
q = q * mx.array(scale, q_dtype)
|
||||||
|
n_q_heads = q.shape[-3]
|
||||||
|
n_kv_heads = k.shape[-3]
|
||||||
|
n_repeats = n_q_heads // n_kv_heads
|
||||||
|
|
||||||
|
B = q.shape[0]
|
||||||
|
L = q.shape[2]
|
||||||
|
kL = k.shape[2]
|
||||||
|
|
||||||
|
if n_repeats > 1:
|
||||||
|
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||||
|
k = mx.expand_dims(k, 2)
|
||||||
|
v = mx.expand_dims(v, 2)
|
||||||
|
|
||||||
|
scores = q @ mx.swapaxes(k, -1, -2)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
|
||||||
|
if mask == "causal":
|
||||||
|
q_offset = max(0, kL - L)
|
||||||
|
q_indices = mx.arange(q_offset, q_offset + L)
|
||||||
|
k_indices = mx.arange(kL)
|
||||||
|
mask = q_indices[:, None] >= k_indices[None]
|
||||||
|
|
||||||
|
if n_repeats > 1 and mask.ndim >= 3:
|
||||||
|
if mask.shape[-3] == 1:
|
||||||
|
mask = mx.expand_dims(mask, -3)
|
||||||
|
else:
|
||||||
|
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||||
|
|
||||||
|
if mask.dtype == mx.bool_:
|
||||||
|
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||||
|
else:
|
||||||
|
scores += mask
|
||||||
|
|
||||||
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||||
|
|
||||||
|
out = scores @ v
|
||||||
|
if n_repeats > 1:
|
||||||
|
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_fused_attn(q, k, v, scale, mask):
|
||||||
|
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||||
|
if transpose:
|
||||||
|
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||||
|
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||||
|
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||||
|
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||||
|
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||||
|
else:
|
||||||
|
return f(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
||||||
|
q_out = q
|
||||||
|
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
||||||
|
|
||||||
|
mx.eval(q_out)
|
||||||
|
return q_out
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(
|
||||||
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
||||||
|
):
|
||||||
|
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||||
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
time_mlx_unfused = bench(
|
||||||
|
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
time_mlx_fused = bench(
|
||||||
|
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
|
||||||
|
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
||||||
|
o_mlx_unfused = do_attention(
|
||||||
|
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
|
||||||
|
atol = 1e-5 if dtype == "float32" else 2e-4
|
||||||
|
|
||||||
|
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx_fused, time_mlx_unfused
|
||||||
|
|
||||||
|
|
||||||
|
def get_gflop_count(B, M, N, K):
|
||||||
|
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.gpu:
|
|
||||||
mx.set_default_device(mx.gpu)
|
|
||||||
else:
|
|
||||||
mx.set_default_device(mx.cpu)
|
|
||||||
|
|
||||||
time_self_attention_sdpa()
|
dtypes = ("float16", "float32")[:1]
|
||||||
time_self_attention_primitives()
|
transposes = (False,)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
shapes_64 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 32, 32, 64, 32, 32),
|
||||||
|
( 1, 64, 64, 64, 32, 32),
|
||||||
|
( 1, 128, 128, 64, 32, 32),
|
||||||
|
( 1, 256, 256, 64, 32, 32),
|
||||||
|
( 1, 512, 512, 64, 32, 32),
|
||||||
|
( 1, 1024, 1024, 64, 32, 8),
|
||||||
|
( 1, 2048, 2048, 64, 32, 8),
|
||||||
|
( 1, 4096, 4096, 64, 32, 8),
|
||||||
|
)
|
||||||
|
|
||||||
|
shapes_80 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 1024, 1024, 80, 32, 8),
|
||||||
|
( 1, 2048, 2048, 80, 32, 8),
|
||||||
|
( 1, 4096, 4096, 80, 32, 8),
|
||||||
|
)
|
||||||
|
|
||||||
|
shapes_128 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 1024, 1024, 128, 32, 8),
|
||||||
|
( 1, 2048, 2048, 128, 32, 8),
|
||||||
|
( 1, 4096, 4096, 128, 32, 8),
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
shapes = shapes_64 + shapes_80 + shapes_128
|
||||||
|
|
||||||
|
masks = [None, "bool", "causal"]
|
||||||
|
|
||||||
|
print(
|
||||||
|
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
for transpose in transposes:
|
||||||
|
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||||
|
for mask_in in masks:
|
||||||
|
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||||
|
B,
|
||||||
|
qsl,
|
||||||
|
ksl,
|
||||||
|
head_dim,
|
||||||
|
n_q_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
dtype,
|
||||||
|
transpose,
|
||||||
|
mask_in,
|
||||||
|
)
|
||||||
|
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||||
|
t_str = 1 if transpose else 0
|
||||||
|
print(
|
||||||
|
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,46 +4,92 @@ import math
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
L = 1024
|
L = 16384
|
||||||
H = 32
|
H = 32
|
||||||
H_k = 32 // 4
|
H_k = H // 4
|
||||||
D = 128
|
D = 128
|
||||||
|
V = 128
|
||||||
|
dtype = mx.float16
|
||||||
|
loops = 10
|
||||||
|
|
||||||
|
|
||||||
def attention(q, k, v):
|
def upproject(x, w):
|
||||||
|
if w is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x @ w.T
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q, k, v, mask=None, w=None):
|
||||||
|
def _sdpa(q, k, v):
|
||||||
B, Hq, L, D = q.shape
|
B, Hq, L, D = q.shape
|
||||||
_, Hk, S, _ = k.shape
|
_, Hk, S, _ = k.shape
|
||||||
|
_, _, _, V = v.shape
|
||||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||||
k = k[:, :, None, :, :]
|
k = k[:, :, None, :, :]
|
||||||
v = v[:, :, None, :, :]
|
v = v[:, :, None, :, :]
|
||||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||||
|
if mask is not None:
|
||||||
|
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
|
||||||
|
s = mx.where(m, s, mx.finfo(s.dtype).min)
|
||||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||||
o = p @ v
|
o = p @ v
|
||||||
return o.reshape(B, Hq, L, D)
|
return o.reshape(B, Hq, L, V)
|
||||||
|
|
||||||
|
for i in range(loops):
|
||||||
|
q = _sdpa(q, k, v)
|
||||||
|
q = upproject(q, w)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
def sdpa(q, k, v):
|
def sdpa(q, k, v, mask=None, w=None):
|
||||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
for i in range(loops):
|
||||||
|
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||||
|
q = upproject(q, w)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
def time_self_attention_primitives():
|
def time_self_attention_primitives():
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||||
mx.eval(q, k, v)
|
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||||
time_fn(attention, q, k, v)
|
mx.eval(q, k, v, w)
|
||||||
|
time_fn(attention, q, k, v, w=w)
|
||||||
|
|
||||||
|
|
||||||
def time_self_attention_sdpa():
|
def time_self_attention_sdpa():
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||||
mx.eval(q, k, v)
|
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||||
time_fn(sdpa, q, k, v)
|
mx.eval(q, k, v, w)
|
||||||
|
time_fn(sdpa, q, k, v, w=w)
|
||||||
|
|
||||||
|
|
||||||
|
def time_self_attention_sdpa_with_mask():
|
||||||
|
mx.random.seed(3)
|
||||||
|
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||||
|
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
|
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||||
|
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||||
|
mask = mx.full((L,), True)
|
||||||
|
mask[L // 2 :] = False
|
||||||
|
mx.eval(q, k, v, mask, w)
|
||||||
|
|
||||||
|
def sdpa_mask(*args):
|
||||||
|
return sdpa(*args, mask=mask, w=w)
|
||||||
|
|
||||||
|
def attention_mask(*args):
|
||||||
|
return attention(*args, mask=mask, w=w)
|
||||||
|
|
||||||
|
time_fn(attention_mask, q, k, v)
|
||||||
|
time_fn(sdpa_mask, q, k, v)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
time_self_attention_sdpa()
|
time_self_attention_sdpa()
|
||||||
time_self_attention_primitives()
|
time_self_attention_primitives()
|
||||||
|
time_self_attention_sdpa_with_mask()
|
||||||
|
|||||||
@@ -51,6 +51,20 @@ def time_maximum():
|
|||||||
time_fn(mx.maximum, a, b)
|
time_fn(mx.maximum, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_max():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.max, a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def time_min():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.min, a, 0)
|
||||||
|
|
||||||
|
|
||||||
def time_negative():
|
def time_negative():
|
||||||
a = mx.random.uniform(shape=(10000, 1000))
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
@@ -108,6 +122,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
time_add()
|
time_add()
|
||||||
time_matmul()
|
time_matmul()
|
||||||
|
time_min()
|
||||||
|
time_max()
|
||||||
time_maximum()
|
time_maximum()
|
||||||
time_exp()
|
time_exp()
|
||||||
time_negative()
|
time_negative()
|
||||||
|
|||||||
55
benchmarks/python/synchronize_bench.py
Normal file
55
benchmarks/python/synchronize_bench.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
rank = mx.distributed.init().rank()
|
||||||
|
|
||||||
|
|
||||||
|
def timeit(fn, a):
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(5):
|
||||||
|
mx.eval(fn(a))
|
||||||
|
|
||||||
|
its = 10
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(its):
|
||||||
|
mx.eval(fn(a))
|
||||||
|
toc = time.perf_counter()
|
||||||
|
ms = 1000 * (toc - tic) / its
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce_benchmark():
|
||||||
|
a = mx.ones((5, 5), mx.int32)
|
||||||
|
|
||||||
|
its_per_eval = 100
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
for _ in range(its_per_eval):
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
x = x - 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
ms = timeit(fn, a) / its_per_eval
|
||||||
|
if rank == 0:
|
||||||
|
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_benchmark():
|
||||||
|
a = mx.ones((5, 5), mx.int32)
|
||||||
|
its_per_eval = 100
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
for _ in range(its_per_eval):
|
||||||
|
x = mx.distributed.all_gather(x)[0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
ms = timeit(fn, a) / its_per_eval
|
||||||
|
if rank == 0:
|
||||||
|
print(f"All gather: time per iteration {ms:.6f} (ms)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
all_reduce_benchmark()
|
||||||
|
all_gather_benchmark()
|
||||||
54
cmake/FindNCCL.cmake
Normal file
54
cmake/FindNCCL.cmake
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
||||||
|
# directories.
|
||||||
|
|
||||||
|
set(NCCL_ROOT_DIR
|
||||||
|
$ENV{NCCL_ROOT_DIR}
|
||||||
|
CACHE PATH "Folder contains NVIDIA NCCL")
|
||||||
|
|
||||||
|
find_path(
|
||||||
|
NCCL_INCLUDE_DIRS
|
||||||
|
NAMES nccl.h
|
||||||
|
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
||||||
|
|
||||||
|
if($ENV{USE_STATIC_NCCL})
|
||||||
|
message(
|
||||||
|
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
||||||
|
set(NCCL_LIBNAME "libnccl_static.a")
|
||||||
|
else()
|
||||||
|
set(NCCL_LIBNAME "nccl")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_library(
|
||||||
|
NCCL_LIBRARIES
|
||||||
|
NAMES ${NCCL_LIBNAME}
|
||||||
|
HINTS ${NCCL_LIB_DIR}
|
||||||
|
${NCCL_ROOT_DIR}
|
||||||
|
${NCCL_ROOT_DIR}/lib
|
||||||
|
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
||||||
|
${NCCL_ROOT_DIR}/lib64
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
||||||
|
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
||||||
|
|
||||||
|
include(FindPackageHandleStandardArgs)
|
||||||
|
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
||||||
|
NCCL_LIBRARIES)
|
||||||
|
|
||||||
|
if(NCCL_FOUND)
|
||||||
|
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
||||||
|
message(
|
||||||
|
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
||||||
|
file(
|
||||||
|
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
||||||
|
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
||||||
|
LIMIT_COUNT 1)
|
||||||
|
if(NCCL_MAJOR_VERSION_DEFINED)
|
||||||
|
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
||||||
|
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
||||||
|
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
||||||
|
endif()
|
||||||
|
message(
|
||||||
|
STATUS
|
||||||
|
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
||||||
|
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
||||||
|
endif()
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
include(CMakeParseArguments)
|
include(CMakeParseArguments)
|
||||||
|
|
||||||
|
# clang format off
|
||||||
|
#
|
||||||
# ##############################################################################
|
# ##############################################################################
|
||||||
# Build metal library
|
# Build metal library
|
||||||
#
|
#
|
||||||
@@ -9,11 +11,14 @@ include(CMakeParseArguments)
|
|||||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||||
# files (like headers)
|
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
||||||
|
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
||||||
#
|
#
|
||||||
|
# clang format on
|
||||||
|
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
||||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
|
|
||||||
@@ -21,7 +26,11 @@ macro(mlx_build_metallib)
|
|||||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||||
|
|
||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||||
|
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
||||||
|
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
||||||
|
-frecord-sources)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
|
|||||||
CREATE_SUBDIRS = NO
|
CREATE_SUBDIRS = NO
|
||||||
FULL_PATH_NAMES = YES
|
FULL_PATH_NAMES = YES
|
||||||
RECURSIVE = YES
|
RECURSIVE = YES
|
||||||
GENERATE_HTML = YES
|
GENERATE_HTML = NO
|
||||||
GENERATE_LATEX = NO
|
GENERATE_LATEX = NO
|
||||||
GENERATE_XML = YES
|
GENERATE_XML = YES
|
||||||
XML_PROGRAMLISTING = YES
|
XML_PROGRAMLISTING = YES
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
sphinx
|
sphinx
|
||||||
breathe
|
breathe
|
||||||
sphinx-book-theme
|
sphinx-book-theme
|
||||||
|
sphinx-copybutton
|
||||||
mlx
|
mlx
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import mlx.core as mx
|
|||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "MLX"
|
project = "MLX"
|
||||||
copyright = "2023, MLX Contributors"
|
copyright = "2023, Apple"
|
||||||
author = "MLX Contributors"
|
author = "MLX Contributors"
|
||||||
version = ".".join(mx.__version__.split(".")[:3])
|
version = ".".join(mx.__version__.split(".")[:3])
|
||||||
release = version
|
release = version
|
||||||
@@ -18,6 +18,7 @@ release = version
|
|||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
extensions = [
|
extensions = [
|
||||||
|
"sphinx_copybutton",
|
||||||
"sphinx.ext.autodoc",
|
"sphinx.ext.autodoc",
|
||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
@@ -60,6 +61,7 @@ html_theme_options = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
html_favicon = html_theme_options["logo"]["image_light"]
|
||||||
|
|
||||||
# -- Options for HTMLHelp output ---------------------------------------------
|
# -- Options for HTMLHelp output ---------------------------------------------
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
.. _custom_metal_kernels:
|
||||||
|
|
||||||
Custom Metal Kernels
|
Custom Metal Kernels
|
||||||
====================
|
====================
|
||||||
|
|
||||||
@@ -6,11 +8,12 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
|||||||
Simple Example
|
Simple Example
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
T tmp = inp[elem];
|
T tmp = inp[elem];
|
||||||
@@ -23,6 +26,8 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -37,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
b = exp_elementwise(a)
|
b = exp_elementwise(a)
|
||||||
assert mx.allclose(b, mx.exp(a))
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
|
Every time you make a kernel, a new Metal library is created and possibly
|
||||||
|
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||||
|
:func:`fast.metal_kernel` and then use it many times.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
We are only required to pass the body of the Metal kernel in ``source``.
|
Only pass the body of the Metal kernel in ``source``. The function
|
||||||
|
signature is generated automatically.
|
||||||
|
|
||||||
The full function signature will be generated using:
|
The full function signature will be generated using:
|
||||||
|
|
||||||
@@ -76,25 +86,34 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
|||||||
|
|
||||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||||
|
|
||||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||||
|
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||||
|
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||||
|
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||||
|
dimension should be less than or equal to the corresponding grid dimension.
|
||||||
|
|
||||||
|
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||||
|
generated code for debugging purposes.
|
||||||
|
|
||||||
Using Shape/Strides
|
Using Shape/Strides
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
is ``True`` by default. This will copy the array inputs if needed
|
||||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
before the kernel is launched to ensure that the memory layout is row
|
||||||
when indexing.
|
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||||
|
have to worry about gaps or the ordering of the dims when indexing.
|
||||||
|
|
||||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||||
input array ``a`` if any are present in ``source``.
|
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||||
|
the right elements for each thread.
|
||||||
|
|
||||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||||
|
relying on a copy from ``ensure_row_contiguous``:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||||
@@ -108,8 +127,11 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
|
|||||||
name="myexp_strided",
|
name="myexp_strided",
|
||||||
input_names=["inp"],
|
input_names=["inp"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source
|
source=source,
|
||||||
|
ensure_row_contiguous=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -117,7 +139,6 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
|
|||||||
threadgroup=(256, 1, 1),
|
threadgroup=(256, 1, 1),
|
||||||
output_shapes=[a.shape],
|
output_shapes=[a.shape],
|
||||||
output_dtypes=[a.dtype],
|
output_dtypes=[a.dtype],
|
||||||
ensure_row_contiguous=False,
|
|
||||||
)
|
)
|
||||||
return outputs[0]
|
return outputs[0]
|
||||||
|
|
||||||
@@ -177,25 +198,13 @@ We'll start with the following MLX implementation using standard ops:
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||||
to write a fast GPU kernel for both the forward and backward passes.
|
to write a fast GPU kernel for both the forward and backward passes.
|
||||||
|
|
||||||
First we'll implement the forward pass as a fused kernel:
|
First we'll implement the forward pass as a fused kernel:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@mx.custom_function
|
|
||||||
def grid_sample(x, grid):
|
|
||||||
|
|
||||||
assert x.ndim == 4, "`x` must be 4D."
|
|
||||||
assert grid.ndim == 4, "`grid` must be 4D."
|
|
||||||
|
|
||||||
B, _, _, C = x.shape
|
|
||||||
_, gN, gM, D = grid.shape
|
|
||||||
out_shape = (B, gN, gM, C)
|
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
|
||||||
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
int H = x_shape[1];
|
int H = x_shape[1];
|
||||||
@@ -245,12 +254,26 @@ First we'll implement the forward pass as a fused kernel:
|
|||||||
|
|
||||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="grid_sample",
|
name="grid_sample",
|
||||||
input_names=["x", "grid"],
|
input_names=["x", "grid"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@mx.custom_function
|
||||||
|
def grid_sample(x, grid):
|
||||||
|
|
||||||
|
assert x.ndim == 4, "`x` must be 4D."
|
||||||
|
assert grid.ndim == 4, "`grid` must be 4D."
|
||||||
|
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
out_shape = (B, gN, gM, C)
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[x, grid],
|
inputs=[x, grid],
|
||||||
template=[("T", x.dtype)],
|
template=[("T", x.dtype)],
|
||||||
@@ -275,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
|
|||||||
Grid Sample VJP
|
Grid Sample VJP
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||||
its custom vjp transform so MLX can differentiate it.
|
define its custom vjp transform so MLX can differentiate it.
|
||||||
|
|
||||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
requires a few extra :func:`fast.metal_kernel` features:
|
||||||
|
|
||||||
* ``init_value=0``
|
* ``init_value=0``
|
||||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||||
@@ -293,14 +316,6 @@ We can then implement the backwards pass as follows:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@grid_sample.vjp
|
|
||||||
def grid_sample_vjp(primals, cotangent, _):
|
|
||||||
x, grid = primals
|
|
||||||
B, _, _, C = x.shape
|
|
||||||
_, gN, gM, D = grid.shape
|
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
|
||||||
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
int H = x_shape[1];
|
int H = x_shape[1];
|
||||||
@@ -400,6 +415,15 @@ We can then implement the backwards pass as follows:
|
|||||||
source=source,
|
source=source,
|
||||||
atomic_outputs=True,
|
atomic_outputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@grid_sample.vjp
|
||||||
|
def grid_sample_vjp(primals, cotangent, _):
|
||||||
|
x, grid = primals
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
# pad the output channels to simd group size
|
# pad the output channels to simd group size
|
||||||
# so that our `simd_sum`s don't overlap.
|
# so that our `simd_sum`s don't overlap.
|
||||||
simdgroup_size = 32
|
simdgroup_size = 32
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ You can do that in MLX directly:
|
|||||||
This function performs that operation while leaving the implementation and
|
This function performs that operation while leaving the implementation and
|
||||||
function transformations to MLX.
|
function transformations to MLX.
|
||||||
|
|
||||||
However you may need to customize the underlying implementation, perhaps to
|
However, you may want to customize the underlying implementation, perhaps to
|
||||||
make it faster or for custom differentiation. In this tutorial we will go
|
make it faster. In this tutorial we will go through adding custom extensions.
|
||||||
through adding custom extensions. It will cover:
|
It will cover:
|
||||||
|
|
||||||
* The structure of the MLX library.
|
* The structure of the MLX library.
|
||||||
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
|
* Implementing a CPU operation.
|
||||||
* Implementing a GPU operation using metal.
|
* Implementing a GPU operation using metal.
|
||||||
* Adding the ``vjp`` and ``jvp`` function transformation.
|
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||||
* Building a custom extension and binding it to python.
|
* Building a custom extension and binding it to python.
|
||||||
@@ -45,7 +45,7 @@ Operations
|
|||||||
Operations are the front-end functions that operate on arrays. They are defined
|
Operations are the front-end functions that operate on arrays. They are defined
|
||||||
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||||
|
|
||||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
|
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
|
||||||
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||||
C++:
|
C++:
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ C++:
|
|||||||
* Scale and sum two vectors element-wise
|
* Scale and sum two vectors element-wise
|
||||||
* z = alpha * x + beta * y
|
* z = alpha * x + beta * y
|
||||||
*
|
*
|
||||||
* Follow numpy style broadcasting between x and y
|
* Use NumPy-style broadcasting between x and y
|
||||||
* Inputs are upcasted to floats if needed
|
* Inputs are upcasted to floats if needed
|
||||||
**/
|
**/
|
||||||
array axpby(
|
array axpby(
|
||||||
@@ -66,7 +66,7 @@ C++:
|
|||||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||||
);
|
);
|
||||||
|
|
||||||
The simplest way to this operation is in terms of existing operations:
|
The simplest way to implement this is with existing operations:
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@@ -93,9 +93,9 @@ Primitives
|
|||||||
^^^^^^^^^^^
|
^^^^^^^^^^^
|
||||||
|
|
||||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||||
defines how to create outputs arrays given a input arrays. Further, a
|
defines how to create output arrays given input arrays. Further, a
|
||||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||||
more concrete:
|
more concrete:
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
@@ -128,7 +128,7 @@ more concrete:
|
|||||||
/** The vector-Jacobian product. */
|
/** The vector-Jacobian product. */
|
||||||
std::vector<array> vjp(
|
std::vector<array> vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const array& cotan,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
@@ -138,13 +138,13 @@ more concrete:
|
|||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
*/
|
*/
|
||||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
@@ -153,9 +153,6 @@ more concrete:
|
|||||||
private:
|
private:
|
||||||
float alpha_;
|
float alpha_;
|
||||||
float beta_;
|
float beta_;
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||||
@@ -188,7 +185,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
|||||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||||
|
|
||||||
// Upcast to float32 for non-floating point inputs x and y
|
// Upcast to float32 for non-floating point inputs x and y
|
||||||
auto out_dtype = is_floating_point(promoted_dtype)
|
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||||
? promoted_dtype
|
? promoted_dtype
|
||||||
: promote_types(promoted_dtype, float32);
|
: promote_types(promoted_dtype, float32);
|
||||||
|
|
||||||
@@ -234,11 +231,9 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
|||||||
Implementing the CPU Back-end
|
Implementing the CPU Back-end
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Let's start by implementing a naive and generic version of
|
Let's start by implementing :meth:`Axpby::eval_cpu`.
|
||||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
|
||||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
|
||||||
|
|
||||||
Our naive method will go over each element of the output array, find the
|
The method will go over each element of the output array, find the
|
||||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||||
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||||
|
|
||||||
@@ -246,36 +241,46 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void axpby_impl(
|
void axpby_impl(
|
||||||
const array& x,
|
const mx::array& x,
|
||||||
const array& y,
|
const mx::array& y,
|
||||||
array& out,
|
mx::array& out,
|
||||||
float alpha_,
|
float alpha_,
|
||||||
float beta_) {
|
float beta_,
|
||||||
// We only allocate memory when we are ready to fill the output
|
mx::Stream stream) {
|
||||||
// malloc_or_wait synchronously allocates available memory
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
// There may be a wait executed here if the allocation is requested
|
|
||||||
// under memory-pressured conditions
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// Collect input and output data pointers
|
// Get the CPU command encoder and register input and output arrays
|
||||||
const T* x_ptr = x.data<T>();
|
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||||
const T* y_ptr = y.data<T>();
|
encoder.set_input_array(x);
|
||||||
T* out_ptr = out.data<T>();
|
encoder.set_input_array(y);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
// Launch the CPU kernel
|
||||||
|
encoder.dispatch([x_ptr = x.data<T>(),
|
||||||
|
y_ptr = y.data<T>(),
|
||||||
|
out_ptr = out.data<T>(),
|
||||||
|
size = out.size(),
|
||||||
|
shape = out.shape(),
|
||||||
|
x_strides = x.strides(),
|
||||||
|
y_strides = y.strides(),
|
||||||
|
alpha_,
|
||||||
|
beta_]() {
|
||||||
|
|
||||||
// Cast alpha and beta to the relevant types
|
// Cast alpha and beta to the relevant types
|
||||||
T alpha = static_cast<T>(alpha_);
|
T alpha = static_cast<T>(alpha_);
|
||||||
T beta = static_cast<T>(beta_);
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
// Do the element-wise operation for each output
|
// Do the element-wise operation for each output
|
||||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||||
// Map linear indices to offsets in x and y
|
// Map linear indices to offsets in x and y
|
||||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||||
|
|
||||||
// We allocate the output to be contiguous and regularly strided
|
// We allocate the output to be contiguous and regularly strided
|
||||||
// (defaults to row major) and hence it doesn't need additional mapping
|
// (defaults to row major) and hence it doesn't need additional mapping
|
||||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Our implementation should work for all incoming floating point arrays.
|
Our implementation should work for all incoming floating point arrays.
|
||||||
@@ -284,112 +289,32 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
|||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
void Axpby::eval_cpu(
|
||||||
void Axpby::eval(
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<array>& inputs,
|
std::vector<mx::array>& outputs) {
|
||||||
const std::vector<array>& outputs) {
|
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Dispatch to the correct dtype
|
// Dispatch to the correct dtype
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == mx::float32) {
|
||||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == float16) {
|
} else if (out.dtype() == mx::float16) {
|
||||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == bfloat16) {
|
} else if (out.dtype() == mx::bfloat16) {
|
||||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == complex64) {
|
} else if (out.dtype() == mx::complex64) {
|
||||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[Axpby] Only supports floating point types.");
|
"Axpby is only supported for floating point types.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
This is good as a fallback implementation. We can use the ``axpby`` routine
|
|
||||||
provided by the Accelerate_ framework for a faster implementation in certain
|
|
||||||
cases:
|
|
||||||
|
|
||||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
|
||||||
floats. We can only use it for ``float32`` types.
|
|
||||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
|
|
||||||
elements have fixed strides between them. We only direct to Accelerate
|
|
||||||
if both ``x`` and ``y`` are row contiguous or column contiguous.
|
|
||||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
|
|
||||||
MLX expects to write the output to a new array. We must copy the elements
|
|
||||||
of ``y`` into the output and use that as an input to ``axpby``.
|
|
||||||
|
|
||||||
Let's write an implementation that uses Accelerate in the right conditions.
|
|
||||||
It allocates data for the output, copies ``y`` into it, and then calls the
|
|
||||||
:func:`catlas_saxpby` from accelerate.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void axpby_impl_accelerate(
|
|
||||||
const array& x,
|
|
||||||
const array& y,
|
|
||||||
array& out,
|
|
||||||
float alpha_,
|
|
||||||
float beta_) {
|
|
||||||
// Accelerate library provides catlas_saxpby which does
|
|
||||||
// Y = (alpha * X) + (beta * Y) in place
|
|
||||||
// To use it, we first copy the data in y over to the output array
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// We then copy over the elements using the contiguous vector specialization
|
|
||||||
copy_inplace(y, out, CopyType::Vector);
|
|
||||||
|
|
||||||
// Get x and y pointers for catlas_saxpby
|
|
||||||
const T* x_ptr = x.data<T>();
|
|
||||||
T* y_ptr = out.data<T>();
|
|
||||||
|
|
||||||
T alpha = static_cast<T>(alpha_);
|
|
||||||
T beta = static_cast<T>(beta_);
|
|
||||||
|
|
||||||
// Call the inplace accelerate operator
|
|
||||||
catlas_saxpby(
|
|
||||||
/* N = */ out.size(),
|
|
||||||
/* ALPHA = */ alpha,
|
|
||||||
/* X = */ x_ptr,
|
|
||||||
/* INCX = */ 1,
|
|
||||||
/* BETA = */ beta,
|
|
||||||
/* Y = */ y_ptr,
|
|
||||||
/* INCY = */ 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
For inputs that do not fit the criteria for accelerate, we fall back to
|
|
||||||
:meth:`Axpby::eval`. With this in mind, let's finish our
|
|
||||||
:meth:`Axpby::eval_cpu`.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
|
||||||
|
|
||||||
/** Evaluate primitive on CPU using accelerate specializations */
|
|
||||||
void Axpby::eval_cpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& x = inputs[0];
|
|
||||||
auto& y = inputs[1];
|
|
||||||
auto& out = outputs[0];
|
|
||||||
|
|
||||||
// Accelerate specialization for contiguous single precision float arrays
|
|
||||||
if (out.dtype() == float32 &&
|
|
||||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
|
||||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
|
||||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to common back-end if specializations are not available
|
|
||||||
eval(inputs, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||||
you do not plan on running the operation on the GPU or using transforms on
|
you do not plan on running the operation on the GPU or using transforms on
|
||||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
primitive here.
|
||||||
|
|
||||||
Implementing the GPU Back-end
|
Implementing the GPU Back-end
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
@@ -420,8 +345,8 @@ element in the output.
|
|||||||
constant const float& alpha [[buffer(3)]],
|
constant const float& alpha [[buffer(3)]],
|
||||||
constant const float& beta [[buffer(4)]],
|
constant const float& beta [[buffer(4)]],
|
||||||
constant const int* shape [[buffer(5)]],
|
constant const int* shape [[buffer(5)]],
|
||||||
constant const size_t* x_strides [[buffer(6)]],
|
constant const int64_t* x_strides [[buffer(6)]],
|
||||||
constant const size_t* y_strides [[buffer(7)]],
|
constant const int64_t* y_strides [[buffer(7)]],
|
||||||
constant const int& ndim [[buffer(8)]],
|
constant const int& ndim [[buffer(8)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
// Convert linear indices to offsets in array
|
// Convert linear indices to offsets in array
|
||||||
@@ -438,24 +363,10 @@ each instantiation a unique host name so we can identify it.
|
|||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
#define instantiate_axpby(type_name, type) \
|
instantiate_kernel("axpby_general_float32", axpby_general, float)
|
||||||
template [[host_name("axpby_general_" #type_name)]] \
|
instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
|
||||||
[[kernel]] void axpby_general<type>( \
|
instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
|
||||||
device const type* x [[buffer(0)]], \
|
instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
|
||||||
device const type* y [[buffer(1)]], \
|
|
||||||
device type* out [[buffer(2)]], \
|
|
||||||
constant const float& alpha [[buffer(3)]], \
|
|
||||||
constant const float& beta [[buffer(4)]], \
|
|
||||||
constant const int* shape [[buffer(5)]], \
|
|
||||||
constant const size_t* x_strides [[buffer(6)]], \
|
|
||||||
constant const size_t* y_strides [[buffer(7)]], \
|
|
||||||
constant const int& ndim [[buffer(8)]], \
|
|
||||||
uint index [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
instantiate_axpby(float32, float);
|
|
||||||
instantiate_axpby(float16, half);
|
|
||||||
instantiate_axpby(bfloat16, bfloat16_t);
|
|
||||||
instantiate_axpby(complex64, complex64_t);
|
|
||||||
|
|
||||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||||
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||||
@@ -480,21 +391,21 @@ below.
|
|||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
// Allocate output memory
|
// Allocate output memory
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
// Resolve name of kernel
|
// Resolve name of kernel
|
||||||
std::ostringstream kname;
|
std::stream kname;
|
||||||
kname << "axpby_" << "general_" << type_to_name(out);
|
kname = "axpby_general_" + type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Load the metal library
|
||||||
d.register_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname, lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Kernel parameters are registered with buffer indices corresponding to
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
// those in the kernel declaration at axpby.metal
|
// those in the kernel declaration at axpby.metal
|
||||||
@@ -509,14 +420,14 @@ below.
|
|||||||
compute_encoder.set_output_array(out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Encode alpha and beta
|
// Encode alpha and beta
|
||||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
compute_encoder.set_bytes(alpha_, 3);
|
||||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
compute_encoder.set_bytes(beta_, 4);
|
||||||
|
|
||||||
// Encode shape, strides and ndim
|
// Encode shape, strides and ndim
|
||||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
compute_encoder.set_bytes(y.strides(), 7);
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
compute_encoder.set_bytes(ndim, 8);
|
||||||
|
|
||||||
// We launch 1 thread for each input and make sure that the number of
|
// We launch 1 thread for each input and make sure that the number of
|
||||||
// threads in any given threadgroup is not higher than the max allowed
|
// threads in any given threadgroup is not higher than the max allowed
|
||||||
@@ -530,7 +441,7 @@ below.
|
|||||||
|
|
||||||
// Launch the grid with the given number of threads divided among
|
// Launch the grid with the given number of threads divided among
|
||||||
// the given threadgroups
|
// the given threadgroups
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||||
@@ -558,7 +469,7 @@ one we just defined:
|
|||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Forward mode diff that pushes along the tangents
|
// Forward mode diff that pushes along the tangents
|
||||||
// The jvp transform on the primitive can built with ops
|
// The jvp transform on the primitive can be built with ops
|
||||||
// that are scheduled on the same stream as the primitive
|
// that are scheduled on the same stream as the primitive
|
||||||
|
|
||||||
// If argnums = {0}, we only push along x in which case the
|
// If argnums = {0}, we only push along x in which case the
|
||||||
@@ -570,7 +481,7 @@ one we just defined:
|
|||||||
auto scale_arr = array(scale, tangents[0].dtype());
|
auto scale_arr = array(scale, tangents[0].dtype());
|
||||||
return {multiply(scale_arr, tangents[0], stream())};
|
return {multiply(scale_arr, tangents[0], stream())};
|
||||||
}
|
}
|
||||||
// If, argnums = {0, 1}, we take contributions from both
|
// If argnums = {0, 1}, we take contributions from both
|
||||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||||
else {
|
else {
|
||||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||||
@@ -824,7 +735,7 @@ Let's look at a simple script and its results:
|
|||||||
|
|
||||||
print(f"c shape: {c.shape}")
|
print(f"c shape: {c.shape}")
|
||||||
print(f"c dtype: {c.dtype}")
|
print(f"c dtype: {c.dtype}")
|
||||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
|
|
||||||
@@ -832,13 +743,13 @@ Output:
|
|||||||
|
|
||||||
c shape: [3, 4]
|
c shape: [3, 4]
|
||||||
c dtype: float32
|
c dtype: float32
|
||||||
c correctness: True
|
c is correct: True
|
||||||
|
|
||||||
Results
|
Results
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||||
with the naive :meth:`simple_axpby` we first defined on the CPU.
|
with the naive :meth:`simple_axpby` we first defined.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@@ -846,13 +757,11 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
|
|||||||
from mlx_sample_extensions import axpby
|
from mlx_sample_extensions import axpby
|
||||||
import time
|
import time
|
||||||
|
|
||||||
mx.set_default_device(mx.cpu)
|
|
||||||
|
|
||||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||||
return alpha * x + beta * y
|
return alpha * x + beta * y
|
||||||
|
|
||||||
M = 256
|
M = 4096
|
||||||
N = 512
|
N = 4096
|
||||||
|
|
||||||
x = mx.random.normal((M, N))
|
x = mx.random.normal((M, N))
|
||||||
y = mx.random.normal((M, N))
|
y = mx.random.normal((M, N))
|
||||||
@@ -863,24 +772,24 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
|
|||||||
|
|
||||||
def bench(f):
|
def bench(f):
|
||||||
# Warm up
|
# Warm up
|
||||||
for i in range(100):
|
for i in range(5):
|
||||||
z = f(x, y, alpha, beta)
|
z = f(x, y, alpha, beta)
|
||||||
mx.eval(z)
|
mx.eval(z)
|
||||||
|
|
||||||
# Timed run
|
# Timed run
|
||||||
s = time.time()
|
s = time.time()
|
||||||
for i in range(5000):
|
for i in range(100):
|
||||||
z = f(x, y, alpha, beta)
|
z = f(x, y, alpha, beta)
|
||||||
mx.eval(z)
|
mx.eval(z)
|
||||||
e = time.time()
|
e = time.time()
|
||||||
return e - s
|
return 1000 * (e - s) / 100
|
||||||
|
|
||||||
simple_time = bench(simple_axpby)
|
simple_time = bench(simple_axpby)
|
||||||
custom_time = bench(axpby)
|
custom_time = bench(axpby)
|
||||||
|
|
||||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
|
||||||
|
|
||||||
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
|
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
|
||||||
modest improvements right away!
|
modest improvements right away!
|
||||||
|
|
||||||
This operation is now good to be used to build other operations, in
|
This operation is now good to be used to build other operations, in
|
||||||
|
|||||||
121
docs/src/dev/mlx_in_cpp.rst
Normal file
121
docs/src/dev/mlx_in_cpp.rst
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
.. _mlx_in_cpp:
|
||||||
|
|
||||||
|
Using MLX in C++
|
||||||
|
================
|
||||||
|
|
||||||
|
You can use MLX in a C++ project with CMake.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
This guide is based one the following `example using MLX in C++
|
||||||
|
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
|
||||||
|
|
||||||
|
First install MLX:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install -U mlx
|
||||||
|
|
||||||
|
You can also install the MLX Python package from source or just the C++
|
||||||
|
library. For more information see the :ref:`documentation on installing MLX
|
||||||
|
<build_and_install>`.
|
||||||
|
|
||||||
|
Next make an example program in ``example.cpp``:
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
auto x = mx::array({1, 2, 3});
|
||||||
|
auto y = mx::array({1, 2, 3});
|
||||||
|
std::cout << x + y << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
The next step is to setup a CMake file in ``CMakeLists.txt``:
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
cmake_minimum_required(VERSION 3.27)
|
||||||
|
|
||||||
|
project(example LANGUAGES CXX)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
|
|
||||||
|
Depending on how you installed MLX, you may need to tell CMake where to
|
||||||
|
find it.
|
||||||
|
|
||||||
|
If you installed MLX with Python, then add the following to the CMake file:
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
find_package(
|
||||||
|
Python 3.9
|
||||||
|
COMPONENTS Interpreter Development.Module
|
||||||
|
REQUIRED)
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE MLX_ROOT)
|
||||||
|
|
||||||
|
If you installed the MLX C++ package to a system path, then CMake should be
|
||||||
|
able to find it. If you installed it to a non-standard location or CMake can't
|
||||||
|
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
set(MLX_ROOT "/path/to/mlx/")
|
||||||
|
|
||||||
|
Next, instruct CMake to find MLX:
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
find_package(MLX CONFIG REQUIRED)
|
||||||
|
|
||||||
|
Finally, add the ``example.cpp`` program as an executable and link MLX.
|
||||||
|
|
||||||
|
.. code-block:: cmake
|
||||||
|
|
||||||
|
add_executable(example example.cpp)
|
||||||
|
target_link_libraries(example PRIVATE mlx)
|
||||||
|
|
||||||
|
You can build the example with:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||||
|
cmake --build build
|
||||||
|
|
||||||
|
And run it with:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./build/example
|
||||||
|
|
||||||
|
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
|
||||||
|
|
||||||
|
.. list-table:: Package Variables
|
||||||
|
:widths: 20 20
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Variable
|
||||||
|
- Description
|
||||||
|
* - MLX_FOUND
|
||||||
|
- ``True`` if MLX is found
|
||||||
|
* - MLX_INCLUDE_DIRS
|
||||||
|
- Include directory
|
||||||
|
* - MLX_LIBRARIES
|
||||||
|
- Libraries to link against
|
||||||
|
* - MLX_CXX_FLAGS
|
||||||
|
- Additional compiler flags
|
||||||
|
* - MLX_BUILD_ACCELERATE
|
||||||
|
- ``True`` if MLX was built with Accelerate
|
||||||
|
* - MLX_BUILD_METAL
|
||||||
|
- ``True`` if MLX was built with Metal
|
||||||
@@ -45,6 +45,7 @@ are the CPU and GPU.
|
|||||||
usage/numpy
|
usage/numpy
|
||||||
usage/distributed
|
usage/distributed
|
||||||
usage/using_streams
|
usage/using_streams
|
||||||
|
usage/export
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:caption: Examples
|
:caption: Examples
|
||||||
@@ -61,6 +62,7 @@ are the CPU and GPU.
|
|||||||
python/array
|
python/array
|
||||||
python/data_types
|
python/data_types
|
||||||
python/devices_and_streams
|
python/devices_and_streams
|
||||||
|
python/export
|
||||||
python/ops
|
python/ops
|
||||||
python/random
|
python/random
|
||||||
python/transforms
|
python/transforms
|
||||||
@@ -68,6 +70,8 @@ are the CPU and GPU.
|
|||||||
python/fft
|
python/fft
|
||||||
python/linalg
|
python/linalg
|
||||||
python/metal
|
python/metal
|
||||||
|
python/cuda
|
||||||
|
python/memory_management
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
python/distributed
|
python/distributed
|
||||||
@@ -86,3 +90,4 @@ are the CPU and GPU.
|
|||||||
dev/extensions
|
dev/extensions
|
||||||
dev/metal_debugger
|
dev/metal_debugger
|
||||||
dev/custom_metal_kernels
|
dev/custom_metal_kernels
|
||||||
|
dev/mlx_in_cpp
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
.. _build_and_install:
|
||||||
|
|
||||||
Build and Install
|
Build and Install
|
||||||
=================
|
=================
|
||||||
|
|
||||||
@@ -11,22 +13,49 @@ silicon computer is
|
|||||||
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
|
|
||||||
To install from PyPI you must meet the following requirements:
|
To install from PyPI your system must meet the following requirements:
|
||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.9
|
- Using a native Python >= 3.10
|
||||||
- macOS >= 13.5
|
- macOS >= 13.5
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
MLX is only available on devices running macOS >= 13.5
|
MLX is only available on devices running macOS >= 13.5
|
||||||
It is highly recommended to use macOS 14 (Sonoma)
|
It is highly recommended to use macOS 14 (Sonoma)
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
MLX is also available on conda-forge. To install MLX with conda do:
|
MLX has a CUDA backend which you can install with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
conda install conda-forge::mlx
|
pip install mlx[cuda]
|
||||||
|
|
||||||
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Nvidia architecture >= SM 7.0 (Volta)
|
||||||
|
- Nvidia driver >= 550.54.14
|
||||||
|
- CUDA toolkit >= 12.0
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.10
|
||||||
|
|
||||||
|
|
||||||
|
CPU-only (Linux)
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
For a CPU-only version of MLX that runs on Linux use:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install mlx[cpu]
|
||||||
|
|
||||||
|
To install the CPU-only package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.10
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
@@ -53,7 +82,7 @@ Build Requirements
|
|||||||
^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
|
||||||
- Xcode >= 15.0 and macOS SDK >= 14.0
|
- Xcode >= 15.0 and macOS SDK >= 14.0
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@@ -63,6 +92,8 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
.. _python install:
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@@ -74,20 +105,20 @@ Then simply build and install MLX using pip:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
pip install .
|
||||||
|
|
||||||
For developing, install the package with development dependencies, and use an
|
For developing, install the package with development dependencies, and use an
|
||||||
editable install:
|
editable install:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
Once the development dependencies are installed, you can build faster with:
|
Once the development dependencies are installed, you can build faster with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
python setup.py build_ext --inplace
|
||||||
|
|
||||||
Run the tests with:
|
Run the tests with:
|
||||||
|
|
||||||
@@ -105,6 +136,8 @@ IDE:
|
|||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
.. _cpp install:
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@@ -183,6 +216,7 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -209,7 +243,51 @@ Metal library by run-time compiling kernels the first time they are used in MLX
|
|||||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||||
application. Once a kernel is compiled, it will be cached by the system. The
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
Metal kernel cache persists accross reboots.
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
|
Linux
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||||
|
For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
apt-get update -y
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
From here follow the instructions to install either the :ref:`Python <python
|
||||||
|
install>` or :ref:`C++ <cpp install>` APIs.
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||||
|
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
apt-get update -y
|
||||||
|
apt-get -y install cuda-toolkit-12-9
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev libcudnn9-dev-cuda-12 -y
|
||||||
|
|
||||||
|
|
||||||
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||||
|
|
||||||
|
To build the C++ package run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
@@ -240,7 +318,7 @@ x86 Shell
|
|||||||
|
|
||||||
.. _build shell:
|
.. _build shell:
|
||||||
|
|
||||||
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||||
Rosetta instead of natively.
|
Rosetta instead of natively.
|
||||||
|
|
||||||
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
||||||
@@ -264,4 +342,4 @@ Also check that cmake is using the correct architecture:
|
|||||||
|
|
||||||
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
||||||
but the build errors out with "Building for x86_64 on macOS is not supported."
|
but the build errors out with "Building for x86_64 on macOS is not supported."
|
||||||
wipe your build cahce with ``rm -rf build/`` and try again.
|
wipe your build cache with ``rm -rf build/`` and try again.
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ Array
|
|||||||
array.ndim
|
array.ndim
|
||||||
array.shape
|
array.shape
|
||||||
array.size
|
array.size
|
||||||
|
array.real
|
||||||
|
array.imag
|
||||||
array.abs
|
array.abs
|
||||||
array.all
|
array.all
|
||||||
array.any
|
array.any
|
||||||
@@ -38,6 +40,7 @@ Array
|
|||||||
array.log10
|
array.log10
|
||||||
array.log1p
|
array.log1p
|
||||||
array.log2
|
array.log2
|
||||||
|
array.logcumsumexp
|
||||||
array.logsumexp
|
array.logsumexp
|
||||||
array.max
|
array.max
|
||||||
array.mean
|
array.mean
|
||||||
|
|||||||
9
docs/src/python/cuda.rst
Normal file
9
docs/src/python/cuda.rst
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
CUDA
|
||||||
|
=====
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.cuda
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
is_available
|
||||||
@@ -51,11 +51,20 @@ The default floating point type is ``float32`` and the default integer type is
|
|||||||
* - ``float32``
|
* - ``float32``
|
||||||
- 4
|
- 4
|
||||||
- 32-bit float
|
- 32-bit float
|
||||||
|
* - ``float64``
|
||||||
|
- 4
|
||||||
|
- 64-bit double
|
||||||
* - ``complex64``
|
* - ``complex64``
|
||||||
- 8
|
- 8
|
||||||
- 64-bit complex float
|
- 64-bit complex float
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Arrays with type ``float64`` only work with CPU operations. Using
|
||||||
|
``float64`` arrays on the GPU will result in an exception.
|
||||||
|
|
||||||
|
|
||||||
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||||
documentation for more information. Use :func:`issubdtype` to determine if one
|
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||||
``dtype`` (or category) is a subtype of another category.
|
``dtype`` (or category) is a subtype of another category.
|
||||||
@@ -66,3 +75,4 @@ documentation for more information. Use :func:`issubdtype` to determine if one
|
|||||||
Dtype
|
Dtype
|
||||||
DtypeCategory
|
DtypeCategory
|
||||||
issubdtype
|
issubdtype
|
||||||
|
finfo
|
||||||
|
|||||||
14
docs/src/python/export.rst
Normal file
14
docs/src/python/export.rst
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
.. _export:
|
||||||
|
|
||||||
|
Export Functions
|
||||||
|
================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
export_function
|
||||||
|
import_function
|
||||||
|
exporter
|
||||||
|
export_to_dot
|
||||||
@@ -12,5 +12,5 @@ Fast
|
|||||||
layer_norm
|
layer_norm
|
||||||
rope
|
rope
|
||||||
scaled_dot_product_attention
|
scaled_dot_product_attention
|
||||||
affine_quantize
|
|
||||||
metal_kernel
|
metal_kernel
|
||||||
|
cuda_kernel
|
||||||
|
|||||||
@@ -20,3 +20,5 @@ FFT
|
|||||||
irfft2
|
irfft2
|
||||||
rfftn
|
rfftn
|
||||||
irfftn
|
irfftn
|
||||||
|
fftshift
|
||||||
|
ifftshift
|
||||||
|
|||||||
@@ -16,3 +16,12 @@ Linear Algebra
|
|||||||
cross
|
cross
|
||||||
qr
|
qr
|
||||||
svd
|
svd
|
||||||
|
eigvals
|
||||||
|
eig
|
||||||
|
eigvalsh
|
||||||
|
eigh
|
||||||
|
lu
|
||||||
|
lu_factor
|
||||||
|
pinv
|
||||||
|
solve
|
||||||
|
solve_triangular
|
||||||
|
|||||||
16
docs/src/python/memory_management.rst
Normal file
16
docs/src/python/memory_management.rst
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
Memory Management
|
||||||
|
=================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
get_active_memory
|
||||||
|
get_peak_memory
|
||||||
|
reset_peak_memory
|
||||||
|
get_cache_memory
|
||||||
|
set_memory_limit
|
||||||
|
set_cache_limit
|
||||||
|
set_wired_limit
|
||||||
|
clear_cache
|
||||||
@@ -8,12 +8,5 @@ Metal
|
|||||||
|
|
||||||
is_available
|
is_available
|
||||||
device_info
|
device_info
|
||||||
get_active_memory
|
|
||||||
get_peak_memory
|
|
||||||
reset_peak_memory
|
|
||||||
get_cache_memory
|
|
||||||
set_memory_limit
|
|
||||||
set_cache_limit
|
|
||||||
clear_cache
|
|
||||||
start_capture
|
start_capture
|
||||||
stop_capture
|
stop_capture
|
||||||
|
|||||||
@@ -174,6 +174,7 @@ In detail:
|
|||||||
|
|
||||||
value_and_grad
|
value_and_grad
|
||||||
quantize
|
quantize
|
||||||
|
average_gradients
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ simple functions.
|
|||||||
mish
|
mish
|
||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
|
relu2
|
||||||
relu6
|
relu6
|
||||||
selu
|
selu
|
||||||
sigmoid
|
sigmoid
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ Layers
|
|||||||
ALiBi
|
ALiBi
|
||||||
AvgPool1d
|
AvgPool1d
|
||||||
AvgPool2d
|
AvgPool2d
|
||||||
|
AvgPool3d
|
||||||
BatchNorm
|
BatchNorm
|
||||||
CELU
|
CELU
|
||||||
Conv1d
|
Conv1d
|
||||||
@@ -41,6 +42,7 @@ Layers
|
|||||||
LSTM
|
LSTM
|
||||||
MaxPool1d
|
MaxPool1d
|
||||||
MaxPool2d
|
MaxPool2d
|
||||||
|
MaxPool3d
|
||||||
Mish
|
Mish
|
||||||
MultiHeadAttention
|
MultiHeadAttention
|
||||||
PReLU
|
PReLU
|
||||||
@@ -48,6 +50,7 @@ Layers
|
|||||||
QuantizedLinear
|
QuantizedLinear
|
||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
|
ReLU2
|
||||||
ReLU6
|
ReLU6
|
||||||
RNN
|
RNN
|
||||||
RoPE
|
RoPE
|
||||||
|
|||||||
@@ -32,13 +32,16 @@ Operations
|
|||||||
atleast_2d
|
atleast_2d
|
||||||
atleast_3d
|
atleast_3d
|
||||||
bitwise_and
|
bitwise_and
|
||||||
|
bitwise_invert
|
||||||
bitwise_or
|
bitwise_or
|
||||||
bitwise_xor
|
bitwise_xor
|
||||||
block_masked_mm
|
block_masked_mm
|
||||||
|
broadcast_arrays
|
||||||
broadcast_to
|
broadcast_to
|
||||||
ceil
|
ceil
|
||||||
clip
|
clip
|
||||||
concatenate
|
concatenate
|
||||||
|
contiguous
|
||||||
conj
|
conj
|
||||||
conjugate
|
conjugate
|
||||||
convolve
|
convolve
|
||||||
@@ -89,6 +92,7 @@ Operations
|
|||||||
isneginf
|
isneginf
|
||||||
isposinf
|
isposinf
|
||||||
issubdtype
|
issubdtype
|
||||||
|
kron
|
||||||
left_shift
|
left_shift
|
||||||
less
|
less
|
||||||
less_equal
|
less_equal
|
||||||
@@ -99,6 +103,7 @@ Operations
|
|||||||
log10
|
log10
|
||||||
log1p
|
log1p
|
||||||
logaddexp
|
logaddexp
|
||||||
|
logcumsumexp
|
||||||
logical_not
|
logical_not
|
||||||
logical_and
|
logical_and
|
||||||
logical_or
|
logical_or
|
||||||
@@ -107,6 +112,7 @@ Operations
|
|||||||
max
|
max
|
||||||
maximum
|
maximum
|
||||||
mean
|
mean
|
||||||
|
median
|
||||||
meshgrid
|
meshgrid
|
||||||
min
|
min
|
||||||
minimum
|
minimum
|
||||||
@@ -144,6 +150,8 @@ Operations
|
|||||||
sign
|
sign
|
||||||
sin
|
sin
|
||||||
sinh
|
sinh
|
||||||
|
slice
|
||||||
|
slice_update
|
||||||
softmax
|
softmax
|
||||||
sort
|
sort
|
||||||
split
|
split
|
||||||
@@ -168,6 +176,7 @@ Operations
|
|||||||
tri
|
tri
|
||||||
tril
|
tril
|
||||||
triu
|
triu
|
||||||
|
unflatten
|
||||||
var
|
var
|
||||||
view
|
view
|
||||||
where
|
where
|
||||||
|
|||||||
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
# Save the state
|
# Save the state
|
||||||
state = tree_flatten(optimizer.state)
|
state = tree_flatten(optimizer.state, destination={})
|
||||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
mx.save_safetensors("optimizer.safetensors", state)
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
# Later on, for example when loading from a checkpoint,
|
||||||
# recreate the optimizer and load the state
|
# recreate the optimizer and load the state
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||||
optimizer.state = state
|
optimizer.state = state
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
|||||||
@@ -18,3 +18,5 @@ Common Optimizers
|
|||||||
AdamW
|
AdamW
|
||||||
Adamax
|
Adamax
|
||||||
Lion
|
Lion
|
||||||
|
MultiOptimizer
|
||||||
|
Muon
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ Transforms
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
eval
|
eval
|
||||||
|
async_eval
|
||||||
compile
|
compile
|
||||||
custom_function
|
custom_function
|
||||||
disable_compile
|
disable_compile
|
||||||
|
|||||||
@@ -130,19 +130,12 @@ Now make an array, and benchmark both functions:
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(32, 1000, 4096))
|
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||||
timeit(nn.gelu, x)
|
timeit(gelu, x)
|
||||||
timeit(mx.compile(nn.gelu), x)
|
timeit(mx.compile(gelu), x)
|
||||||
|
|
||||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||||
five times faster.
|
five times faster.
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
|
|
||||||
functions can still be helpful, but won't typically result in as large a
|
|
||||||
speedup as compiling operations that run on the GPU.
|
|
||||||
|
|
||||||
|
|
||||||
Debugging
|
Debugging
|
||||||
---------
|
---------
|
||||||
|
|
||||||
@@ -232,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
|
|||||||
def fun(x, y):
|
def fun(x, y):
|
||||||
z = x + y
|
z = x + y
|
||||||
state.append(z)
|
state.append(z)
|
||||||
return mx.exp(z), state
|
return mx.exp(z)
|
||||||
|
|
||||||
fun(mx.array(1.0), mx.array(2.0))
|
fun(mx.array(1.0), mx.array(2.0))
|
||||||
# Prints [array(3, dtype=float32)]
|
# Prints [array(3, dtype=float32)]
|
||||||
@@ -428,3 +421,77 @@ the most opportunity to optimize the computation graph:
|
|||||||
# Compiling the outer function is good to do as it will likely
|
# Compiling the outer function is good to do as it will likely
|
||||||
# be faster even though the inner functions are compiled
|
# be faster even though the inner functions are compiled
|
||||||
fun = mx.compile(outer)
|
fun = mx.compile(outer)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
.. _shapeless_compile:
|
||||||
|
|
||||||
|
Shapeless Compilation
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
When the shape of an input to a compiled function changes, the function is
|
||||||
|
recompiled. You can compile a function once and run it on inputs with
|
||||||
|
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
|
||||||
|
case changes to the shapes of the inputs do not cause the function to be
|
||||||
|
recompiled.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return mx.abs(x + y)
|
||||||
|
|
||||||
|
compiled_fun = mx.compile(fun, shapeless=True)
|
||||||
|
|
||||||
|
x = mx.array(1.0)
|
||||||
|
y = mx.array(-2.0)
|
||||||
|
|
||||||
|
# Firt call compiles the function
|
||||||
|
print(compiled_fun(x, y))
|
||||||
|
|
||||||
|
# Second call with different shapes
|
||||||
|
# does not recompile the function
|
||||||
|
x = mx.array([1.0, -6.0])
|
||||||
|
y = mx.array([-2.0, 3.0])
|
||||||
|
print(compiled_fun(x, y))
|
||||||
|
|
||||||
|
|
||||||
|
Use shapeless compilations carefully. Since compilation is not triggered when
|
||||||
|
shapes change, any graphs which are conditional on the input shapes will not
|
||||||
|
work as expected. Shape-dependent computations are common and sometimes subtle
|
||||||
|
to detect. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
return x.reshape(x.shape[0] * x.shape[1], -1)
|
||||||
|
|
||||||
|
compiled_fun = mx.compile(fun, shapeless=True)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 3, 4))
|
||||||
|
|
||||||
|
out = compiled_fun(x)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(5, 5, 3))
|
||||||
|
|
||||||
|
# Error, can't reshape (5, 5, 3) to (6, -1)
|
||||||
|
out = compiled_fun(x)
|
||||||
|
|
||||||
|
The second call to the ``compiled_fun`` fails because of the call to
|
||||||
|
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
|
||||||
|
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
return x.flatten(0, 1)
|
||||||
|
|
||||||
|
compiled_fun = mx.compile(fun, shapeless=True)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 3, 4))
|
||||||
|
|
||||||
|
out = compiled_fun(x)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(5, 5, 3))
|
||||||
|
|
||||||
|
# Ok
|
||||||
|
out = compiled_fun(x)
|
||||||
|
|||||||
@@ -5,21 +5,27 @@ Distributed Communication
|
|||||||
|
|
||||||
.. currentmodule:: mlx.core.distributed
|
.. currentmodule:: mlx.core.distributed
|
||||||
|
|
||||||
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
MLX supports distributed communication operations that allow the computational cost
|
||||||
provide distributed communication operations that allow the computational cost
|
of training or inference to be shared across many physical machines. At the
|
||||||
of training or inference to be shared across many physical machines. You can
|
moment we support two different communication backends:
|
||||||
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
|
||||||
|
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
||||||
|
full-featured and mature distributed communications library
|
||||||
|
* A **ring** backend of our own that uses native TCP sockets and should be
|
||||||
|
faster for thunderbolt connections.
|
||||||
|
|
||||||
|
The list of all currently supported operations and their documentation can be
|
||||||
|
seen in the :ref:`API docs<distributed>`.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
A lot of operations may not be supported or not as fast as they should be.
|
Some operations may not be supported or not as fast as they should be.
|
||||||
We are adding more and tuning the ones we have as we are figuring out the
|
We are adding more and tuning the ones we have as we are figuring out the
|
||||||
best way to do distributed computing on Macs using MLX.
|
best way to do distributed computing on Macs using MLX.
|
||||||
|
|
||||||
Getting Started
|
Getting Started
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
A distributed program in MLX is as simple as:
|
||||||
machine. The minimal distributed program in MLX is as simple as:
|
|
||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
@@ -30,74 +36,79 @@ machine. The minimal distributed program in MLX is as simple as:
|
|||||||
print(world.rank(), x)
|
print(world.rank(), x)
|
||||||
|
|
||||||
The program above sums the array ``mx.ones(10)`` across all
|
The program above sums the array ``mx.ones(10)`` across all
|
||||||
distributed processes. If simply run with ``python``, however, only one
|
distributed processes. However, when this script is run with ``python`` only
|
||||||
process is launched and no distributed communication takes place.
|
one process is launched and no distributed communication takes place. Namely,
|
||||||
|
all operations in ``mx.distributed`` are noops when the distributed group has a
|
||||||
|
size of one. This property allows us to avoid code that checks if we are in a
|
||||||
|
distributed setting similar to the one below:
|
||||||
|
|
||||||
To launch the program in distributed mode we need to use ``mpirun`` or
|
.. code:: python
|
||||||
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
|
||||||
following:
|
import mlx.core as mx
|
||||||
|
|
||||||
|
x = ...
|
||||||
|
world = mx.distributed.init()
|
||||||
|
# No need for the check we can simply do x = mx.distributed.all_sum(x)
|
||||||
|
if world.size() > 1:
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
|
||||||
|
Running Distributed Programs
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
|
||||||
|
Continuing with our initial example we can run it on localhost with 4 processes using
|
||||||
|
|
||||||
.. code:: shell
|
.. code:: shell
|
||||||
|
|
||||||
$ mpirun -np 2 python test.py
|
$ mlx.launch -n 4 my_script.py
|
||||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
|
||||||
The above launches two processes on the same (local) machine and we can see
|
We can also run it on some remote hosts by providing their IPs (provided that
|
||||||
both standard output streams. The processes send the array of 1s to each other
|
the script exists on all hosts and they are reachable by ssh)
|
||||||
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
|
||||||
print 4 etc.
|
|
||||||
|
|
||||||
Installing MPI
|
|
||||||
---------------
|
|
||||||
|
|
||||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
|
||||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
|
||||||
with the Anaconda package manager as follows:
|
|
||||||
|
|
||||||
.. code:: shell
|
.. code:: shell
|
||||||
|
|
||||||
$ conda install openmpi
|
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
|
||||||
|
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||||
|
|
||||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
Consult the dedicated :doc:`usage guide<launching_distributed>` for more
|
||||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
information on using ``mlx.launch``.
|
||||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
|
||||||
|
|
||||||
.. code:: shell
|
Selecting Backend
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
You can select the backend you want to use when calling :func:`init` by passing
|
||||||
|
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
|
||||||
Setting up Remote Hosts
|
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
|
||||||
-----------------------
|
both fail then a singleton group is created.
|
||||||
|
|
||||||
MPI can automatically connect to remote hosts and set up the communication over
|
|
||||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
|
||||||
debug connectivity issues is the following:
|
|
||||||
|
|
||||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
|
||||||
password or host confirmation
|
|
||||||
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
|
||||||
full path to force all machines to use a specific path.
|
|
||||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
|
||||||
in the ``.ssh/config`` files on all machines.
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
After a distributed backend is successfully initialized :func:`init` will
|
||||||
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
return **the same backend** if called without arguments or with backend set to
|
||||||
|
``any``.
|
||||||
|
|
||||||
An easy way to pass the host names to MPI is using a host file. A host file
|
The following examples aim to clarify the backend initialization logic in MLX:
|
||||||
looks like the following, where ``host1`` and ``host2`` should be the fully
|
|
||||||
qualified domain names or IPs for these hosts.
|
|
||||||
|
|
||||||
.. code::
|
.. code:: python
|
||||||
|
|
||||||
host1 slots=1
|
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
|
||||||
host2 slots=1
|
world = mx.distributed.init(backend="mpi")
|
||||||
|
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
|
||||||
|
|
||||||
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
# Case 2: Initialize any backend
|
||||||
process per host. The hostfile also needs to contain the current
|
world = mx.distributed.init(backend="any") # equivalent to no arguments
|
||||||
host if you want to run on the local host. Passing the host file to
|
world2 = mx.distributed.init() # same as above
|
||||||
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
|
||||||
|
# Case 3: Initialize both backends at the same time
|
||||||
|
world_mpi = mx.distributed.init(backend="mpi")
|
||||||
|
world_ring = mx.distributed.init(backend="ring")
|
||||||
|
world_any = mx.distributed.init() # same as MPI because it was initialized first!
|
||||||
|
|
||||||
Training Example
|
Training Example
|
||||||
----------------
|
----------------
|
||||||
@@ -141,12 +152,13 @@ everything else remaining the same.
|
|||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
|
|
||||||
def all_reduce_grads(grads):
|
def all_reduce_grads(grads):
|
||||||
N = mx.distributed.init()
|
N = mx.distributed.init().size()
|
||||||
if N == 1:
|
if N == 1:
|
||||||
return grads
|
return grads
|
||||||
return tree_map(
|
return tree_map(
|
||||||
lambda x: mx.distributed.all_sum(x) / N,
|
lambda x: mx.distributed.all_sum(x) / N,
|
||||||
grads)
|
grads
|
||||||
|
)
|
||||||
|
|
||||||
def step(model, x, y):
|
def step(model, x, y):
|
||||||
loss, grads = loss_grad_fn(model, x, y)
|
loss, grads = loss_grad_fn(model, x, y)
|
||||||
@@ -154,13 +166,179 @@ everything else remaining the same.
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
Tuning All Reduce
|
Utilizing ``nn.average_gradients``
|
||||||
-----------------
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
We are working on improving the performance of all reduce on MLX but for now
|
Although the code example above works correctly; it performs one communication
|
||||||
the two main things one can do to extract the most out of distributed training with MLX are:
|
per gradient. It is significantly more efficient to aggregate several gradients
|
||||||
|
together and perform fewer communication steps.
|
||||||
|
|
||||||
1. Perform a few large reductions instead of many small ones to improve
|
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
|
||||||
bandwidth and latency
|
almost identical to the example above:
|
||||||
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
|
||||||
connections between each host to improve bandwidth
|
.. code:: python
|
||||||
|
|
||||||
|
model = ...
|
||||||
|
optimizer = ...
|
||||||
|
dataset = ...
|
||||||
|
|
||||||
|
def step(model, x, y):
|
||||||
|
loss, grads = loss_grad_fn(model, x, y)
|
||||||
|
grads = mx.nn.average_gradients(grads) # <---- This line was added
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
for x, y in dataset:
|
||||||
|
loss = step(model, x, y)
|
||||||
|
mx.eval(loss, model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
Getting Started with MPI
|
||||||
|
------------------------
|
||||||
|
|
||||||
|
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||||
|
machine. Launching distributed MLX programs that use MPI can be done with
|
||||||
|
``mpirun`` as expected. However, in the following examples we will be using
|
||||||
|
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
|
||||||
|
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
|
||||||
|
library.
|
||||||
|
|
||||||
|
The simplest possible usage is the following which, assuming the minimal
|
||||||
|
example in the beginning of this page, should result in:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ mlx.launch --backend mpi -n 2 test.py
|
||||||
|
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||||
|
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||||
|
|
||||||
|
The above launches two processes on the same (local) machine and we can see
|
||||||
|
both standard output streams. The processes send the array of 1s to each other
|
||||||
|
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
|
||||||
|
print 4 etc.
|
||||||
|
|
||||||
|
Installing MPI
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||||
|
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||||
|
with the Anaconda package manager as follows:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ conda install conda-forge::openmpi
|
||||||
|
|
||||||
|
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||||
|
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||||
|
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
||||||
|
done automatically by ``mlx.launch``.
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||||
|
$ # or simply
|
||||||
|
$ mlx.launch -n 2 test.py
|
||||||
|
|
||||||
|
Setting up Remote Hosts
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
MPI can automatically connect to remote hosts and set up the communication over
|
||||||
|
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||||
|
debug connectivity issues is the following:
|
||||||
|
|
||||||
|
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||||
|
password or host confirmation
|
||||||
|
* ``mpirun`` is accessible on all machines.
|
||||||
|
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||||
|
in the ``.ssh/config`` files on all machines.
|
||||||
|
|
||||||
|
Tuning MPI All Reduce
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
For faster all reduce consider using the ring backend either with Thunderbolt
|
||||||
|
connections or over Ethernet.
|
||||||
|
|
||||||
|
Configure MPI to use N tcp connections between each host to improve bandwidth
|
||||||
|
by passing ``--mca btl_tcp_links N``.
|
||||||
|
|
||||||
|
Force MPI to use the most performant network interface by setting ``--mca
|
||||||
|
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
|
||||||
|
to use.
|
||||||
|
|
||||||
|
Getting Started with Ring
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
The ring backend does not depend on any third party library so it is always
|
||||||
|
available. It uses TCP sockets so the nodes need to be reachable via a network.
|
||||||
|
As the name suggests the nodes are connected in a ring which means that rank 1
|
||||||
|
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
|
||||||
|
and so on and so forth. As a result :func:`send` and :func:`recv` with
|
||||||
|
arbitrary sender and receiver is not supported in the ring backend.
|
||||||
|
|
||||||
|
Defining a Ring
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The easiest way to define and use a ring is via a JSON hostfile and the
|
||||||
|
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
|
||||||
|
defines a hostname to ssh into to run commands on this node and one or more IPs
|
||||||
|
that this node will listen to for connections.
|
||||||
|
|
||||||
|
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
|
||||||
|
rank 0, ``hostname2`` rank 1 etc.
|
||||||
|
|
||||||
|
.. code:: json
|
||||||
|
|
||||||
|
[
|
||||||
|
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
||||||
|
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
||||||
|
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
|
||||||
|
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
|
||||||
|
]
|
||||||
|
|
||||||
|
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
|
||||||
|
node, run the script which will listen for connections in each of the provided
|
||||||
|
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
|
||||||
|
connection from ``123.123.123.4`` and so on and so forth.
|
||||||
|
|
||||||
|
Thunderbolt Ring
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Although the ring backend can have benefits over MPI even for Ethernet, its
|
||||||
|
main purpose is to use Thunderbolt rings for higher bandwidth communication.
|
||||||
|
Setting up such thunderbolt rings can be done manually, but is a relatively
|
||||||
|
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
|
||||||
|
|
||||||
|
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
|
||||||
|
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
|
||||||
|
utility as follows:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
|
||||||
|
|
||||||
|
By default the script will attempt to discover the thunderbolt ring and provide
|
||||||
|
you with the commands to configure each node as well as the ``hostfile.json``
|
||||||
|
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
|
||||||
|
then ``--auto-setup`` can be used to configure them automatically.
|
||||||
|
|
||||||
|
To validate your connection without configuring anything
|
||||||
|
``mlx.distributed_config`` can also plot the ring using DOT format.
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
|
||||||
|
dot -Tpng ring.dot >ring.png
|
||||||
|
open ring.png
|
||||||
|
|
||||||
|
If you want to go through the process manually, the steps are as follows:
|
||||||
|
|
||||||
|
* Disable the thunderbolt bridge interface
|
||||||
|
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
|
||||||
|
corresponding to that cable in nodes ``i`` and ``i + 1``.
|
||||||
|
* Set up a unique subnetwork connecting the two nodes for the corresponding
|
||||||
|
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
|
||||||
|
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
|
||||||
|
``192.168.0.2`` respectively to the two nodes. For more details you can see
|
||||||
|
the commands prepared by the utility script.
|
||||||
|
|||||||
288
docs/src/usage/export.rst
Normal file
288
docs/src/usage/export.rst
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
.. _export_usage:
|
||||||
|
|
||||||
|
Exporting Functions
|
||||||
|
===================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
MLX has an API to export and import functions to and from a file. This lets you
|
||||||
|
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||||
|
front-end (e.g. C++).
|
||||||
|
|
||||||
|
This guide walks through the basics of the MLX export API with some examples.
|
||||||
|
To see the full list of functions check-out the :ref:`API documentation
|
||||||
|
<export>`.
|
||||||
|
|
||||||
|
Basics of Exporting
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
Let's start with a simple example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
x = mx.array(1.0)
|
||||||
|
y = mx.array(1.0)
|
||||||
|
mx.export_function("add.mlxfn", fun, x, y)
|
||||||
|
|
||||||
|
To export a function, provide sample input arrays that the function
|
||||||
|
can be called with. The data doesn't matter, but the shapes and types of the
|
||||||
|
arrays do. In the above example we exported ``fun`` with two ``float32``
|
||||||
|
scalar arrays. We can then import the function and run it:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
add_fun = mx.import_function("add.mlxfn")
|
||||||
|
|
||||||
|
out, = add_fun(mx.array(1.0), mx.array(2.0))
|
||||||
|
# Prints: array(3, dtype=float32)
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
out, = add_fun(mx.array(1.0), mx.array(3.0))
|
||||||
|
# Prints: array(4, dtype=float32)
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
# Raises an exception
|
||||||
|
add_fun(mx.array(1), mx.array(3.0))
|
||||||
|
|
||||||
|
# Raises an exception
|
||||||
|
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
|
||||||
|
|
||||||
|
Notice the third and fourth calls to ``add_fun`` raise exceptions because the
|
||||||
|
shapes and types of the inputs are different than the shapes and types of the
|
||||||
|
example inputs we exported the function with.
|
||||||
|
|
||||||
|
Also notice that even though the original ``fun`` returns a single output
|
||||||
|
array, the imported function always returns a tuple of one or more arrays.
|
||||||
|
|
||||||
|
The inputs to :func:`export_function` and to an imported function can be
|
||||||
|
specified as variable positional arguments or as a tuple of arrays:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
x = mx.array(1.0)
|
||||||
|
y = mx.array(1.0)
|
||||||
|
|
||||||
|
# Both arguments to fun are positional
|
||||||
|
mx.export_function("add.mlxfn", fun, x, y)
|
||||||
|
|
||||||
|
# Same as above
|
||||||
|
mx.export_function("add.mlxfn", fun, (x, y))
|
||||||
|
|
||||||
|
imported_fun = mx.import_function("add.mlxfn")
|
||||||
|
|
||||||
|
# Ok
|
||||||
|
out, = imported_fun(x, y)
|
||||||
|
|
||||||
|
# Also ok
|
||||||
|
out, = imported_fun((x, y))
|
||||||
|
|
||||||
|
You can pass example inputs to functions as positional or keyword arguments. If
|
||||||
|
you use keyword arguments to export the function, then you have to use the same
|
||||||
|
keyword arguments when calling the imported function.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
# One argument to fun is positional, the other is a kwarg
|
||||||
|
mx.export_function("add.mlxfn", fun, x, y=y)
|
||||||
|
|
||||||
|
imported_fun = mx.import_function("add.mlxfn")
|
||||||
|
|
||||||
|
# Ok
|
||||||
|
out, = imported_fun(x, y=y)
|
||||||
|
|
||||||
|
# Also ok
|
||||||
|
out, = imported_fun((x,), {"y": y})
|
||||||
|
|
||||||
|
# Raises since the keyword argument is missing
|
||||||
|
out, = imported_fun(x, y)
|
||||||
|
|
||||||
|
# Raises since the keyword argument has the wrong key
|
||||||
|
out, = imported_fun(x, z=y)
|
||||||
|
|
||||||
|
|
||||||
|
Exporting Modules
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
An :obj:`mlx.nn.Module` can be exported with or without the parameters included
|
||||||
|
in the exported function. Here's an example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = nn.Linear(4, 4)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
def call(x):
|
||||||
|
return model(x)
|
||||||
|
|
||||||
|
mx.export_function("model.mlxfn", call, mx.zeros(4))
|
||||||
|
|
||||||
|
In the above example, the :obj:`mlx.nn.Linear` module is exported. Its
|
||||||
|
parameters are also saved to the ``model.mlxfn`` file.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||||
|
they are evaluated. The computation graph that gets exported will include
|
||||||
|
the computation that produces enclosed inputs.
|
||||||
|
|
||||||
|
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||||
|
exported function would include the random initialization of the
|
||||||
|
:obj:`mlx.nn.Module` parameters.
|
||||||
|
|
||||||
|
If you only want to export the ``Module.__call__`` function without the
|
||||||
|
parameters, pass them as inputs to the ``call`` wrapper:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
model = nn.Linear(4, 4)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
def call(x, **params):
|
||||||
|
# Set the model's parameters to the input parameters
|
||||||
|
model.update(tree_unflatten(list(params.items())))
|
||||||
|
return model(x)
|
||||||
|
|
||||||
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
|
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||||
|
|
||||||
|
|
||||||
|
Shapeless Exports
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
Just like :func:`compile`, functions can also be exported for dynamically shaped
|
||||||
|
inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`
|
||||||
|
to export a function which can be used for inputs with variable shapes:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
mx.export_function("fun.mlxfn", mx.abs, mx.array([0.0]), shapeless=True)
|
||||||
|
imported_abs = mx.import_function("fun.mlxfn")
|
||||||
|
|
||||||
|
# Ok
|
||||||
|
out, = imported_abs(mx.array([-1.0]))
|
||||||
|
|
||||||
|
# Also ok
|
||||||
|
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||||
|
|
||||||
|
With ``shapeless=False`` (which is the default), the second call to
|
||||||
|
``imported_abs`` would raise an exception with a shape mismatch.
|
||||||
|
|
||||||
|
Shapeless exporting works the same as shapeless compilation and should be
|
||||||
|
used carefully. See the :ref:`documentation on shapeless compilation
|
||||||
|
<shapeless_compile>` for more information.
|
||||||
|
|
||||||
|
Exporting Multiple Traces
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
In some cases, functions build different computation graphs for different
|
||||||
|
input arguments. A simple way to manage this is to export to a new file with
|
||||||
|
each set of inputs. This is a fine option in many cases. But it can be
|
||||||
|
suboptimal if the exported functions have a large amount of duplicate constant
|
||||||
|
data (for example the parameters of a :obj:`mlx.nn.Module`).
|
||||||
|
|
||||||
|
The export API in MLX lets you export multiple traces of the same function to
|
||||||
|
a single file by creating an exporting context manager with :func:`exporter`:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, y=None):
|
||||||
|
constant = mx.array(3.0)
|
||||||
|
if y is not None:
|
||||||
|
x += y
|
||||||
|
return x + constant
|
||||||
|
|
||||||
|
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||||
|
exporter(mx.array(1.0))
|
||||||
|
exporter(mx.array(1.0), y=mx.array(0.0))
|
||||||
|
|
||||||
|
imported_function = mx.import_function("fun.mlxfn")
|
||||||
|
|
||||||
|
# Call the function with y=None
|
||||||
|
out, = imported_function(mx.array(1.0))
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
# Call the function with y specified
|
||||||
|
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
In the above example the function constant data, (i.e. ``constant``), is only
|
||||||
|
saved once.
|
||||||
|
|
||||||
|
Transformations with Imported Functions
|
||||||
|
---------------------------------------
|
||||||
|
|
||||||
|
Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work
|
||||||
|
on imported functions just like regular Python functions:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
return mx.sin(x)
|
||||||
|
|
||||||
|
x = mx.array(0.0)
|
||||||
|
mx.export_function("sine.mlxfn", fun, x)
|
||||||
|
|
||||||
|
imported_fun = mx.import_function("sine.mlxfn")
|
||||||
|
|
||||||
|
# Take the derivative of the imported function
|
||||||
|
dfdx = mx.grad(lambda x: imported_fun(x)[0])
|
||||||
|
# Prints: array(1, dtype=float32)
|
||||||
|
print(dfdx(x))
|
||||||
|
|
||||||
|
# Compile the imported function
|
||||||
|
mx.compile(imported_fun)
|
||||||
|
# Prints: array(0, dtype=float32)
|
||||||
|
print(compiled_fun(x)[0])
|
||||||
|
|
||||||
|
|
||||||
|
Importing Functions in C++
|
||||||
|
--------------------------
|
||||||
|
|
||||||
|
Importing and running functions in C++ is basically the same as importing and
|
||||||
|
running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to
|
||||||
|
setup a simple C++ project that uses MLX as a library.
|
||||||
|
|
||||||
|
Next, export a simple function from Python:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return mx.exp(x + y)
|
||||||
|
|
||||||
|
x = mx.array(1.0)
|
||||||
|
y = mx.array(1.0)
|
||||||
|
mx.export_function("fun.mlxfn", fun, x, y)
|
||||||
|
|
||||||
|
|
||||||
|
Import and run the function in C++ with only a few lines of code:
|
||||||
|
|
||||||
|
.. code-block:: c++
|
||||||
|
|
||||||
|
auto fun = mx::import_function("fun.mlxfn");
|
||||||
|
|
||||||
|
auto inputs = {mx::array(1.0), mx::array(1.0)};
|
||||||
|
auto outputs = fun(inputs);
|
||||||
|
|
||||||
|
// Prints: array(2, dtype=float32)
|
||||||
|
std::cout << outputs[0] << std::endl;
|
||||||
|
|
||||||
|
Imported functions can be transformed in C++ just like in Python. Use
|
||||||
|
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||||
|
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||||
|
|
||||||
|
More Examples
|
||||||
|
-------------
|
||||||
|
|
||||||
|
Here are a few more complete examples exporting more complex functions from
|
||||||
|
Python and importing and running them in C++:
|
||||||
|
|
||||||
|
* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_
|
||||||
@@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop:
|
|||||||
ys = mx.random.uniform(shape=(100, 4096))
|
ys = mx.random.uniform(shape=(100, 4096))
|
||||||
|
|
||||||
def naive_add(xs, ys):
|
def naive_add(xs, ys):
|
||||||
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
|
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
|
||||||
|
|
||||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||||
|
|
||||||
@@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition:
|
|||||||
|
|
||||||
# Vectorize over the second dimension of x and the
|
# Vectorize over the second dimension of x and the
|
||||||
# first dimension of y
|
# first dimension of y
|
||||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
|
||||||
|
|
||||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||||
@@ -184,8 +184,8 @@ Let's time these two different versions:
|
|||||||
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||||
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||||
|
|
||||||
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
|
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
|
||||||
vectorized version takes only ``0.025`` seconds, more than ten times faster.
|
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
|
||||||
|
|
||||||
Of course, this operation is quite contrived. A better approach is to simply do
|
Of course, this operation is quite contrived. A better approach is to simply do
|
||||||
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the
|
|||||||
kernel would be extremely inefficient.
|
kernel would be extremely inefficient.
|
||||||
|
|
||||||
Indexing with boolean masks is something that MLX may support in the future. In
|
Indexing with boolean masks is something that MLX may support in the future. In
|
||||||
general, MLX has limited support for operations for which outputs
|
general, MLX has limited support for operations for which output
|
||||||
*shapes* are dependent on input *data*. Other examples of these types of
|
*shapes* are dependent on input *data*. Other examples of these types of
|
||||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||||
single input version of :func:`numpy.where`.
|
single input version of :func:`numpy.where`.
|
||||||
@@ -107,6 +107,28 @@ same array:
|
|||||||
>>> a
|
>>> a
|
||||||
array([1, 2, 0], dtype=int32)
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
Note that unlike NumPy, slicing an array creates a copy, not a view. So
|
||||||
|
mutating it does not mutate the original array:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> b = a[:]
|
||||||
|
>>> b[2] = 0
|
||||||
|
>>> b
|
||||||
|
array([1, 2, 0], dtype=int32)
|
||||||
|
>>> a
|
||||||
|
array([1, 2, 3], dtype=int32)
|
||||||
|
|
||||||
|
Also unlike NumPy, updates to the same location are nondeterministic:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> a[[0, 0]] = mx.array([4, 5])
|
||||||
|
|
||||||
|
The first element of ``a`` could be ``4`` or ``5``.
|
||||||
|
|
||||||
Transformations of functions which use in-place updates are allowed and work as
|
Transformations of functions which use in-place updates are allowed and work as
|
||||||
expected. For example:
|
expected. For example:
|
||||||
|
|
||||||
|
|||||||
105
docs/src/usage/launching_distributed.rst
Normal file
105
docs/src/usage/launching_distributed.rst
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
:orphan:
|
||||||
|
|
||||||
|
.. _usage_launch_distributed:
|
||||||
|
|
||||||
|
Launching Distributed Programs
|
||||||
|
==============================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.distributed
|
||||||
|
|
||||||
|
Installing the MLX python package provides a helper script ``mlx.launch`` that
|
||||||
|
can be used to run python scripts distributed on several nodes. It allows
|
||||||
|
launching using either the MPI backend or the ring backend. See the
|
||||||
|
:doc:`distributed docs <distributed>` for the different backends.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
|
||||||
|
The minimal usage example of ``mlx.launch`` is simply
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
mlx.launch --hosts ip1,ip2 my_script.py
|
||||||
|
|
||||||
|
or for testing on localhost
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
mlx.launch -n 2 my_script.py
|
||||||
|
|
||||||
|
The ``mlx.launch`` command connects to the provided host and launches the input
|
||||||
|
script on each host. It monitors each of the launched processes and terminates
|
||||||
|
the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
|
||||||
|
It also takes care of forwarding the output of each remote process to stdout
|
||||||
|
and stderr respectively.
|
||||||
|
|
||||||
|
Providing Hosts
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Hosts can be provided as command line arguments, like above, but the way that
|
||||||
|
allows to fully define a list of hosts is via a JSON hostfile. The hostfile has
|
||||||
|
a very simple schema. It is simply a list of objects that define each host via
|
||||||
|
a hostname to ssh to and a list of IPs to utilize for the communication.
|
||||||
|
|
||||||
|
.. code:: json
|
||||||
|
|
||||||
|
[
|
||||||
|
{"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
|
||||||
|
{"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
|
||||||
|
]
|
||||||
|
|
||||||
|
You can use ``mlx.distributed_config --over ethernet`` to create a hostfile
|
||||||
|
with IPs corresponding to the ``en0`` interface.
|
||||||
|
|
||||||
|
Setting up Remote Hosts
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
In order to be able to launch the script on each host we need to be able to
|
||||||
|
connect via ssh. Moreover the input script and python binary need to be on each
|
||||||
|
host and on the same path. A good checklist to debug errors is the following:
|
||||||
|
|
||||||
|
* ``ssh hostname`` works without asking for password or host confirmation
|
||||||
|
* the python binary is available on all hosts at the same path. You can use
|
||||||
|
``mlx.launch --print-python`` to see what that path is.
|
||||||
|
* the script you want to run is available on all hosts at the same path
|
||||||
|
|
||||||
|
.. _mpi_specifics:
|
||||||
|
|
||||||
|
MPI Specifics
|
||||||
|
-------------
|
||||||
|
|
||||||
|
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
|
||||||
|
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
|
||||||
|
|
||||||
|
* The IPs in the hostfile are ignored
|
||||||
|
* The ssh connectivity requirement is stronger as every node needs to be able
|
||||||
|
to connect to every other node
|
||||||
|
* ``mpirun`` needs to be available on every node at the same path
|
||||||
|
|
||||||
|
Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance
|
||||||
|
to choose a specific interface for the byte-transfer-layer of MPI we can call
|
||||||
|
``mlx.launch`` as follows:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
|
||||||
|
|
||||||
|
|
||||||
|
.. _ring_specifics:
|
||||||
|
|
||||||
|
Ring Specifics
|
||||||
|
--------------
|
||||||
|
|
||||||
|
The ring backend, which is also the default backend, can be explicitly selected
|
||||||
|
with the argument ``--backend ring``. The ring backend has some specific
|
||||||
|
requirements and arguments that are different to MPI:
|
||||||
|
|
||||||
|
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
|
||||||
|
ssh to a hostname that does not correspond to the IP we want to bind to we
|
||||||
|
have to provide a hostfile.
|
||||||
|
* ``--starting-port`` defines the port to bind to on the remote hosts.
|
||||||
|
Specifically rank 0 for the first IP will use this port and each subsequent
|
||||||
|
IP or rank will add 1 to this port.
|
||||||
|
* ``--connections-per-ip`` allows us to increase the number of connections
|
||||||
|
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
|
||||||
|
``mpirun``.
|
||||||
@@ -109,7 +109,7 @@ Here is a concrete example:
|
|||||||
|
|
||||||
An important behavior to be aware of is when the graph will be implicitly
|
An important behavior to be aware of is when the graph will be implicitly
|
||||||
evaluated. Anytime you ``print`` an array, convert it to an
|
evaluated. Anytime you ``print`` an array, convert it to an
|
||||||
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
|
||||||
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||||
saving functions) will also evaluate the array.
|
saving functions) will also evaluate the array.
|
||||||
|
|
||||||
|
|||||||
@@ -21,11 +21,13 @@ Let's convert an array to NumPy and back.
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
|
Since NumPy does not support ``bfloat16`` arrays, you will need to convert
|
||||||
``np.array(a.astype(mx.float32))``.
|
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
|
||||||
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
|
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
|
||||||
|
buffer format string does not match the dtype V item size 0.``
|
||||||
|
|
||||||
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
|
By default, NumPy copies data to a new array. This can be prevented by creating
|
||||||
|
an array view:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@@ -35,10 +37,16 @@ By default, NumPy copies data to a new array. This can be prevented by creating
|
|||||||
a_view[0] = 1
|
a_view[0] = 1
|
||||||
print(a[0].item()) # 1
|
print(a[0].item()) # 1
|
||||||
|
|
||||||
A NumPy array view is a normal NumPy array, except that it does not own its memory.
|
.. note::
|
||||||
This means writing to the view is reflected in the original array.
|
|
||||||
|
|
||||||
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
|
NumPy arrays with type ``float64`` will be default converted to MLX arrays
|
||||||
|
with type ``float32``.
|
||||||
|
|
||||||
|
A NumPy array view is a normal NumPy array, except that it does not own its
|
||||||
|
memory. This means writing to the view is reflected in the original array.
|
||||||
|
|
||||||
|
While this is quite powerful to prevent copying arrays, it should be noted that
|
||||||
|
external changes to the memory of arrays cannot be reflected in gradients.
|
||||||
|
|
||||||
Let's demonstrate this in an example:
|
Let's demonstrate this in an example:
|
||||||
|
|
||||||
@@ -56,11 +64,12 @@ Let's demonstrate this in an example:
|
|||||||
|
|
||||||
|
|
||||||
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
||||||
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
|
However, this modification is not reflected in the gradient, as seen in the
|
||||||
representing the gradient of the sum operation alone.
|
last line outputting ``1.0``, representing the gradient of the sum operation
|
||||||
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
|
alone. The squaring of ``x`` occurs externally to MLX, meaning that no
|
||||||
It's important to note that a similar issue arises during array conversion and copying.
|
gradient is incorporated. It's important to note that a similar issue arises
|
||||||
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
during array conversion and copying. For instance, a function defined as
|
||||||
|
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||||
even though no in-place operations on MLX memory are executed.
|
even though no in-place operations on MLX memory are executed.
|
||||||
|
|
||||||
PyTorch
|
PyTorch
|
||||||
@@ -71,7 +80,8 @@ PyTorch
|
|||||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||||
|
|
||||||
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
PyTorch supports the buffer protocol, but it requires an explicit
|
||||||
|
:obj:`memoryview`.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@@ -82,7 +92,8 @@ PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryvi
|
|||||||
b = torch.tensor(memoryview(a))
|
b = torch.tensor(memoryview(a))
|
||||||
c = mx.array(b.numpy())
|
c = mx.array(b.numpy())
|
||||||
|
|
||||||
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
|
Conversion from PyTorch tensors back to arrays must be done via intermediate
|
||||||
|
NumPy arrays with ``numpy()``.
|
||||||
|
|
||||||
JAX
|
JAX
|
||||||
---
|
---
|
||||||
@@ -100,7 +111,8 @@ JAX fully supports the buffer protocol.
|
|||||||
TensorFlow
|
TensorFlow
|
||||||
----------
|
----------
|
||||||
|
|
||||||
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
TensorFlow supports the buffer protocol, but it requires an explicit
|
||||||
|
:obj:`memoryview`.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
|||||||
22
examples/cmake_project/CMakeLists.txt
Normal file
22
examples/cmake_project/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.27)
|
||||||
|
|
||||||
|
project(example LANGUAGES CXX)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
|
# Comment the following two commands only the MLX C++ library is installed and
|
||||||
|
# set(MLX_ROOT "/path/to/mlx") directly if needed.
|
||||||
|
find_package(
|
||||||
|
Python 3.9
|
||||||
|
COMPONENTS Interpreter Development.Module
|
||||||
|
REQUIRED)
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE MLX_ROOT)
|
||||||
|
|
||||||
|
find_package(MLX CONFIG REQUIRED)
|
||||||
|
|
||||||
|
add_executable(example example.cpp)
|
||||||
|
target_link_libraries(example PRIVATE mlx)
|
||||||
26
examples/cmake_project/README.md
Normal file
26
examples/cmake_project/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
## Build and Run
|
||||||
|
|
||||||
|
Install MLX with Python:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx>=0.22
|
||||||
|
```
|
||||||
|
|
||||||
|
Build the C++ example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||||
|
cmake --build build
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the C++ example:
|
||||||
|
|
||||||
|
```
|
||||||
|
./build/example
|
||||||
|
```
|
||||||
|
|
||||||
|
which should output:
|
||||||
|
|
||||||
|
```
|
||||||
|
array([2, 4, 6], dtype=int32)
|
||||||
|
```
|
||||||
14
examples/cmake_project/example.cpp
Normal file
14
examples/cmake_project/example.cpp
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
auto x = mx::array({1, 2, 3});
|
||||||
|
auto y = mx::array({1, 2, 3});
|
||||||
|
std::cout << x + y << std::endl;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -4,19 +4,19 @@
|
|||||||
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
if (!distributed::is_available()) {
|
if (!mx::distributed::is_available()) {
|
||||||
std::cout << "No communication backend found" << std::endl;
|
std::cout << "No communication backend found" << std::endl;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto global_group = distributed::init();
|
auto global_group = mx::distributed::init();
|
||||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||||
|
|
||||||
array x = ones({10});
|
mx::array x = mx::ones({10});
|
||||||
array out = distributed::all_sum(x, global_group);
|
mx::array out = mx::distributed::all_sum(x, global_group);
|
||||||
|
|
||||||
std::cout << out << std::endl;
|
std::cout << out << std::endl;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
/**
|
/**
|
||||||
* An example of linear regression with MLX.
|
* An example of linear regression with MLX.
|
||||||
*/
|
*/
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
int num_features = 100;
|
int num_features = 100;
|
||||||
@@ -19,35 +19,35 @@ int main() {
|
|||||||
float learning_rate = 0.01;
|
float learning_rate = 0.01;
|
||||||
|
|
||||||
// True parameters
|
// True parameters
|
||||||
auto w_star = random::normal({num_features});
|
auto w_star = mx::random::normal({num_features});
|
||||||
|
|
||||||
// The input examples (design matrix)
|
// The input examples (design matrix)
|
||||||
auto X = random::normal({num_examples, num_features});
|
auto X = mx::random::normal({num_examples, num_features});
|
||||||
|
|
||||||
// Noisy labels
|
// Noisy labels
|
||||||
auto eps = 1e-2 * random::normal({num_examples});
|
auto eps = 1e-2 * mx::random::normal({num_examples});
|
||||||
auto y = matmul(X, w_star) + eps;
|
auto y = mx::matmul(X, w_star) + eps;
|
||||||
|
|
||||||
// Initialize random parameters
|
// Initialize random parameters
|
||||||
array w = 1e-2 * random::normal({num_features});
|
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||||
|
|
||||||
auto loss_fn = [&](array w) {
|
auto loss_fn = [&](mx::array w) {
|
||||||
auto yhat = matmul(X, w);
|
auto yhat = mx::matmul(X, w);
|
||||||
return (0.5f / num_examples) * sum(square(yhat - y));
|
return (0.5f / num_examples) * mx::sum(mx::square(yhat - y));
|
||||||
};
|
};
|
||||||
|
|
||||||
auto grad_fn = grad(loss_fn);
|
auto grad_fn = mx::grad(loss_fn);
|
||||||
|
|
||||||
auto tic = timer::time();
|
auto tic = timer::time();
|
||||||
for (int it = 0; it < num_iters; ++it) {
|
for (int it = 0; it < num_iters; ++it) {
|
||||||
auto grad = grad_fn(w);
|
auto grads = grad_fn(w);
|
||||||
w = w - learning_rate * grad;
|
w = w - learning_rate * grads;
|
||||||
eval(w);
|
mx::eval(w);
|
||||||
}
|
}
|
||||||
auto toc = timer::time();
|
auto toc = timer::time();
|
||||||
|
|
||||||
auto loss = loss_fn(w);
|
auto loss = loss_fn(w);
|
||||||
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
|
auto error_norm = std::sqrt(mx::sum(mx::square(w - w_star)).item<float>());
|
||||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||||
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
||||||
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
/**
|
/**
|
||||||
* An example of logistic regression with MLX.
|
* An example of logistic regression with MLX.
|
||||||
*/
|
*/
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
int num_features = 100;
|
int num_features = 100;
|
||||||
@@ -19,35 +19,35 @@ int main() {
|
|||||||
float learning_rate = 0.1;
|
float learning_rate = 0.1;
|
||||||
|
|
||||||
// True parameters
|
// True parameters
|
||||||
auto w_star = random::normal({num_features});
|
auto w_star = mx::random::normal({num_features});
|
||||||
|
|
||||||
// The input examples
|
// The input examples
|
||||||
auto X = random::normal({num_examples, num_features});
|
auto X = mx::random::normal({num_examples, num_features});
|
||||||
|
|
||||||
// Labels
|
// Labels
|
||||||
auto y = matmul(X, w_star) > 0;
|
auto y = mx::matmul(X, w_star) > 0;
|
||||||
|
|
||||||
// Initialize random parameters
|
// Initialize random parameters
|
||||||
array w = 1e-2 * random::normal({num_features});
|
mx::array w = 1e-2 * mx::random::normal({num_features});
|
||||||
|
|
||||||
auto loss_fn = [&](array w) {
|
auto loss_fn = [&](mx::array w) {
|
||||||
auto logits = matmul(X, w);
|
auto logits = mx::matmul(X, w);
|
||||||
auto scale = (1.0f / num_examples);
|
auto scale = (1.0f / num_examples);
|
||||||
return scale * sum(logaddexp(array(0.0f), logits) - y * logits);
|
return scale * mx::sum(mx::logaddexp(mx::array(0.0f), logits) - y * logits);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto grad_fn = grad(loss_fn);
|
auto grad_fn = mx::grad(loss_fn);
|
||||||
|
|
||||||
auto tic = timer::time();
|
auto tic = timer::time();
|
||||||
for (int it = 0; it < num_iters; ++it) {
|
for (int it = 0; it < num_iters; ++it) {
|
||||||
auto grad = grad_fn(w);
|
auto grads = grad_fn(w);
|
||||||
w = w - learning_rate * grad;
|
w = w - learning_rate * grads;
|
||||||
eval(w);
|
mx::eval(w);
|
||||||
}
|
}
|
||||||
auto toc = timer::time();
|
auto toc = timer::time();
|
||||||
|
|
||||||
auto loss = loss_fn(w);
|
auto loss = loss_fn(w);
|
||||||
auto acc = sum((matmul(X, w) > 0) == y) / num_examples;
|
auto acc = mx::sum((mx::matmul(X, w) > 0) == y) / num_examples;
|
||||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||||
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput "
|
||||||
<< throughput << " (it/s)." << std::endl;
|
<< throughput << " (it/s)." << std::endl;
|
||||||
|
|||||||
@@ -5,27 +5,27 @@
|
|||||||
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
// To use Metal debugging and profiling:
|
// To use Metal debugging and profiling:
|
||||||
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||||
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||||
metal::start_capture("mlx_trace.gputrace");
|
mx::metal::start_capture("mlx_trace.gputrace");
|
||||||
|
|
||||||
// Start at index two because the default GPU and CPU streams have indices
|
// Start at index two because the default GPU and CPU streams have indices
|
||||||
// zero and one, respectively. This naming matches the label assigned to each
|
// zero and one, respectively. This naming matches the label assigned to each
|
||||||
// stream's command queue.
|
// stream's command queue.
|
||||||
auto s2 = new_stream(Device::gpu);
|
auto s2 = new_stream(mx::Device::gpu);
|
||||||
auto s3 = new_stream(Device::gpu);
|
auto s3 = new_stream(mx::Device::gpu);
|
||||||
|
|
||||||
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
auto a = mx::arange(1.f, 10.f, 1.f, mx::float32, s2);
|
||||||
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
auto b = mx::arange(1.f, 10.f, 1.f, mx::float32, s3);
|
||||||
auto x = add(a, a, s2);
|
auto x = mx::add(a, a, s2);
|
||||||
auto y = add(b, b, s3);
|
auto y = mx::add(b, b, s3);
|
||||||
|
|
||||||
// The multiply will happen on the default stream.
|
// The multiply will happen on the default stream.
|
||||||
std::cout << multiply(x, y) << std::endl;
|
std::cout << mx::multiply(x, y) << std::endl;
|
||||||
|
|
||||||
metal::stop_capture();
|
mx::metal::stop_capture();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,23 +5,26 @@
|
|||||||
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
void array_basics() {
|
void array_basics() {
|
||||||
// Make a scalar array:
|
// Make a scalar array:
|
||||||
array x(1.0);
|
mx::array x(1.0);
|
||||||
|
|
||||||
// Get the value out of it:
|
// Get the value out of it:
|
||||||
auto s = x.item<float>();
|
auto s = x.item<float>();
|
||||||
assert(s == 1.0);
|
assert(s == 1.0);
|
||||||
|
(void)s;
|
||||||
|
|
||||||
// Scalars have a size of 1:
|
// Scalars have a size of 1:
|
||||||
size_t size = x.size();
|
int64_t size = x.size();
|
||||||
assert(size == 1);
|
assert(size == 1);
|
||||||
|
(void)size;
|
||||||
|
|
||||||
// Scalars have 0 dimensions:
|
// Scalars have 0 dimensions:
|
||||||
int ndim = x.ndim();
|
int ndim = x.ndim();
|
||||||
assert(ndim == 0);
|
assert(ndim == 0);
|
||||||
|
(void)ndim;
|
||||||
|
|
||||||
// The shape should be an empty vector:
|
// The shape should be an empty vector:
|
||||||
auto shape = x.shape();
|
auto shape = x.shape();
|
||||||
@@ -29,31 +32,32 @@ void array_basics() {
|
|||||||
|
|
||||||
// The datatype should be float32:
|
// The datatype should be float32:
|
||||||
auto dtype = x.dtype();
|
auto dtype = x.dtype();
|
||||||
assert(dtype == float32);
|
assert(dtype == mx::float32);
|
||||||
|
(void)dtype;
|
||||||
|
|
||||||
// Specify the dtype when constructing the array:
|
// Specify the dtype when constructing the array:
|
||||||
x = array(1, int32);
|
x = mx::array(1, mx::int32);
|
||||||
assert(x.dtype() == int32);
|
assert(x.dtype() == mx::int32);
|
||||||
x.item<int>(); // OK
|
x.item<int>(); // OK
|
||||||
// x.item<float>(); // Undefined!
|
// x.item<float>(); // Undefined!
|
||||||
|
|
||||||
// Make a multidimensional array:
|
// Make a multidimensional array:
|
||||||
x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
|
||||||
// mlx is row-major by default so the first row of this array
|
// mlx is row-major by default so the first row of this array
|
||||||
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||||
|
|
||||||
// Make an array of shape {2, 2} filled with ones:
|
// Make an array of shape {2, 2} filled with ones:
|
||||||
auto y = ones({2, 2});
|
auto y = mx::ones({2, 2});
|
||||||
|
|
||||||
// Pointwise add x and y:
|
// Pointwise add x and y:
|
||||||
auto z = add(x, y);
|
auto z = mx::add(x, y);
|
||||||
|
|
||||||
// Same thing:
|
// Same thing:
|
||||||
z = x + y;
|
z = x + y;
|
||||||
|
|
||||||
// mlx is lazy by default. At this point `z` only
|
// mlx is lazy by default. At this point `z` only
|
||||||
// has a shape and a type but no actual data:
|
// has a shape and a type but no actual data:
|
||||||
assert(z.dtype() == float32);
|
assert(z.dtype() == mx::float32);
|
||||||
assert(z.shape(0) == 2);
|
assert(z.shape(0) == 2);
|
||||||
assert(z.shape(1) == 2);
|
assert(z.shape(1) == 2);
|
||||||
|
|
||||||
@@ -63,33 +67,33 @@ void array_basics() {
|
|||||||
// and inputs. When `eval` is called on an array (or arrays), the array and
|
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||||
// all of its dependencies are recursively evaluated to produce the result.
|
// all of its dependencies are recursively evaluated to produce the result.
|
||||||
// Once an array is evaluated, it has data and is detached from its inputs.
|
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||||
eval(z);
|
mx::eval(z);
|
||||||
|
|
||||||
// Of course the array can still be an input to other operations. You can even
|
// Of course the array can still be an input to other operations. You can
|
||||||
// call eval on the array again, this will just be a no-op:
|
// even call eval on the array again, this will just be a no-op:
|
||||||
eval(z); // no-op
|
mx::eval(z); // no-op
|
||||||
|
|
||||||
// Some functions or methods on arrays implicitly evaluate them. For example
|
// Some functions or methods on arrays implicitly evaluate them. For example
|
||||||
// accessing a value in an array or printing the array implicitly evaluate it:
|
// accessing a value in an array or printing the array implicitly evaluate it:
|
||||||
z = ones({1});
|
z = mx::ones({1});
|
||||||
z.item<float>(); // implicit evaluation
|
z.item<float>(); // implicit evaluation
|
||||||
|
|
||||||
z = ones({2, 2});
|
z = mx::ones({2, 2});
|
||||||
std::cout << z << std::endl; // implicit evaluation
|
std::cout << z << std::endl; // implicit evaluation
|
||||||
}
|
}
|
||||||
|
|
||||||
void automatic_differentiation() {
|
void automatic_differentiation() {
|
||||||
auto fn = [](array x) { return square(x); };
|
auto fn = [](mx::array x) { return mx::square(x); };
|
||||||
|
|
||||||
// Computing the derivative function of a function
|
// Computing the derivative function of a function
|
||||||
auto grad_fn = grad(fn);
|
auto grad_fn = mx::grad(fn);
|
||||||
// Call grad_fn on the input to get the derivative
|
// Call grad_fn on the input to get the derivative
|
||||||
auto x = array(1.5);
|
auto x = mx::array(1.5);
|
||||||
auto dfdx = grad_fn(x);
|
auto dfdx = grad_fn(x);
|
||||||
// dfdx is 2 * x
|
// dfdx is 2 * x
|
||||||
|
|
||||||
// Get the second derivative by composing grad with grad
|
// Get the second derivative by composing grad with grad
|
||||||
auto d2fdx2 = grad(grad(fn))(x);
|
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
|
||||||
// d2fdx2 is 2
|
// d2fdx2 is 2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
22
examples/export/CMakeLists.txt
Normal file
22
examples/export/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.27)
|
||||||
|
|
||||||
|
project(import_mlx LANGUAGES CXX)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
|
find_package(
|
||||||
|
Python 3.9
|
||||||
|
COMPONENTS Interpreter Development.Module
|
||||||
|
REQUIRED)
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE MLX_ROOT)
|
||||||
|
find_package(MLX CONFIG REQUIRED)
|
||||||
|
|
||||||
|
add_executable(eval_mlp eval_mlp.cpp)
|
||||||
|
target_link_libraries(eval_mlp PRIVATE mlx)
|
||||||
|
|
||||||
|
add_executable(train_mlp train_mlp.cpp)
|
||||||
|
target_link_libraries(train_mlp PRIVATE mlx)
|
||||||
49
examples/export/README.md
Normal file
49
examples/export/README.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
## Setup
|
||||||
|
|
||||||
|
Install MLX:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx>=0.22
|
||||||
|
```
|
||||||
|
|
||||||
|
Build the C++ examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||||
|
cmake --build build
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
### Eval MLP
|
||||||
|
|
||||||
|
Run the Python script to export the eval function:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval_mlp.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the C++ program to import and run the function:
|
||||||
|
|
||||||
|
```
|
||||||
|
./build/eval_mlp
|
||||||
|
```
|
||||||
|
|
||||||
|
The Python and C++ programs should output the same result.
|
||||||
|
|
||||||
|
### Train MLP
|
||||||
|
|
||||||
|
Run the Python script to export the model initialization and training
|
||||||
|
functions:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train_mlp.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the C++ program to import and run the functions:
|
||||||
|
|
||||||
|
```
|
||||||
|
./build/train_mlp
|
||||||
|
```
|
||||||
|
|
||||||
|
The Python and C++ programs should output the same results.
|
||||||
25
examples/export/eval_mlp.cpp
Normal file
25
examples/export/eval_mlp.cpp
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <mlx/mlx.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
int batch_size = 8;
|
||||||
|
int input_dim = 32;
|
||||||
|
|
||||||
|
// Make the input
|
||||||
|
mx::random::seed(42);
|
||||||
|
auto example_x = mx::random::uniform({batch_size, input_dim});
|
||||||
|
|
||||||
|
// Import the function
|
||||||
|
auto forward = mx::import_function("eval_mlp.mlxfn");
|
||||||
|
|
||||||
|
// Call the imported function
|
||||||
|
auto out = forward({example_x})[0];
|
||||||
|
|
||||||
|
std::cout << out << std::endl;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
52
examples/export/eval_mlp.py
Normal file
52
examples/export/eval_mlp.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.utils
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
"""A simple MLP."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||||
|
self.layers = [
|
||||||
|
nn.Linear(idim, odim)
|
||||||
|
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for l in self.layers[:-1]:
|
||||||
|
x = nn.relu(l(x))
|
||||||
|
return self.layers[-1](x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
batch_size = 8
|
||||||
|
input_dim = 32
|
||||||
|
output_dim = 10
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
mx.random.seed(0) # Seed for params
|
||||||
|
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
|
||||||
|
mx.eval(model)
|
||||||
|
|
||||||
|
# Note, the model parameters are saved in the export function
|
||||||
|
def forward(x):
|
||||||
|
return model(x)
|
||||||
|
|
||||||
|
mx.random.seed(42) # Seed for input
|
||||||
|
example_x = mx.random.uniform(shape=(batch_size, input_dim))
|
||||||
|
|
||||||
|
mx.export_function("eval_mlp.mlxfn", forward, example_x)
|
||||||
|
|
||||||
|
# Import in Python
|
||||||
|
imported_forward = mx.import_function("eval_mlp.mlxfn")
|
||||||
|
expected = forward(example_x)
|
||||||
|
(out,) = imported_forward(example_x)
|
||||||
|
assert mx.allclose(expected, out)
|
||||||
|
print(out)
|
||||||
35
examples/export/train_mlp.cpp
Normal file
35
examples/export/train_mlp.cpp
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <mlx/mlx.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
int batch_size = 8;
|
||||||
|
int input_dim = 32;
|
||||||
|
int output_dim = 10;
|
||||||
|
|
||||||
|
auto state = mx::import_function("init_mlp.mlxfn")({});
|
||||||
|
|
||||||
|
// Make the input
|
||||||
|
mx::random::seed(42);
|
||||||
|
auto example_X = mx::random::normal({batch_size, input_dim});
|
||||||
|
auto example_y = mx::random::randint(0, output_dim, {batch_size});
|
||||||
|
|
||||||
|
// Import the function
|
||||||
|
auto step = mx::import_function("train_mlp.mlxfn");
|
||||||
|
|
||||||
|
// Call the imported function
|
||||||
|
for (int it = 0; it < 100; ++it) {
|
||||||
|
state.insert(state.end(), {example_X, example_y});
|
||||||
|
state = step(state);
|
||||||
|
eval(state);
|
||||||
|
auto loss = state.back();
|
||||||
|
state.pop_back();
|
||||||
|
if (it % 10 == 0) {
|
||||||
|
std::cout << "Loss " << loss.item<float>() << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
76
examples/export/train_mlp.py
Normal file
76
examples/export/train_mlp.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
import mlx.utils
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
"""A simple MLP."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||||
|
self.layers = [
|
||||||
|
nn.Linear(idim, odim)
|
||||||
|
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for l in self.layers[:-1]:
|
||||||
|
x = nn.relu(l(x))
|
||||||
|
return self.layers[-1](x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
batch_size = 8
|
||||||
|
input_dim = 32
|
||||||
|
output_dim = 10
|
||||||
|
|
||||||
|
def init():
|
||||||
|
# Seed for the parameter initialization
|
||||||
|
mx.random.seed(0)
|
||||||
|
model = MLP(
|
||||||
|
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
|
||||||
|
)
|
||||||
|
optimizer = optim.SGD(learning_rate=1e-1)
|
||||||
|
optimizer.init(model.parameters())
|
||||||
|
state = [model.parameters(), optimizer.state]
|
||||||
|
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
|
||||||
|
return model, optimizer, tree_structure, state
|
||||||
|
|
||||||
|
# Export the model parameter initialization
|
||||||
|
model, optimizer, tree_structure, state = init()
|
||||||
|
mx.eval(state)
|
||||||
|
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
|
||||||
|
|
||||||
|
def loss_fn(params, X, y):
|
||||||
|
model.update(params)
|
||||||
|
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||||
|
|
||||||
|
def step(*inputs):
|
||||||
|
*state, X, y = inputs
|
||||||
|
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
|
||||||
|
optimizer.state = opt_state
|
||||||
|
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
|
||||||
|
params = optimizer.apply_gradients(grads, params)
|
||||||
|
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
|
||||||
|
return *state, loss
|
||||||
|
|
||||||
|
# Make some random data
|
||||||
|
mx.random.seed(42)
|
||||||
|
example_X = mx.random.normal(shape=(batch_size, input_dim))
|
||||||
|
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
|
||||||
|
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
|
||||||
|
|
||||||
|
# Export one step of SGD
|
||||||
|
imported_step = mx.import_function("train_mlp.mlxfn")
|
||||||
|
|
||||||
|
for it in range(100):
|
||||||
|
*state, loss = imported_step(*state, example_X, example_y)
|
||||||
|
if it % 10 == 0:
|
||||||
|
print(f"Loss {loss.item():.6}")
|
||||||
@@ -10,7 +10,6 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|||||||
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||||
|
|
||||||
# ----------------------------- Dependencies -----------------------------
|
# ----------------------------- Dependencies -----------------------------
|
||||||
find_package(MLX CONFIG REQUIRED)
|
|
||||||
find_package(
|
find_package(
|
||||||
Python 3.8
|
Python 3.8
|
||||||
COMPONENTS Interpreter Development.Module
|
COMPONENTS Interpreter Development.Module
|
||||||
@@ -18,10 +17,15 @@ find_package(
|
|||||||
execute_process(
|
execute_process(
|
||||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
OUTPUT_VARIABLE NB_DIR)
|
OUTPUT_VARIABLE nanobind_ROOT)
|
||||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
|
||||||
find_package(nanobind CONFIG REQUIRED)
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
|
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE MLX_ROOT)
|
||||||
|
find_package(MLX CONFIG REQUIRED)
|
||||||
|
|
||||||
# ----------------------------- Extensions -----------------------------
|
# ----------------------------- Extensions -----------------------------
|
||||||
|
|
||||||
# Add library
|
# Add library
|
||||||
|
|||||||
@@ -1,25 +1,34 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
#include <cassert>
|
#include <dlfcn.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/copy.h"
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include "axpby/axpby.h"
|
#include "axpby/axpby.h"
|
||||||
|
|
||||||
#ifdef ACCELERATE_NEW_LAPACK
|
|
||||||
#include <vecLib/cblas_new.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef _METAL_
|
#ifdef _METAL_
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace my_ext {
|
||||||
|
|
||||||
|
// A helper function to find the location of the current binary on disk.
|
||||||
|
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
|
||||||
|
std::string current_binary_dir() {
|
||||||
|
static std::string binary_dir = []() {
|
||||||
|
Dl_info info;
|
||||||
|
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||||
|
throw std::runtime_error("Unable to get current binary dir.");
|
||||||
|
}
|
||||||
|
return std::filesystem::path(info.dli_fname).parent_path().string();
|
||||||
|
}();
|
||||||
|
return binary_dir;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Operation Implementation
|
// Operation Implementation
|
||||||
@@ -32,24 +41,24 @@ namespace mlx::core {
|
|||||||
* Follow numpy style broadcasting between x and y
|
* Follow numpy style broadcasting between x and y
|
||||||
* Inputs are upcasted to floats if needed
|
* Inputs are upcasted to floats if needed
|
||||||
**/
|
**/
|
||||||
array axpby(
|
mx::array axpby(
|
||||||
const array& x, // Input array x
|
const mx::array& x, // Input mx::array x
|
||||||
const array& y, // Input array y
|
const mx::array& y, // Input mx::array y
|
||||||
const float alpha, // Scaling factor for x
|
const float alpha, // Scaling factor for x
|
||||||
const float beta, // Scaling factor for y
|
const float beta, // Scaling factor for y
|
||||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||||
) {
|
) {
|
||||||
// Promote dtypes between x and y as needed
|
// Promote dtypes between x and y as needed
|
||||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||||
|
|
||||||
// Upcast to float32 for non-floating point inputs x and y
|
// Upcast to float32 for non-floating point inputs x and y
|
||||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32)
|
||||||
? promoted_dtype
|
? promoted_dtype
|
||||||
: promote_types(promoted_dtype, float32);
|
: promote_types(promoted_dtype, mx::float32);
|
||||||
|
|
||||||
// Cast x and y up to the determined dtype (on the same stream s)
|
// Cast x and y up to the determined dtype (on the same stream s)
|
||||||
auto x_casted = astype(x, out_dtype, s);
|
auto x_casted = mx::astype(x, out_dtype, s);
|
||||||
auto y_casted = astype(y, out_dtype, s);
|
auto y_casted = mx::astype(y, out_dtype, s);
|
||||||
|
|
||||||
// Broadcast the shapes of x and y (on the same stream s)
|
// Broadcast the shapes of x and y (on the same stream s)
|
||||||
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||||
@@ -57,12 +66,12 @@ array axpby(
|
|||||||
|
|
||||||
// Construct the array as the output of the Axpby primitive
|
// Construct the array as the output of the Axpby primitive
|
||||||
// with the broadcasted and upcasted arrays as inputs
|
// with the broadcasted and upcasted arrays as inputs
|
||||||
return array(
|
return mx::array(
|
||||||
/* const std::vector<int>& shape = */ out_shape,
|
/* const mx::Shape& shape = */ out_shape,
|
||||||
/* Dtype dtype = */ out_dtype,
|
/* mx::Dtype dtype = */ out_dtype,
|
||||||
/* std::unique_ptr<Primitive> primitive = */
|
/* std::shared_ptr<mx::Primitive> primitive = */
|
||||||
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -71,140 +80,69 @@ array axpby(
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void axpby_impl(
|
void axpby_impl(
|
||||||
const array& x,
|
const mx::array& x,
|
||||||
const array& y,
|
const mx::array& y,
|
||||||
array& out,
|
mx::array& out,
|
||||||
float alpha_,
|
float alpha_,
|
||||||
float beta_) {
|
float beta_,
|
||||||
// We only allocate memory when we are ready to fill the output
|
mx::Stream stream) {
|
||||||
// malloc_or_wait synchronously allocates available memory
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
// There may be a wait executed here if the allocation is requested
|
|
||||||
// under memory-pressured conditions
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// Collect input and output data pointers
|
// Get the CPU command encoder and register input and output arrays
|
||||||
const T* x_ptr = x.data<T>();
|
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||||
const T* y_ptr = y.data<T>();
|
encoder.set_input_array(x);
|
||||||
T* out_ptr = out.data<T>();
|
encoder.set_input_array(y);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
// Launch the CPU kernel
|
||||||
|
encoder.dispatch([x_ptr = x.data<T>(),
|
||||||
|
y_ptr = y.data<T>(),
|
||||||
|
out_ptr = out.data<T>(),
|
||||||
|
size = out.size(),
|
||||||
|
shape = out.shape(),
|
||||||
|
x_strides = x.strides(),
|
||||||
|
y_strides = y.strides(),
|
||||||
|
alpha_,
|
||||||
|
beta_]() {
|
||||||
// Cast alpha and beta to the relevant types
|
// Cast alpha and beta to the relevant types
|
||||||
T alpha = static_cast<T>(alpha_);
|
T alpha = static_cast<T>(alpha_);
|
||||||
T beta = static_cast<T>(beta_);
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
// Do the element-wise operation for each output
|
// Do the element-wise operation for each output
|
||||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||||
// Map linear indices to offsets in x and y
|
// Map linear indices to offsets in x and y
|
||||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||||
|
|
||||||
// We allocate the output to be contiguous and regularly strided
|
// We allocate the output to be contiguous and regularly strided
|
||||||
// (defaults to row major) and hence it doesn't need additional mapping
|
// (defaults to row major) and hence it doesn't need additional mapping
|
||||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
void Axpby::eval_cpu(
|
||||||
void Axpby::eval(
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<array>& inputs,
|
std::vector<mx::array>& outputs) {
|
||||||
std::vector<array>& outputs) {
|
|
||||||
// Check the inputs (registered in the op while constructing the out array)
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Dispatch to the correct dtype
|
// Dispatch to the correct dtype
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == mx::float32) {
|
||||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == float16) {
|
} else if (out.dtype() == mx::float16) {
|
||||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == bfloat16) {
|
} else if (out.dtype() == mx::bfloat16) {
|
||||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == complex64) {
|
} else if (out.dtype() == mx::complex64) {
|
||||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Axpby is only supported for floating point types.");
|
"Axpby is only supported for floating point types.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Primitive Accelerate Backend Implementation
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#ifdef ACCELERATE_NEW_LAPACK
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void axpby_impl_accelerate(
|
|
||||||
const array& x,
|
|
||||||
const array& y,
|
|
||||||
array& out,
|
|
||||||
float alpha_,
|
|
||||||
float beta_) {
|
|
||||||
// Accelerate library provides catlas_saxpby which does
|
|
||||||
// Y = (alpha * X) + (beta * Y) in place
|
|
||||||
// To use it, we first copy the data in y over to the output array
|
|
||||||
|
|
||||||
// This specialization requires both x and y be contiguous in the same mode
|
|
||||||
// i.e: corresponding linear indices in both point to corresponding elements
|
|
||||||
// The data in the output array is allocated to match the strides in y
|
|
||||||
// such that x, y, and out are contiguous in the same mode and
|
|
||||||
// no transposition is needed
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// We then copy over the elements using the contiguous vector specialization
|
|
||||||
copy_inplace(y, out, CopyType::Vector);
|
|
||||||
|
|
||||||
// Get x and y pointers for catlas_saxpby
|
|
||||||
const T* x_ptr = x.data<T>();
|
|
||||||
T* y_ptr = out.data<T>();
|
|
||||||
|
|
||||||
T alpha = static_cast<T>(alpha_);
|
|
||||||
T beta = static_cast<T>(beta_);
|
|
||||||
|
|
||||||
// Call the inplace accelerate operator
|
|
||||||
catlas_saxpby(
|
|
||||||
/* N = */ out.size(),
|
|
||||||
/* ALPHA = */ alpha,
|
|
||||||
/* X = */ x_ptr,
|
|
||||||
/* INCX = */ 1,
|
|
||||||
/* BETA = */ beta,
|
|
||||||
/* Y = */ y_ptr,
|
|
||||||
/* INCY = */ 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Evaluate primitive on CPU using accelerate specializations */
|
|
||||||
void Axpby::eval_cpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& x = inputs[0];
|
|
||||||
auto& y = inputs[1];
|
|
||||||
auto& out = outputs[0];
|
|
||||||
|
|
||||||
// Accelerate specialization for contiguous single precision float arrays
|
|
||||||
if (out.dtype() == float32 &&
|
|
||||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
|
||||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
|
||||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to common backend if specializations are not available
|
|
||||||
eval(inputs, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
#else // Accelerate not available
|
|
||||||
|
|
||||||
/** Evaluate primitive on CPU falling back to common backend */
|
|
||||||
void Axpby::eval_cpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs) {
|
|
||||||
eval(inputs, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Primitive Metal Backend Implementation
|
// Primitive Metal Backend Implementation
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -213,10 +151,9 @@ void Axpby::eval_cpu(
|
|||||||
|
|
||||||
/** Evaluate primitive on GPU */
|
/** Evaluate primitive on GPU */
|
||||||
void Axpby::eval_gpu(
|
void Axpby::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<mx::array>& outputs) {
|
||||||
// Prepare inputs
|
// Prepare inputs
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
@@ -225,7 +162,7 @@ void Axpby::eval_gpu(
|
|||||||
// and each stream carries its device identifiers
|
// and each stream carries its device identifiers
|
||||||
auto& s = stream();
|
auto& s = stream();
|
||||||
// We get the needed metal device using the stream
|
// We get the needed metal device using the stream
|
||||||
auto& d = metal::device(s.device);
|
auto& d = mx::metal::device(s.device);
|
||||||
|
|
||||||
// Prepare to specialize based on contiguity
|
// Prepare to specialize based on contiguity
|
||||||
bool contiguous_kernel =
|
bool contiguous_kernel =
|
||||||
@@ -235,29 +172,28 @@ void Axpby::eval_gpu(
|
|||||||
// Allocate output memory with strides based on specialization
|
// Allocate output memory with strides based on specialization
|
||||||
if (contiguous_kernel) {
|
if (contiguous_kernel) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve name of kernel (corresponds to axpby.metal)
|
// Resolve name of kernel (corresponds to axpby.metal)
|
||||||
std::ostringstream kname;
|
std::string kname = "axpby_";
|
||||||
kname << "axpby_";
|
kname += (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
kname += type_to_name(out);
|
||||||
kname << type_to_name(out);
|
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Load the metal library
|
||||||
d.register_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname, lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Kernel parameters are registered with buffer indices corresponding to
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
// those in the kernel declaration at axpby.metal
|
// those in the kernel declaration at axpby.metal
|
||||||
@@ -272,15 +208,15 @@ void Axpby::eval_gpu(
|
|||||||
compute_encoder.set_output_array(out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Encode alpha and beta
|
// Encode alpha and beta
|
||||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
compute_encoder.set_bytes(alpha_, 3);
|
||||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
compute_encoder.set_bytes(beta_, 4);
|
||||||
|
|
||||||
// Encode shape, strides and ndim if needed
|
// Encode shape, strides and ndim if needed
|
||||||
if (!contiguous_kernel) {
|
if (!contiguous_kernel) {
|
||||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
compute_encoder.set_vector_bytes(y.strides(), 7);
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
compute_encoder.set_bytes(ndim, 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
// We launch 1 thread for each input and make sure that the number of
|
// We launch 1 thread for each input and make sure that the number of
|
||||||
@@ -295,15 +231,15 @@ void Axpby::eval_gpu(
|
|||||||
|
|
||||||
// Launch the grid with the given number of threads divided among
|
// Launch the grid with the given number of threads divided among
|
||||||
// the given threadgroups
|
// the given threadgroups
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // Metal is not available
|
#else // Metal is not available
|
||||||
|
|
||||||
/** Fail evaluation on GPU */
|
/** Fail evaluation on GPU */
|
||||||
void Axpby::eval_gpu(
|
void Axpby::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
std::vector<array>& out) {
|
std::vector<mx::array>& out) {
|
||||||
throw std::runtime_error("Axpby has no GPU implementation.");
|
throw std::runtime_error("Axpby has no GPU implementation.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,9 +250,9 @@ void Axpby::eval_gpu(
|
|||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/** The Jacobian-vector product. */
|
/** The Jacobian-vector product. */
|
||||||
std::vector<array> Axpby::jvp(
|
std::vector<mx::array> Axpby::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<mx::array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<mx::array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Forward mode diff that pushes along the tangents
|
// Forward mode diff that pushes along the tangents
|
||||||
// The jvp transform on the primitive can built with ops
|
// The jvp transform on the primitive can built with ops
|
||||||
@@ -328,8 +264,8 @@ std::vector<array> Axpby::jvp(
|
|||||||
// scaled by beta
|
// scaled by beta
|
||||||
if (argnums.size() > 1) {
|
if (argnums.size() > 1) {
|
||||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||||
auto scale_arr = array(scale, tangents[0].dtype());
|
auto scale_arr = mx::array(scale, tangents[0].dtype());
|
||||||
return {multiply(scale_arr, tangents[0], stream())};
|
return {mx::multiply(scale_arr, tangents[0], stream())};
|
||||||
}
|
}
|
||||||
// If, argnums = {0, 1}, we take contributions from both
|
// If, argnums = {0, 1}, we take contributions from both
|
||||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||||
@@ -339,24 +275,24 @@ std::vector<array> Axpby::jvp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** The vector-Jacobian product. */
|
/** The vector-Jacobian product. */
|
||||||
std::vector<array> Axpby::vjp(
|
std::vector<mx::array> Axpby::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<mx::array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<mx::array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>&) {
|
const std::vector<mx::array>&) {
|
||||||
// Reverse mode diff
|
// Reverse mode diff
|
||||||
std::vector<array> vjps;
|
std::vector<mx::array> vjps;
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
auto scale = arg == 0 ? alpha_ : beta_;
|
auto scale = arg == 0 ? alpha_ : beta_;
|
||||||
auto scale_arr = array(scale, cotangents[0].dtype());
|
auto scale_arr = mx::array(scale, cotangents[0].dtype());
|
||||||
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream()));
|
||||||
}
|
}
|
||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Vectorize primitive along given axis */
|
/** Vectorize primitive along given axis */
|
||||||
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||||
}
|
}
|
||||||
@@ -367,4 +303,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
|||||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace my_ext
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mx = mlx::core;
|
||||||
|
|
||||||
|
namespace my_ext {
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Operation
|
// Operation
|
||||||
@@ -18,22 +20,22 @@ namespace mlx::core {
|
|||||||
* Follow numpy style broadcasting between x and y
|
* Follow numpy style broadcasting between x and y
|
||||||
* Inputs are upcasted to floats if needed
|
* Inputs are upcasted to floats if needed
|
||||||
**/
|
**/
|
||||||
array axpby(
|
mx::array axpby(
|
||||||
const array& x, // Input array x
|
const mx::array& x, // Input array x
|
||||||
const array& y, // Input array y
|
const mx::array& y, // Input array y
|
||||||
const float alpha, // Scaling factor for x
|
const float alpha, // Scaling factor for x
|
||||||
const float beta, // Scaling factor for y
|
const float beta, // Scaling factor for y
|
||||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
mx::StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||||
);
|
);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Primitive
|
// Primitive
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
class Axpby : public Primitive {
|
class Axpby : public mx::Primitive {
|
||||||
public:
|
public:
|
||||||
explicit Axpby(Stream stream, float alpha, float beta)
|
explicit Axpby(mx::Stream stream, float alpha, float beta)
|
||||||
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
: mx::Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||||
@@ -42,23 +44,25 @@ class Axpby : public Primitive {
|
|||||||
* To avoid unnecessary allocations, the evaluation function
|
* To avoid unnecessary allocations, the evaluation function
|
||||||
* is responsible for allocating space for the array.
|
* is responsible for allocating space for the array.
|
||||||
*/
|
*/
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(
|
||||||
override;
|
const std::vector<mx::array>& inputs,
|
||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
std::vector<mx::array>& outputs) override;
|
||||||
override;
|
void eval_gpu(
|
||||||
|
const std::vector<mx::array>& inputs,
|
||||||
|
std::vector<mx::array>& outputs) override;
|
||||||
|
|
||||||
/** The Jacobian-vector product. */
|
/** The Jacobian-vector product. */
|
||||||
std::vector<array> jvp(
|
std::vector<mx::array> jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<mx::array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<mx::array>& tangents,
|
||||||
const std::vector<int>& argnums) override;
|
const std::vector<int>& argnums) override;
|
||||||
|
|
||||||
/** The vector-Jacobian product. */
|
/** The vector-Jacobian product. */
|
||||||
std::vector<array> vjp(
|
std::vector<mx::array> vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<mx::array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<mx::array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<mx::array>& outputs) override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The primitive must know how to vectorize itself across
|
* The primitive must know how to vectorize itself across
|
||||||
@@ -66,24 +70,21 @@ class Axpby : public Primitive {
|
|||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
*/
|
*/
|
||||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
std::pair<std::vector<mx::array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const mx::Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
float alpha_;
|
float alpha_;
|
||||||
float beta_;
|
float beta_;
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
|
||||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace my_ext
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -13,8 +12,8 @@ template <typename T>
|
|||||||
constant const float& alpha [[buffer(3)]],
|
constant const float& alpha [[buffer(3)]],
|
||||||
constant const float& beta [[buffer(4)]],
|
constant const float& beta [[buffer(4)]],
|
||||||
constant const int* shape [[buffer(5)]],
|
constant const int* shape [[buffer(5)]],
|
||||||
constant const size_t* x_strides [[buffer(6)]],
|
constant const int64_t* x_strides [[buffer(6)]],
|
||||||
constant const size_t* y_strides [[buffer(7)]],
|
constant const int64_t* y_strides [[buffer(7)]],
|
||||||
constant const int& ndim [[buffer(8)]],
|
constant const int& ndim [[buffer(8)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||||
@@ -35,29 +34,14 @@ template <typename T>
|
|||||||
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
#define instantiate_axpby(type_name, type) \
|
#define instantiate_axpby(type_name, type) \
|
||||||
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \
|
||||||
axpby_general<type>( \
|
instantiate_kernel( \
|
||||||
device const type* x [[buffer(0)]], \
|
"axpby_contiguous_" #type_name, axpby_contiguous, type)
|
||||||
device const type* y [[buffer(1)]], \
|
|
||||||
device type* out [[buffer(2)]], \
|
|
||||||
constant const float& alpha [[buffer(3)]], \
|
|
||||||
constant const float& beta [[buffer(4)]], \
|
|
||||||
constant const int* shape [[buffer(5)]], \
|
|
||||||
constant const size_t* x_strides [[buffer(6)]], \
|
|
||||||
constant const size_t* y_strides [[buffer(7)]], \
|
|
||||||
constant const int& ndim [[buffer(8)]], \
|
|
||||||
uint index [[thread_position_in_grid]]); \
|
|
||||||
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
|
||||||
axpby_contiguous<type>( \
|
|
||||||
device const type* x [[buffer(0)]], \
|
|
||||||
device const type* y [[buffer(1)]], \
|
|
||||||
device type* out [[buffer(2)]], \
|
|
||||||
constant const float& alpha [[buffer(3)]], \
|
|
||||||
constant const float& beta [[buffer(4)]], \
|
|
||||||
uint index [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
instantiate_axpby(float32, float);
|
instantiate_axpby(float32, float);
|
||||||
instantiate_axpby(float16, half);
|
instantiate_axpby(float16, half);
|
||||||
instantiate_axpby(bfloat16, bfloat16_t);
|
instantiate_axpby(bfloat16, bfloat16_t);
|
||||||
instantiate_axpby(complex64, complex64_t);
|
instantiate_axpby(complex64, complex64_t);
|
||||||
|
// clang-format on
|
||||||
|
|||||||
@@ -8,14 +8,12 @@
|
|||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
using namespace mlx::core;
|
|
||||||
|
|
||||||
NB_MODULE(_ext, m) {
|
NB_MODULE(_ext, m) {
|
||||||
m.doc() = "Sample extension for MLX";
|
m.doc() = "Sample extension for MLX";
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"axpby",
|
"axpby",
|
||||||
&axpby,
|
&my_ext::axpby,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"y"_a,
|
"y"_a,
|
||||||
"alpha"_a,
|
"alpha"_a,
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = [
|
requires = [
|
||||||
"setuptools>=42",
|
"setuptools>=42",
|
||||||
"cmake>=3.24",
|
"cmake>=3.25",
|
||||||
"mlx>=0.18.0",
|
"mlx>=0.18.0",
|
||||||
"nanobind==2.2.0",
|
"nanobind==2.4.0",
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
setuptools>=42
|
setuptools>=42
|
||||||
cmake>=3.24
|
cmake>=3.25
|
||||||
mlx>=0.18.1
|
mlx>=0.21.0
|
||||||
nanobind==2.2.0
|
nanobind==2.4.0
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
|
|||||||
|
|
||||||
a = mx.ones((3, 4))
|
a = mx.ones((3, 4))
|
||||||
b = mx.ones((3, 4))
|
b = mx.ones((3, 4))
|
||||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||||
|
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
|
||||||
|
|
||||||
print(f"c shape: {c.shape}")
|
print(f"c shape: {c_cpu.shape}")
|
||||||
print(f"c dtype: {c.dtype}")
|
print(f"c dtype: {c_cpu.dtype}")
|
||||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
|
||||||
|
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
|
||||||
|
|||||||
13
mlx.pc.in
13
mlx.pc.in
@@ -28,10 +28,19 @@ endif()
|
|||||||
if (@MLX_BUILD_METAL@)
|
if (@MLX_BUILD_METAL@)
|
||||||
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
|
set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
|
||||||
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
||||||
set_and_check(MLX_INCLUDE_DIRS
|
set(MLX_INCLUDE_DIRS
|
||||||
${MLX_INCLUDE_DIRS}
|
"${MLX_INCLUDE_DIRS};"
|
||||||
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
|
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
|
||||||
)
|
)
|
||||||
|
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
|
||||||
|
set(MLX_INCLUDE_DIRS
|
||||||
|
"${MLX_INCLUDE_DIRS};"
|
||||||
|
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
|
||||||
|
else()
|
||||||
|
set(MLX_INCLUDE_DIRS
|
||||||
|
"${MLX_INCLUDE_DIRS};"
|
||||||
|
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set_target_properties(mlx PROPERTIES
|
set_target_properties(mlx PROPERTIES
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
@@ -18,24 +20,48 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||||
|
|
||||||
|
# Define MLX_VERSION only in the version.cpp file.
|
||||||
|
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||||
|
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||||
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||||
|
|
||||||
|
if(MSVC)
|
||||||
|
# Disable some MSVC warnings to speed up compilation.
|
||||||
|
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
# Export symbols by default to behave like macOS/linux.
|
||||||
|
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||||
|
|
||||||
if(MLX_BUILD_CPU)
|
if(MLX_BUILD_CPU)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cpu)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||||
if(MLX_BUILD_ACCELERATE)
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
|
||||||
elseif(MLX_BUILD_CPU)
|
|
||||||
target_sources(
|
|
||||||
mlx
|
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||||
|
else()
|
||||||
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||||
|
else()
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -4,12 +4,11 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/scheduler.h"
|
|
||||||
|
|
||||||
namespace mlx::core::allocator {
|
namespace mlx::core::allocator {
|
||||||
|
|
||||||
Buffer malloc(size_t size) {
|
Buffer malloc(size_t size) {
|
||||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
auto buffer = allocator().malloc(size);
|
||||||
if (size && !buffer.ptr()) {
|
if (size && !buffer.ptr()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
@@ -19,48 +18,7 @@ Buffer malloc(size_t size) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void free(Buffer buffer) {
|
void free(Buffer buffer) {
|
||||||
return allocator().free(buffer);
|
allocator().free(buffer);
|
||||||
}
|
|
||||||
|
|
||||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
|
||||||
void* ptr = std::malloc(size + sizeof(size_t));
|
|
||||||
if (ptr != nullptr) {
|
|
||||||
*static_cast<size_t*>(ptr) = size;
|
|
||||||
}
|
|
||||||
return Buffer{ptr};
|
|
||||||
}
|
|
||||||
|
|
||||||
void CommonAllocator::free(Buffer buffer) {
|
|
||||||
std::free(buffer.ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CommonAllocator::size(Buffer buffer) const {
|
|
||||||
if (buffer.ptr() == nullptr) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return *static_cast<size_t*>(buffer.ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
Buffer malloc_or_wait(size_t size) {
|
|
||||||
auto buffer = allocator().malloc(size);
|
|
||||||
|
|
||||||
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
|
||||||
scheduler::wait_for_one();
|
|
||||||
buffer = allocator().malloc(size);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try swapping if needed
|
|
||||||
if (size && !buffer.ptr()) {
|
|
||||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (size && !buffer.ptr()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
return buffer;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
|||||||
@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
|
|||||||
|
|
||||||
void free(Buffer buffer);
|
void free(Buffer buffer);
|
||||||
|
|
||||||
// Wait for running tasks to finish and free up memory
|
|
||||||
// if allocation fails
|
|
||||||
Buffer malloc_or_wait(size_t size);
|
|
||||||
|
|
||||||
class Allocator {
|
class Allocator {
|
||||||
/** Abstract base class for a memory allocator. */
|
/** Abstract base class for a memory allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
virtual Buffer malloc(size_t size) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
virtual size_t size(Buffer buffer) const = 0;
|
virtual size_t size(Buffer buffer) const = 0;
|
||||||
|
|
||||||
@@ -53,16 +49,4 @@ class Allocator {
|
|||||||
|
|
||||||
Allocator& allocator();
|
Allocator& allocator();
|
||||||
|
|
||||||
class CommonAllocator : public Allocator {
|
|
||||||
/** A general CPU allocator. */
|
|
||||||
public:
|
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
|
||||||
virtual void free(Buffer buffer) override;
|
|
||||||
virtual size_t size(Buffer buffer) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
CommonAllocator() = default;
|
|
||||||
friend Allocator& allocator();
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
|||||||
149
mlx/array.cpp
149
mlx/array.cpp
@@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
@@ -9,28 +10,14 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
/** Return true if we are currently performing a function transformation in
|
|
||||||
* order to keep the graph when evaluating tracer arrays. */
|
|
||||||
bool in_tracing() {
|
|
||||||
return detail::InTracing::in_tracing();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool retain_graph() {
|
|
||||||
return detail::RetainGraph::retain_graph();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||||
auto cval = static_cast<complex64_t>(val);
|
auto cval = static_cast<complex64_t>(val);
|
||||||
init(&cval);
|
init(&cval);
|
||||||
}
|
}
|
||||||
|
|
||||||
array::array(
|
array::array(
|
||||||
std::vector<int> shape,
|
Shape shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::shared_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
std::vector<array> inputs)
|
std::vector<array> inputs)
|
||||||
@@ -38,19 +25,30 @@ array::array(
|
|||||||
std::move(shape),
|
std::move(shape),
|
||||||
dtype,
|
dtype,
|
||||||
std::move(primitive),
|
std::move(primitive),
|
||||||
std::move(inputs))) {}
|
std::move(inputs))) {
|
||||||
|
if (has_primitive() && this->primitive().stream().device == Device::gpu) {
|
||||||
|
for (auto& in : this->inputs()) {
|
||||||
|
if (in.dtype() == float64) {
|
||||||
|
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (this->dtype() == float64) {
|
||||||
|
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> array::make_arrays(
|
std::vector<array> array::make_arrays(
|
||||||
std::vector<std::vector<int>> shapes,
|
std::vector<Shape> shapes,
|
||||||
const std::vector<Dtype>& dtypes,
|
const std::vector<Dtype>& dtypes,
|
||||||
const std::shared_ptr<Primitive>& primitive,
|
const std::shared_ptr<Primitive>& primitive,
|
||||||
const std::vector<array>& inputs) {
|
const std::vector<array>& inputs) {
|
||||||
std::vector<array> outputs;
|
std::vector<array> outputs;
|
||||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
for (int i = 0; i < std::ssize(shapes); ++i) {
|
||||||
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
|
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
|
||||||
}
|
}
|
||||||
// For each node in |outputs|, its siblings are the other nodes.
|
// For each node in |outputs|, its siblings are the other nodes.
|
||||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
for (int i = 0; i < std::ssize(outputs); ++i) {
|
||||||
auto siblings = outputs;
|
auto siblings = outputs;
|
||||||
siblings.erase(siblings.begin() + i);
|
siblings.erase(siblings.begin() + i);
|
||||||
outputs[i].set_siblings(std::move(siblings), i);
|
outputs[i].set_siblings(std::move(siblings), i);
|
||||||
@@ -58,47 +56,59 @@ std::vector<array> array::make_arrays(
|
|||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array array::unsafe_weak_copy(const array& other) {
|
||||||
|
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
|
||||||
|
cpy.set_data(
|
||||||
|
other.buffer(),
|
||||||
|
other.data_size(),
|
||||||
|
other.strides(),
|
||||||
|
other.flags(),
|
||||||
|
[](auto) {});
|
||||||
|
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||||
|
return cpy;
|
||||||
|
}
|
||||||
|
|
||||||
array::array(std::initializer_list<float> data)
|
array::array(std::initializer_list<float> data)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
std::vector<int>{static_cast<int>(data.size())},
|
Shape{static_cast<ShapeElem>(data.size())},
|
||||||
float32)) {
|
float32)) {
|
||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
array::array(std::initializer_list<int> data, Dtype dtype)
|
array::array(std::initializer_list<int> data, Dtype dtype)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
std::vector<int>{static_cast<int>(data.size())},
|
Shape{static_cast<ShapeElem>(data.size())},
|
||||||
dtype)) {
|
dtype)) {
|
||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Build an array from a shared buffer */
|
/* Build an array from a shared buffer */
|
||||||
array::array(
|
array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||||
allocator::Buffer data,
|
|
||||||
std::vector<int> shape,
|
|
||||||
Dtype dtype,
|
|
||||||
deleter_t deleter)
|
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
set_data(data, deleter);
|
set_data(data, deleter);
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::detach() {
|
void array::detach() {
|
||||||
|
array_desc_->primitive = nullptr;
|
||||||
|
for (auto& s : array_desc_->siblings) {
|
||||||
|
s.array_desc_->primitive = nullptr;
|
||||||
|
}
|
||||||
for (auto& s : array_desc_->siblings) {
|
for (auto& s : array_desc_->siblings) {
|
||||||
s.array_desc_->inputs.clear();
|
s.array_desc_->inputs.clear();
|
||||||
s.array_desc_->siblings.clear();
|
s.array_desc_->siblings.clear();
|
||||||
s.array_desc_->position = 0;
|
s.array_desc_->position = 0;
|
||||||
s.array_desc_->primitive = nullptr;
|
|
||||||
}
|
}
|
||||||
array_desc_->inputs.clear();
|
array_desc_->inputs.clear();
|
||||||
array_desc_->siblings.clear();
|
array_desc_->siblings.clear();
|
||||||
array_desc_->position = 0;
|
array_desc_->position = 0;
|
||||||
array_desc_->primitive = nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool array::is_available() const {
|
bool array::is_available() const {
|
||||||
if (status() == Status::available) {
|
if (status() == Status::available) {
|
||||||
return true;
|
return true;
|
||||||
} else if (status() == Status::evaluated && event().is_signaled()) {
|
} else if (
|
||||||
|
status() == Status::evaluated &&
|
||||||
|
(!event().valid() || event().is_signaled())) {
|
||||||
set_status(Status::available);
|
set_status(Status::available);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -107,7 +117,10 @@ bool array::is_available() const {
|
|||||||
|
|
||||||
void array::wait() {
|
void array::wait() {
|
||||||
if (!is_available()) {
|
if (!is_available()) {
|
||||||
|
if (event().valid()) {
|
||||||
event().wait();
|
event().wait();
|
||||||
|
detach_event();
|
||||||
|
}
|
||||||
set_status(Status::available);
|
set_status(Status::available);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -122,25 +135,27 @@ void array::eval() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool array::is_tracer() const {
|
bool array::is_tracer() const {
|
||||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
return (array_desc_->is_tracer && detail::in_tracing()) ||
|
||||||
|
detail::retain_graph();
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||||
array_desc_->data_ptr = buffer.raw_ptr();
|
array_desc_->data_ptr = buffer.raw_ptr();
|
||||||
array_desc_->data_size = size();
|
array_desc_->data_size = size();
|
||||||
array_desc_->flags.contiguous = true;
|
array_desc_->flags.contiguous = true;
|
||||||
array_desc_->flags.row_contiguous = true;
|
array_desc_->flags.row_contiguous = true;
|
||||||
auto max_dim = std::max_element(shape().begin(), shape().end());
|
auto max_dim =
|
||||||
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
|
static_cast<int64_t>(*std::max_element(shape().begin(), shape().end()));
|
||||||
|
array_desc_->flags.col_contiguous = size() <= 1 || size() == max_dim;
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::set_data(
|
void array::set_data(
|
||||||
allocator::Buffer buffer,
|
allocator::Buffer buffer,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
std::vector<size_t> strides,
|
Strides strides,
|
||||||
Flags flags,
|
Flags flags,
|
||||||
deleter_t d) {
|
Deleter d) {
|
||||||
array_desc_->data = std::make_shared<Data>(buffer, d);
|
array_desc_->data = std::make_shared<Data>(buffer, d);
|
||||||
array_desc_->data_ptr = buffer.raw_ptr();
|
array_desc_->data_ptr = buffer.raw_ptr();
|
||||||
array_desc_->data_size = data_size;
|
array_desc_->data_size = data_size;
|
||||||
@@ -150,7 +165,7 @@ void array::set_data(
|
|||||||
|
|
||||||
void array::copy_shared_buffer(
|
void array::copy_shared_buffer(
|
||||||
const array& other,
|
const array& other,
|
||||||
const std::vector<size_t>& strides,
|
const Strides& strides,
|
||||||
Flags flags,
|
Flags flags,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
size_t offset /* = 0 */) {
|
size_t offset /* = 0 */) {
|
||||||
@@ -167,37 +182,18 @@ void array::copy_shared_buffer(const array& other) {
|
|||||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::move_shared_buffer(
|
|
||||||
array other,
|
|
||||||
const std::vector<size_t>& strides,
|
|
||||||
Flags flags,
|
|
||||||
size_t data_size,
|
|
||||||
size_t offset /* = 0 */) {
|
|
||||||
array_desc_->data = std::move(other.array_desc_->data);
|
|
||||||
array_desc_->strides = strides;
|
|
||||||
array_desc_->flags = flags;
|
|
||||||
array_desc_->data_size = data_size;
|
|
||||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
|
||||||
array_desc_->data_ptr = static_cast<void*>(
|
|
||||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
void array::move_shared_buffer(array other) {
|
|
||||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
array::~array() {
|
array::~array() {
|
||||||
if (array_desc_ == nullptr) {
|
if (array_desc_ == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ignore arrays that might be detached during eval
|
// Detached/detaching
|
||||||
if (status() == array::Status::scheduled) {
|
if (array_desc_->primitive == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Break circular reference for non-detached arrays with siblings
|
// Break circular reference for non-detached arrays with siblings
|
||||||
if (auto n = siblings().size(); n > 0) {
|
if (auto n = std::ssize(siblings()); n > 0) {
|
||||||
bool do_detach = true;
|
bool do_detach = true;
|
||||||
// If all siblings have siblings.size() references except
|
// If all siblings have siblings.size() references except
|
||||||
// the one we are currently destroying (which has siblings.size() + 1)
|
// the one we are currently destroying (which has siblings.size() + 1)
|
||||||
@@ -212,6 +208,8 @@ array::~array() {
|
|||||||
if (do_detach) {
|
if (do_detach) {
|
||||||
for (auto& s : siblings()) {
|
for (auto& s : siblings()) {
|
||||||
for (auto& ss : s.siblings()) {
|
for (auto& ss : s.siblings()) {
|
||||||
|
// Set to null here to avoid descending into array destructor
|
||||||
|
// for siblings
|
||||||
ss.array_desc_ = nullptr;
|
ss.array_desc_ = nullptr;
|
||||||
}
|
}
|
||||||
s.array_desc_->siblings.clear();
|
s.array_desc_->siblings.clear();
|
||||||
@@ -232,20 +230,20 @@ void array::ArrayDesc::init() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
|
array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
|
||||||
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
|
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
|
||||||
init();
|
init();
|
||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayDesc::ArrayDesc(
|
array::ArrayDesc::ArrayDesc(
|
||||||
std::vector<int> shape,
|
Shape shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::shared_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
std::vector<array> inputs)
|
std::vector<array> inputs)
|
||||||
: shape(std::move(shape)),
|
: shape(std::move(shape)),
|
||||||
dtype(dtype),
|
dtype(dtype),
|
||||||
status(Status::unscheduled),
|
|
||||||
primitive(std::move(primitive)),
|
primitive(std::move(primitive)),
|
||||||
|
status(Status::unscheduled),
|
||||||
inputs(std::move(inputs)) {
|
inputs(std::move(inputs)) {
|
||||||
init();
|
init();
|
||||||
}
|
}
|
||||||
@@ -269,11 +267,26 @@ array::ArrayDesc::~ArrayDesc() {
|
|||||||
for (array& a : ad.inputs) {
|
for (array& a : ad.inputs) {
|
||||||
if (a.array_desc_) {
|
if (a.array_desc_) {
|
||||||
input_map.insert({a.id(), a});
|
input_map.insert({a.id(), a});
|
||||||
|
for (auto& s : a.siblings()) {
|
||||||
|
input_map.insert({s.id(), s});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ad.inputs.clear();
|
ad.inputs.clear();
|
||||||
for (auto& [_, a] : input_map) {
|
for (auto& [_, a] : input_map) {
|
||||||
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
|
bool is_deletable =
|
||||||
|
(a.array_desc_.use_count() <= std::ssize(a.siblings()) + 1);
|
||||||
|
// An array with siblings is deletable only if all of its siblings
|
||||||
|
// are deletable
|
||||||
|
for (auto& s : a.siblings()) {
|
||||||
|
if (!is_deletable) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
int is_input = (input_map.find(s.id()) != input_map.end());
|
||||||
|
is_deletable &=
|
||||||
|
s.array_desc_.use_count() <= std::ssize(a.siblings()) + is_input;
|
||||||
|
}
|
||||||
|
if (is_deletable) {
|
||||||
for_deletion.push_back(std::move(a.array_desc_));
|
for_deletion.push_back(std::move(a.array_desc_));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -287,6 +300,14 @@ array::ArrayDesc::~ArrayDesc() {
|
|||||||
auto top = std::move(for_deletion.back());
|
auto top = std::move(for_deletion.back());
|
||||||
for_deletion.pop_back();
|
for_deletion.pop_back();
|
||||||
append_deletable_inputs(*top);
|
append_deletable_inputs(*top);
|
||||||
|
|
||||||
|
// Clear out possible siblings to break circular references
|
||||||
|
for (auto& s : top->siblings) {
|
||||||
|
// Set to null here to avoid descending into top-level
|
||||||
|
// array destructor for siblings
|
||||||
|
s.array_desc_ = nullptr;
|
||||||
|
}
|
||||||
|
top->siblings.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,7 +319,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
|||||||
}
|
}
|
||||||
|
|
||||||
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
|
||||||
auto start = std::vector<int>(arr.ndim(), 0);
|
auto start = Shape(arr.ndim(), 0);
|
||||||
auto end = arr.shape();
|
auto end = arr.shape();
|
||||||
auto shape = arr.shape();
|
auto shape = arr.shape();
|
||||||
shape.erase(shape.begin());
|
shape.erase(shape.begin());
|
||||||
|
|||||||
121
mlx/array.h
121
mlx/array.h
@@ -10,12 +10,17 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/dtype.h"
|
#include "mlx/dtype.h"
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
|
#include "mlx/small_vector.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
// Forward declaration
|
// Forward declaration
|
||||||
class Primitive;
|
class Primitive;
|
||||||
using deleter_t = std::function<void(allocator::Buffer)>;
|
|
||||||
|
using Deleter = std::function<void(allocator::Buffer)>;
|
||||||
|
using ShapeElem = int32_t;
|
||||||
|
using Shape = SmallVector<ShapeElem>;
|
||||||
|
using Strides = SmallVector<int64_t>;
|
||||||
|
|
||||||
class array {
|
class array {
|
||||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||||
@@ -31,33 +36,33 @@ class array {
|
|||||||
explicit array(const std::complex<float>& val, Dtype dtype = complex64);
|
explicit array(const std::complex<float>& val, Dtype dtype = complex64);
|
||||||
|
|
||||||
template <typename It>
|
template <typename It>
|
||||||
array(
|
explicit array(
|
||||||
It data,
|
It data,
|
||||||
std::vector<int> shape,
|
Shape shape,
|
||||||
Dtype dtype =
|
Dtype dtype =
|
||||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
explicit array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
||||||
|
|
||||||
/* Special case so empty lists default to float32. */
|
/* Special case so empty lists default to float32. */
|
||||||
array(std::initializer_list<float> data);
|
explicit array(std::initializer_list<float> data);
|
||||||
|
|
||||||
/* Special case so array({}, type) is an empty array. */
|
/* Special case so array({}, type) is an empty array. */
|
||||||
array(std::initializer_list<int> data, Dtype dtype);
|
explicit array(std::initializer_list<int> data, Dtype dtype);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array(
|
explicit array(
|
||||||
std::initializer_list<T> data,
|
std::initializer_list<T> data,
|
||||||
std::vector<int> shape,
|
Shape shape,
|
||||||
Dtype dtype = TypeToDtype<T>());
|
Dtype dtype = TypeToDtype<T>());
|
||||||
|
|
||||||
/* Build an array from a buffer */
|
/* Build an array from a buffer */
|
||||||
array(
|
explicit array(
|
||||||
allocator::Buffer data,
|
allocator::Buffer data,
|
||||||
std::vector<int> shape,
|
Shape shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
deleter_t deleter = allocator::free);
|
Deleter deleter = allocator::free);
|
||||||
|
|
||||||
/** Assignment to rvalue does not compile. */
|
/** Assignment to rvalue does not compile. */
|
||||||
array& operator=(const array& other) && = delete;
|
array& operator=(const array& other) && = delete;
|
||||||
@@ -76,27 +81,27 @@ class array {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** The size of the array's datatype in bytes. */
|
/** The size of the array's datatype in bytes. */
|
||||||
size_t itemsize() const {
|
int itemsize() const {
|
||||||
return size_of(dtype());
|
return size_of(dtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The number of elements in the array. */
|
/** The number of elements in the array. */
|
||||||
size_t size() const {
|
int64_t size() const {
|
||||||
return array_desc_->size;
|
return array_desc_->size;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The number of bytes in the array. */
|
/** The number of bytes in the array. */
|
||||||
size_t nbytes() const {
|
int64_t nbytes() const {
|
||||||
return size() * itemsize();
|
return size() * itemsize();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The number of dimensions of the array. */
|
/** The number of dimensions of the array. */
|
||||||
size_t ndim() const {
|
int ndim() const {
|
||||||
return array_desc_->shape.size();
|
return array_desc_->shape.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The shape of the array as a vector of integers. */
|
/** The shape of the array as a vector of integers. */
|
||||||
const std::vector<int>& shape() const {
|
const Shape& shape() const {
|
||||||
return array_desc_->shape;
|
return array_desc_->shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,12 +110,12 @@ class array {
|
|||||||
*
|
*
|
||||||
* This function supports negative indexing and provides
|
* This function supports negative indexing and provides
|
||||||
* bounds checking. */
|
* bounds checking. */
|
||||||
int shape(int dim) const {
|
auto shape(int dim) const {
|
||||||
return shape().at(dim < 0 ? dim + ndim() : dim);
|
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The strides of the array. */
|
/** The strides of the array. */
|
||||||
const std::vector<size_t>& strides() const {
|
const Strides& strides() const {
|
||||||
return array_desc_->strides;
|
return array_desc_->strides;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,7 +124,7 @@ class array {
|
|||||||
*
|
*
|
||||||
* This function supports negative indexing and provides
|
* This function supports negative indexing and provides
|
||||||
* bounds checking. */
|
* bounds checking. */
|
||||||
size_t strides(int dim) const {
|
auto strides(int dim) const {
|
||||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,17 +189,24 @@ class array {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
array(
|
array(
|
||||||
std::vector<int> shape,
|
Shape shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::shared_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
std::vector<array> inputs);
|
std::vector<array> inputs);
|
||||||
|
|
||||||
static std::vector<array> make_arrays(
|
static std::vector<array> make_arrays(
|
||||||
std::vector<std::vector<int>> shapes,
|
std::vector<Shape> shapes,
|
||||||
const std::vector<Dtype>& dtypes,
|
const std::vector<Dtype>& dtypes,
|
||||||
const std::shared_ptr<Primitive>& primitive,
|
const std::shared_ptr<Primitive>& primitive,
|
||||||
const std::vector<array>& inputs);
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a new array that refers to the same data as the input but with a
|
||||||
|
* non-owning pointer to it. Note the array is detached from the graph and has
|
||||||
|
* no inputs, siblings or primitive.
|
||||||
|
*/
|
||||||
|
static array unsafe_weak_copy(const array& other);
|
||||||
|
|
||||||
/** A unique identifier for an array. */
|
/** A unique identifier for an array. */
|
||||||
std::uintptr_t id() const {
|
std::uintptr_t id() const {
|
||||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||||
@@ -207,12 +219,16 @@ class array {
|
|||||||
|
|
||||||
struct Data {
|
struct Data {
|
||||||
allocator::Buffer buffer;
|
allocator::Buffer buffer;
|
||||||
deleter_t d;
|
Deleter d;
|
||||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
Data(allocator::Buffer buffer, Deleter d = allocator::free)
|
||||||
: buffer(buffer), d(d) {}
|
: buffer(buffer), d(d) {}
|
||||||
// Not copyable
|
// Not copyable
|
||||||
Data(const Data& d) = delete;
|
Data(const Data& d) = delete;
|
||||||
Data& operator=(const Data& d) = delete;
|
Data& operator=(const Data& d) = delete;
|
||||||
|
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||||
|
o.buffer = allocator::Buffer(nullptr);
|
||||||
|
o.d = [](allocator::Buffer) {};
|
||||||
|
}
|
||||||
~Data() {
|
~Data() {
|
||||||
d(buffer);
|
d(buffer);
|
||||||
}
|
}
|
||||||
@@ -313,7 +329,7 @@ class array {
|
|||||||
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
|
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
|
||||||
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
|
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
|
||||||
**/
|
**/
|
||||||
size_t data_size() const {
|
int64_t data_size() const {
|
||||||
return array_desc_->data_size;
|
return array_desc_->data_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,15 +340,15 @@ class array {
|
|||||||
return array_desc_->data->buffer;
|
return array_desc_->data->buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t buffer_size() const {
|
int64_t buffer_size() const {
|
||||||
return allocator::allocator().size(buffer());
|
return allocator::allocator().size(buffer());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a copy of the shared pointer
|
// Return the shared pointer to the array::Data struct
|
||||||
// to the array::Data struct
|
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||||
std::shared_ptr<Data> data_shared_ptr() const {
|
|
||||||
return array_desc_->data;
|
return array_desc_->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a raw pointer to the arrays data
|
// Return a raw pointer to the arrays data
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* data() {
|
T* data() {
|
||||||
@@ -345,15 +361,10 @@ class array {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enum Status {
|
enum Status {
|
||||||
// The ouptut of a computation which has not been scheduled.
|
// The output of a computation which has not been scheduled.
|
||||||
// For example, the status of `x` in `auto x = a + b`.
|
// For example, the status of `x` in `auto x = a + b`.
|
||||||
unscheduled,
|
unscheduled,
|
||||||
|
|
||||||
// The ouptut of a computation which has been scheduled but `eval_*` has
|
|
||||||
// not yet been called on the array's primitive. A possible
|
|
||||||
// status of `x` in `auto x = a + b; eval(x);`
|
|
||||||
scheduled,
|
|
||||||
|
|
||||||
// The array's `eval_*` function has been run, but the computation is not
|
// The array's `eval_*` function has been run, but the computation is not
|
||||||
// necessarily complete. The array will have memory allocated and if it is
|
// necessarily complete. The array will have memory allocated and if it is
|
||||||
// not a tracer then it will be detached from the graph.
|
// not a tracer then it will be detached from the graph.
|
||||||
@@ -390,6 +401,10 @@ class array {
|
|||||||
array_desc_->event = std::move(e);
|
array_desc_->event = std::move(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void detach_event() const {
|
||||||
|
array_desc_->event = Event{};
|
||||||
|
}
|
||||||
|
|
||||||
// Mark the array as a tracer array (true) or not.
|
// Mark the array as a tracer array (true) or not.
|
||||||
void set_tracer(bool is_tracer) {
|
void set_tracer(bool is_tracer) {
|
||||||
array_desc_->is_tracer = is_tracer;
|
array_desc_->is_tracer = is_tracer;
|
||||||
@@ -397,33 +412,24 @@ class array {
|
|||||||
// Check if the array is a tracer array
|
// Check if the array is a tracer array
|
||||||
bool is_tracer() const;
|
bool is_tracer() const;
|
||||||
|
|
||||||
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free);
|
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
|
||||||
|
|
||||||
void set_data(
|
void set_data(
|
||||||
allocator::Buffer buffer,
|
allocator::Buffer buffer,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
std::vector<size_t> strides,
|
Strides strides,
|
||||||
Flags flags,
|
Flags flags,
|
||||||
deleter_t d = allocator::free);
|
Deleter d = allocator::free);
|
||||||
|
|
||||||
void copy_shared_buffer(
|
void copy_shared_buffer(
|
||||||
const array& other,
|
const array& other,
|
||||||
const std::vector<size_t>& strides,
|
const Strides& strides,
|
||||||
Flags flags,
|
Flags flags,
|
||||||
size_t data_size,
|
size_t data_size,
|
||||||
size_t offset = 0);
|
size_t offset = 0);
|
||||||
|
|
||||||
void copy_shared_buffer(const array& other);
|
void copy_shared_buffer(const array& other);
|
||||||
|
|
||||||
void move_shared_buffer(
|
|
||||||
array other,
|
|
||||||
const std::vector<size_t>& strides,
|
|
||||||
Flags flags,
|
|
||||||
size_t data_size,
|
|
||||||
size_t offset = 0);
|
|
||||||
|
|
||||||
void move_shared_buffer(array other);
|
|
||||||
|
|
||||||
void overwrite_descriptor(const array& other) {
|
void overwrite_descriptor(const array& other) {
|
||||||
array_desc_ = other.array_desc_;
|
array_desc_ = other.array_desc_;
|
||||||
}
|
}
|
||||||
@@ -436,8 +442,8 @@ class array {
|
|||||||
void init(const It src);
|
void init(const It src);
|
||||||
|
|
||||||
struct ArrayDesc {
|
struct ArrayDesc {
|
||||||
std::vector<int> shape;
|
Shape shape;
|
||||||
std::vector<size_t> strides;
|
Strides strides;
|
||||||
size_t size;
|
size_t size;
|
||||||
Dtype dtype;
|
Dtype dtype;
|
||||||
std::shared_ptr<Primitive> primitive;
|
std::shared_ptr<Primitive> primitive;
|
||||||
@@ -471,10 +477,10 @@ class array {
|
|||||||
// The arrays position in the output list
|
// The arrays position in the output list
|
||||||
uint32_t position{0};
|
uint32_t position{0};
|
||||||
|
|
||||||
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
|
explicit ArrayDesc(Shape shape, Dtype dtype);
|
||||||
|
|
||||||
explicit ArrayDesc(
|
explicit ArrayDesc(
|
||||||
std::vector<int> shape,
|
Shape shape,
|
||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
std::shared_ptr<Primitive> primitive,
|
std::shared_ptr<Primitive> primitive,
|
||||||
std::vector<array> inputs);
|
std::vector<array> inputs);
|
||||||
@@ -495,14 +501,14 @@ class array {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||||
init(&val);
|
init(&val);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename It>
|
template <typename It>
|
||||||
array::array(
|
array::array(
|
||||||
It data,
|
It data,
|
||||||
std::vector<int> shape,
|
Shape shape,
|
||||||
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
|
||||||
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
init(data);
|
init(data);
|
||||||
@@ -513,7 +519,7 @@ array::array(
|
|||||||
std::initializer_list<T> data,
|
std::initializer_list<T> data,
|
||||||
Dtype dtype /* = TypeToDtype<T>() */)
|
Dtype dtype /* = TypeToDtype<T>() */)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
std::vector<int>{static_cast<int>(data.size())},
|
Shape{static_cast<ShapeElem>(data.size())},
|
||||||
dtype)) {
|
dtype)) {
|
||||||
init(data.begin());
|
init(data.begin());
|
||||||
}
|
}
|
||||||
@@ -521,10 +527,10 @@ array::array(
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
array::array(
|
array::array(
|
||||||
std::initializer_list<T> data,
|
std::initializer_list<T> data,
|
||||||
std::vector<int> shape,
|
Shape shape,
|
||||||
Dtype dtype /* = TypeToDtype<T>() */)
|
Dtype dtype /* = TypeToDtype<T>() */)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
if (data.size() != size()) {
|
if (std::ssize(data) != size()) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Data size and provided shape mismatch in array construction.");
|
"Data size and provided shape mismatch in array construction.");
|
||||||
}
|
}
|
||||||
@@ -590,6 +596,9 @@ void array::init(It src) {
|
|||||||
case float32:
|
case float32:
|
||||||
std::copy(src, src + size(), data<float>());
|
std::copy(src, src + size(), data<float>());
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
std::copy(src, src + size(), data<double>());
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
std::copy(src, src + size(), data<bfloat16_t>());
|
std::copy(src, src + size(), data<bfloat16_t>());
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
target_sources(
|
|
||||||
mlx
|
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
|
||||||
#include <simd/vector.h>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/copy.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
eval(inputs, out);
|
|
||||||
|
|
||||||
// TODO: Add accelerate based optimizations for CPU conv
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,253 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
|
||||||
|
|
||||||
#include "mlx/backend/accelerate/utils.h"
|
|
||||||
#include "mlx/backend/common/copy.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
std::tuple<bool, size_t, array> check_transpose(const array& arr) {
|
|
||||||
auto stx = arr.strides()[arr.ndim() - 2];
|
|
||||||
auto sty = arr.strides()[arr.ndim() - 1];
|
|
||||||
if (stx == arr.shape(-1) && sty == 1) {
|
|
||||||
return std::make_tuple(false, stx, arr);
|
|
||||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
|
||||||
return std::make_tuple(true, sty, arr);
|
|
||||||
} else {
|
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
|
||||||
copy(arr, arr_copy, CopyType::General);
|
|
||||||
size_t stx = arr.shape(-1);
|
|
||||||
return std::make_tuple(false, stx, arr_copy);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void matmul_cblas_general(
|
|
||||||
const array& a_pre,
|
|
||||||
const array& b_pre,
|
|
||||||
array& out,
|
|
||||||
float alpha = 1.0f,
|
|
||||||
float beta = 0.0f) {
|
|
||||||
if (out.dtype() != float32) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[matmul_cblas] on CPU currently only supports float32");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
|
||||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
|
||||||
size_t M = a.shape(-2);
|
|
||||||
size_t N = b.shape(-1);
|
|
||||||
size_t K = a.shape(-1);
|
|
||||||
|
|
||||||
if (M == 0 || N == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (K == 0) {
|
|
||||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
|
||||||
cblas_sgemm(
|
|
||||||
CblasRowMajor,
|
|
||||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
|
||||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
alpha, // alpha
|
|
||||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
|
||||||
lda,
|
|
||||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
|
||||||
ldb,
|
|
||||||
beta, // beta
|
|
||||||
out.data<float>() + M * N * i,
|
|
||||||
out.shape(-1) // ldc
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
|
||||||
if (out.dtype() != float32) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[matmul_cblas] on CPU currently only supports float32");
|
|
||||||
}
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
return matmul_cblas_general(a_pre, b_pre, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void matmul_bnns_general(
|
|
||||||
const array& a_pre,
|
|
||||||
const array& b_pre,
|
|
||||||
array& out,
|
|
||||||
float alpha = 1.0f,
|
|
||||||
float beta = 0.0f) {
|
|
||||||
// TODO: Update to utilize BNNS broadcasting
|
|
||||||
|
|
||||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
|
||||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
|
||||||
size_t M = a.shape(-2);
|
|
||||||
size_t N = b.shape(-1);
|
|
||||||
size_t K = a.shape(-1);
|
|
||||||
|
|
||||||
if (M == 0 || N == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (K == 0) {
|
|
||||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
|
||||||
|
|
||||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
|
||||||
/* float alpha = */ alpha,
|
|
||||||
/* float beta = */ beta,
|
|
||||||
/* bool transA = */ a_transposed,
|
|
||||||
/* bool transB = */ b_transposed,
|
|
||||||
/* bool quadratic = */ false,
|
|
||||||
/* bool a_is_weights = */ false,
|
|
||||||
/* bool b_is_weights = */ false,
|
|
||||||
/* BNNSNDArrayDescriptor iA_desc = */
|
|
||||||
BNNSNDArrayDescriptor{
|
|
||||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
|
||||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
|
||||||
|
|
||||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
|
||||||
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
|
|
||||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
|
||||||
{1, lda, 0, 0, 0, 0, 0, 0},
|
|
||||||
|
|
||||||
/* void * _Nullable data = */ nullptr,
|
|
||||||
/* BNNSDataType data_type = */ bnns_dtype,
|
|
||||||
|
|
||||||
/* void * _Nullable table_data = */ nullptr,
|
|
||||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
|
||||||
|
|
||||||
/* float data_scale = */ 1.0,
|
|
||||||
/* float data_bias = */ 0.0,
|
|
||||||
},
|
|
||||||
/* BNNSNDArrayDescriptor iB_desc = */
|
|
||||||
BNNSNDArrayDescriptor{
|
|
||||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
|
||||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
|
||||||
|
|
||||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
|
||||||
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
|
|
||||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
|
||||||
{1, ldb, 0, 0, 0, 0, 0, 0},
|
|
||||||
|
|
||||||
/* void * _Nullable data = */ nullptr,
|
|
||||||
/* BNNSDataType data_type = */ bnns_dtype,
|
|
||||||
|
|
||||||
/* void * _Nullable table_data = */ nullptr,
|
|
||||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
|
||||||
|
|
||||||
/* float data_scale = */ 1.0,
|
|
||||||
/* float data_bias = */ 0.0,
|
|
||||||
},
|
|
||||||
/* BNNSNDArrayDescriptor o_desc = */
|
|
||||||
BNNSNDArrayDescriptor{
|
|
||||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
|
||||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
|
||||||
|
|
||||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
|
||||||
{N, M, 0, 0, 0, 0, 0, 0},
|
|
||||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
|
||||||
{1, N, 0, 0, 0, 0, 0, 0},
|
|
||||||
|
|
||||||
/* void * _Nullable data = */ nullptr,
|
|
||||||
/* BNNSDataType data_type = */ bnns_dtype,
|
|
||||||
|
|
||||||
/* void * _Nullable table_data = */ nullptr,
|
|
||||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
|
||||||
|
|
||||||
/* float data_scale = */ 1.0,
|
|
||||||
/* float data_bias = */ 0.0,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
auto bnns_filter =
|
|
||||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
|
||||||
|
|
||||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
|
||||||
BNNSFilterApplyTwoInput(
|
|
||||||
bnns_filter,
|
|
||||||
a.data<uint8_t>() +
|
|
||||||
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
|
|
||||||
b.data<uint8_t>() +
|
|
||||||
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
|
|
||||||
out.data<uint8_t>() + M * N * i * out.itemsize());
|
|
||||||
}
|
|
||||||
|
|
||||||
BNNSFilterDestroy(bnns_filter);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
|
||||||
// TODO: Update to utilize BNNS broadcasting
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
return matmul_bnns_general(a_pre, b_pre, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline void mask_matrix(
|
|
||||||
T* data,
|
|
||||||
const bool* mask,
|
|
||||||
int tile_size,
|
|
||||||
const int X,
|
|
||||||
const int Y,
|
|
||||||
const size_t X_data_str,
|
|
||||||
const size_t Y_data_str,
|
|
||||||
const size_t X_mask_str,
|
|
||||||
const size_t Y_mask_str) {
|
|
||||||
int tX = (X + tile_size - 1) / tile_size;
|
|
||||||
int tY = (Y + tile_size - 1) / tile_size;
|
|
||||||
|
|
||||||
for (int i = 0; i < tX; i++) {
|
|
||||||
for (int j = 0; j < tY; j++) {
|
|
||||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
|
||||||
if (!do_mask) {
|
|
||||||
int loc_x = i * tile_size;
|
|
||||||
int loc_y = j * tile_size;
|
|
||||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
|
||||||
|
|
||||||
int size_x = std::min(tile_size, X - loc_x);
|
|
||||||
int size_y = std::min(tile_size, Y - loc_y);
|
|
||||||
for (int ii = 0; ii < size_x; ii++) {
|
|
||||||
for (int jj = 0; jj < size_y; jj++) {
|
|
||||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
if (out.dtype() == float32) {
|
|
||||||
return matmul_cblas(inputs[0], inputs[1], out);
|
|
||||||
}
|
|
||||||
return matmul_bnns(inputs[0], inputs[1], out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
// Fill output with C
|
|
||||||
auto& c = inputs[2];
|
|
||||||
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
|
||||||
copy(c, out, ctype);
|
|
||||||
|
|
||||||
if (out.dtype() == float32) {
|
|
||||||
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
|
|
||||||
}
|
|
||||||
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,600 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/backend/common/binary.h"
|
|
||||||
#include "mlx/backend/common/copy.h"
|
|
||||||
#include "mlx/backend/common/unary.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
#define DEFAULT(primitive) \
|
|
||||||
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
|
||||||
primitive::eval(inputs, out); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define DEFAULT_MULTI(primitive) \
|
|
||||||
void primitive::eval_cpu( \
|
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
|
||||||
primitive::eval(inputs, outputs); \
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
// Use the default implementation for the following primitives
|
|
||||||
DEFAULT(Arange)
|
|
||||||
DEFAULT(ArgPartition)
|
|
||||||
DEFAULT(ArgReduce)
|
|
||||||
DEFAULT(ArgSort)
|
|
||||||
DEFAULT(AsStrided)
|
|
||||||
DEFAULT(BlockMaskedMM)
|
|
||||||
DEFAULT(Broadcast)
|
|
||||||
DEFAULT(Ceil)
|
|
||||||
DEFAULT(Concatenate)
|
|
||||||
DEFAULT(Conjugate)
|
|
||||||
DEFAULT(Copy)
|
|
||||||
DEFAULT_MULTI(CustomTransforms)
|
|
||||||
DEFAULT_MULTI(Depends)
|
|
||||||
DEFAULT_MULTI(DivMod)
|
|
||||||
DEFAULT(NumberOfElements)
|
|
||||||
DEFAULT(Equal)
|
|
||||||
DEFAULT(Erf)
|
|
||||||
DEFAULT(ErfInv)
|
|
||||||
DEFAULT(FFT)
|
|
||||||
DEFAULT(Floor)
|
|
||||||
DEFAULT(Gather)
|
|
||||||
DEFAULT(GatherMM)
|
|
||||||
DEFAULT(GatherQMM)
|
|
||||||
DEFAULT(Greater)
|
|
||||||
DEFAULT(GreaterEqual)
|
|
||||||
DEFAULT(Hadamard)
|
|
||||||
DEFAULT(Less)
|
|
||||||
DEFAULT(LessEqual)
|
|
||||||
DEFAULT(Load)
|
|
||||||
DEFAULT(LogicalNot)
|
|
||||||
DEFAULT(LogicalAnd)
|
|
||||||
DEFAULT(LogicalOr)
|
|
||||||
DEFAULT(LogAddExp)
|
|
||||||
DEFAULT(Maximum)
|
|
||||||
DEFAULT(Minimum)
|
|
||||||
DEFAULT(NotEqual)
|
|
||||||
DEFAULT(Pad)
|
|
||||||
DEFAULT(Partition)
|
|
||||||
DEFAULT_MULTI(QRF)
|
|
||||||
DEFAULT(RandomBits)
|
|
||||||
DEFAULT(Reshape)
|
|
||||||
DEFAULT(Remainder)
|
|
||||||
DEFAULT(Round)
|
|
||||||
DEFAULT(Scatter)
|
|
||||||
DEFAULT(Select)
|
|
||||||
DEFAULT(Sigmoid)
|
|
||||||
DEFAULT(Sign)
|
|
||||||
DEFAULT(Slice)
|
|
||||||
DEFAULT(SliceUpdate)
|
|
||||||
DEFAULT_MULTI(Split)
|
|
||||||
DEFAULT(Sort)
|
|
||||||
DEFAULT(StopGradient)
|
|
||||||
DEFAULT_MULTI(SVD)
|
|
||||||
DEFAULT(Transpose)
|
|
||||||
DEFAULT(Inverse)
|
|
||||||
DEFAULT(Cholesky)
|
|
||||||
|
|
||||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
|
||||||
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
|
|
||||||
if (a.dtype() == float32) {
|
|
||||||
binary_op<float>(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
[](auto x, auto y) { return x + y; },
|
|
||||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
|
||||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
|
||||||
},
|
|
||||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
|
||||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
|
||||||
},
|
|
||||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
|
||||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
|
||||||
});
|
|
||||||
} else if (a.dtype() == int32) {
|
|
||||||
binary_op<int>(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
[](auto x, auto y) { return x + y; },
|
|
||||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
|
||||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
|
||||||
},
|
|
||||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
|
||||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
|
||||||
},
|
|
||||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
|
||||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvacosf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvacoshf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvasinf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvasinhf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvatanf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
|
||||||
b.flags().row_contiguous) {
|
|
||||||
if (a.is_donatable()) {
|
|
||||||
out.copy_shared_buffer(a);
|
|
||||||
} else if (b.is_donatable()) {
|
|
||||||
out.copy_shared_buffer(b);
|
|
||||||
} else {
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
}
|
|
||||||
int size = a.data_size();
|
|
||||||
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvatanhf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
|
|
||||||
if (in.flags().contiguous) {
|
|
||||||
// Use accelerate functions if possible
|
|
||||||
if (in.dtype() == float32 && out.dtype() == uint32) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
vDSP_vfixu32(
|
|
||||||
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
|
|
||||||
return;
|
|
||||||
} else if (in.dtype() == float32 && out.dtype() == int32) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
|
|
||||||
return;
|
|
||||||
} else if (in.dtype() == uint32 && out.dtype() == float32) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
vDSP_vfltu32(
|
|
||||||
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
|
|
||||||
return;
|
|
||||||
} else if (in.dtype() == int32 && out.dtype() == float32) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvcosf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvcoshf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
|
|
||||||
if (a.dtype() == int32) {
|
|
||||||
binary_op<int>(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
[](auto x, auto y) { return x / y; },
|
|
||||||
UseDefaultBinaryOp(),
|
|
||||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
|
||||||
vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
|
||||||
},
|
|
||||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
|
||||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
|
||||||
});
|
|
||||||
} else if (a.dtype() == float32) {
|
|
||||||
binary_op<float>(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
[](auto x, auto y) { return x / y; },
|
|
||||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
|
||||||
vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n);
|
|
||||||
},
|
|
||||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
|
||||||
vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
|
||||||
},
|
|
||||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
|
||||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
auto size = in.data_size();
|
|
||||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
auto size = in.data_size();
|
|
||||||
vvexpm1f(
|
|
||||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
assert(in.dtype() == out.dtype());
|
|
||||||
if (in.data_size() == 1 && out.dtype() == float32) {
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
vDSP_vfill(in.data<float>(), out.data<float>(), 1, out.size());
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
auto size = in.data_size();
|
|
||||||
switch (base_) {
|
|
||||||
case Base::e:
|
|
||||||
vvlogf(
|
|
||||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
|
||||||
break;
|
|
||||||
case Base::two:
|
|
||||||
vvlog2f(
|
|
||||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
|
||||||
break;
|
|
||||||
case Base::ten:
|
|
||||||
vvlog10f(
|
|
||||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
auto size = in.data_size();
|
|
||||||
vvlog1pf(
|
|
||||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
|
|
||||||
if (a.dtype() == float32) {
|
|
||||||
binary_op<float>(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
[](auto x, auto y) { return x * y; },
|
|
||||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
|
||||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
|
||||||
},
|
|
||||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
|
||||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
|
||||||
},
|
|
||||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
|
||||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
|
||||||
b.flags().row_contiguous) {
|
|
||||||
int size = a.size();
|
|
||||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
|
||||||
out.copy_shared_buffer(a);
|
|
||||||
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
|
||||||
out.copy_shared_buffer(b);
|
|
||||||
} else {
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
}
|
|
||||||
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
|
|
||||||
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
int stride = in.shape(axis_);
|
|
||||||
int count = in.size() / stride;
|
|
||||||
const float* input = in.data<float>();
|
|
||||||
float* output = out.data<float>();
|
|
||||||
float s = 1.0;
|
|
||||||
if (!reverse_) {
|
|
||||||
for (int i = 0; i < count; i++) {
|
|
||||||
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
|
|
||||||
input += stride;
|
|
||||||
output += stride;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < count; i++) {
|
|
||||||
input += stride - 1;
|
|
||||||
output += stride - 1;
|
|
||||||
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
|
|
||||||
input++;
|
|
||||||
output++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvsinf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvsinhf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
auto size = in.data_size();
|
|
||||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
if (recip_) {
|
|
||||||
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
vvsqrtf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& a = inputs[0];
|
|
||||||
auto& b = inputs[1];
|
|
||||||
|
|
||||||
if (a.dtype() == float32) {
|
|
||||||
binary_op<float>(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
[](auto x, auto y) { return x - y; },
|
|
||||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
|
||||||
float minus_1 = -1;
|
|
||||||
vDSP_vsmsa(
|
|
||||||
(const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n);
|
|
||||||
},
|
|
||||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
|
||||||
float val = -(*s);
|
|
||||||
vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n);
|
|
||||||
},
|
|
||||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
|
||||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
|
||||||
});
|
|
||||||
} else if (a.dtype() == int32) {
|
|
||||||
binary_op<int>(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
out,
|
|
||||||
[](auto x, auto y) { return x - y; },
|
|
||||||
UseDefaultBinaryOp(),
|
|
||||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
|
||||||
int val = -(*s);
|
|
||||||
vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n);
|
|
||||||
},
|
|
||||||
UseDefaultBinaryOp());
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvtanf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
|
||||||
set_unary_output_data(in, out);
|
|
||||||
int size = in.data_size();
|
|
||||||
vvtanhf(out.data<float>(), in.data<float>(), &size);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <simd/vector.h>
|
|
||||||
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
void _qmm_t_4_64(
|
|
||||||
float* result,
|
|
||||||
const float* x,
|
|
||||||
const uint32_t* w,
|
|
||||||
const float* scales,
|
|
||||||
const float* biases,
|
|
||||||
int M,
|
|
||||||
int N,
|
|
||||||
int K) {
|
|
||||||
constexpr int bits = 4;
|
|
||||||
constexpr int group_size = 64;
|
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
|
||||||
constexpr int pack_factor = 32 / bits;
|
|
||||||
constexpr int packs_in_group = group_size / pack_factor;
|
|
||||||
|
|
||||||
for (int m = 0; m < M; m++) {
|
|
||||||
const uint32_t* w_local = w;
|
|
||||||
const float* scales_local = scales;
|
|
||||||
const float* biases_local = biases;
|
|
||||||
|
|
||||||
for (int n = 0; n < N; n++) {
|
|
||||||
const simd_float16* x_local = (simd_float16*)x;
|
|
||||||
simd_float16 sum = 0;
|
|
||||||
for (int k = 0; k < K; k += group_size) {
|
|
||||||
float scale = *scales_local++;
|
|
||||||
float bias = *biases_local++;
|
|
||||||
|
|
||||||
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
|
||||||
// TODO: vectorize this properly
|
|
||||||
simd_uint16 wi;
|
|
||||||
for (int e = 0; e < 2; e++) {
|
|
||||||
uint32_t wii = *w_local++;
|
|
||||||
for (int p = 0; p < 8; p++) {
|
|
||||||
wi[e * 8 + p] = wii & bitmask;
|
|
||||||
wii >>= bits;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
simd_float16 wf = simd_float(wi);
|
|
||||||
wf *= scale;
|
|
||||||
wf += bias;
|
|
||||||
|
|
||||||
sum += (*x_local) * wf;
|
|
||||||
x_local++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
*result = simd_reduce_add(sum);
|
|
||||||
result++;
|
|
||||||
}
|
|
||||||
|
|
||||||
x += K;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 4);
|
|
||||||
|
|
||||||
auto& x = inputs[0];
|
|
||||||
auto& w = inputs[1];
|
|
||||||
auto& scales = inputs[2];
|
|
||||||
auto& biases = inputs[3];
|
|
||||||
|
|
||||||
bool condition =
|
|
||||||
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
|
|
||||||
scales.flags().row_contiguous && biases.flags().row_contiguous &&
|
|
||||||
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
|
|
||||||
|
|
||||||
if (condition) {
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
int K = x.shape(-1);
|
|
||||||
int M = x.size() / K;
|
|
||||||
int N = out.shape(-1);
|
|
||||||
_qmm_t_4_64(
|
|
||||||
out.data<float>(),
|
|
||||||
x.data<float>(),
|
|
||||||
w.data<uint32_t>(),
|
|
||||||
scales.data<float>(),
|
|
||||||
biases.data<float>(),
|
|
||||||
M,
|
|
||||||
N,
|
|
||||||
K);
|
|
||||||
} else {
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
|
||||||
#include <simd/vector.h>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/reduce.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename T, typename VT>
|
|
||||||
struct MinReduction {
|
|
||||||
T operator()(const T& a, const T& b) {
|
|
||||||
return std::min(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT operator()(VT a, VT b) {
|
|
||||||
return simd_min(a, b);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename VT>
|
|
||||||
struct MaxReduction {
|
|
||||||
T operator()(const T& a, const T& b) {
|
|
||||||
return std::max(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT operator()(VT a, VT b) {
|
|
||||||
return simd_max(a, b);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename VT>
|
|
||||||
struct SumReduction {
|
|
||||||
T operator()(const T& a, const T& b) {
|
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT operator()(VT a, VT b) {
|
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename VT, int N, typename Reduction>
|
|
||||||
struct StridedReduce {
|
|
||||||
void operator()(const T* x, T* accum, int size, size_t stride) {
|
|
||||||
Reduction op;
|
|
||||||
|
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
size_t s = stride;
|
|
||||||
T* a = accum;
|
|
||||||
while (s >= N) {
|
|
||||||
*(VT*)a = op((*(VT*)x), (*(VT*)a));
|
|
||||||
x += N;
|
|
||||||
a += N;
|
|
||||||
s -= N;
|
|
||||||
}
|
|
||||||
while (s-- > 0) {
|
|
||||||
*a = op(*a, *x);
|
|
||||||
a++;
|
|
||||||
x++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
|
|
||||||
if (in.dtype() == float32) {
|
|
||||||
if (reduce_type_ == Reduce::Sum) {
|
|
||||||
reduction_op<float, float>(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
axes_,
|
|
||||||
0,
|
|
||||||
StridedReduce<
|
|
||||||
float,
|
|
||||||
simd_float16,
|
|
||||||
16,
|
|
||||||
SumReduction<float, simd_float16>>(),
|
|
||||||
[](const auto* x, auto* accum, int size) {
|
|
||||||
float acc;
|
|
||||||
vDSP_sve((const float*)x, 1, &acc, size);
|
|
||||||
(*accum) += acc;
|
|
||||||
},
|
|
||||||
[](auto* accum, auto x) { *accum += x; });
|
|
||||||
return;
|
|
||||||
} else if (reduce_type_ == Reduce::Max) {
|
|
||||||
reduction_op<float, float>(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
axes_,
|
|
||||||
-std::numeric_limits<float>::infinity(),
|
|
||||||
StridedReduce<
|
|
||||||
float,
|
|
||||||
simd_float16,
|
|
||||||
16,
|
|
||||||
MaxReduction<float, simd_float16>>(),
|
|
||||||
[](const auto* x, auto* accum, int size) {
|
|
||||||
float max;
|
|
||||||
vDSP_maxv((const float*)x, 1, &max, size);
|
|
||||||
(*accum) = (*accum < max) ? max : *accum;
|
|
||||||
},
|
|
||||||
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
|
|
||||||
return;
|
|
||||||
} else if (reduce_type_ == Reduce::Min) {
|
|
||||||
reduction_op<float, float>(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
axes_,
|
|
||||||
std::numeric_limits<float>::infinity(),
|
|
||||||
StridedReduce<
|
|
||||||
float,
|
|
||||||
simd_float16,
|
|
||||||
16,
|
|
||||||
MinReduction<float, simd_float16>>(),
|
|
||||||
[](const auto* x, auto* accum, int size) {
|
|
||||||
float min;
|
|
||||||
vDSP_minv((const float*)x, 1, &min, size);
|
|
||||||
(*accum) = (*accum > min) ? min : *accum;
|
|
||||||
},
|
|
||||||
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// TODO: Add integer addition and min/max using the templates above and
|
|
||||||
// simd_int16 and friends.
|
|
||||||
eval(inputs, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,393 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
#include <arm_neon.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <simd/math.h>
|
|
||||||
#include <simd/vector.h>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/copy.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute exp(x) in an optimizer friendly way as follows:
|
|
||||||
*
|
|
||||||
* First change the problem to computing 2**y where y = x / ln(2).
|
|
||||||
*
|
|
||||||
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
|
|
||||||
* `ipart` and y2 is fractional part. For the integer part we perform bit
|
|
||||||
* shifting and for the fractional part we use a polynomial approximation.
|
|
||||||
*
|
|
||||||
* The algorithm and constants of the polynomial taken from
|
|
||||||
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
|
|
||||||
* from Cephes math library.
|
|
||||||
*
|
|
||||||
* Note: The implementation below is a general fast exp. There could be faster
|
|
||||||
* implementations for numbers strictly < 0.
|
|
||||||
*/
|
|
||||||
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
|
|
||||||
auto x = x_init * 1.442695; // multiply with log_2(e)
|
|
||||||
simd_float16 ipart, fpart;
|
|
||||||
simd_int16 epart;
|
|
||||||
x = simd_clamp(x, -80, 80);
|
|
||||||
ipart = simd::floor(x + 0.5);
|
|
||||||
fpart = x - ipart;
|
|
||||||
|
|
||||||
x = 1.535336188319500e-4f;
|
|
||||||
x = x * fpart + 1.339887440266574e-3f;
|
|
||||||
x = x * fpart + 9.618437357674640e-3f;
|
|
||||||
x = x * fpart + 5.550332471162809e-2f;
|
|
||||||
x = x * fpart + 2.402264791363012e-1f;
|
|
||||||
x = x * fpart + 6.931472028550421e-1f;
|
|
||||||
x = x * fpart + 1.000000000000000f;
|
|
||||||
|
|
||||||
// generate 2**ipart in the floating point representation using integer
|
|
||||||
// bitshifting
|
|
||||||
epart = (simd_int(ipart) + 127) << 23;
|
|
||||||
|
|
||||||
// Avoid supressing NaNs
|
|
||||||
simd_int16 eq = (x_init == x_init);
|
|
||||||
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
|
|
||||||
}
|
|
||||||
|
|
||||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
/**
|
|
||||||
* The ARM neon equivalent of the fast exp above.
|
|
||||||
*/
|
|
||||||
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
|
||||||
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
|
|
||||||
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
|
|
||||||
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
|
|
||||||
|
|
||||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
|
|
||||||
float16x8_t fpart = vsubq_f16(x, ipart);
|
|
||||||
|
|
||||||
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
|
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
|
|
||||||
|
|
||||||
// generate 2**ipart in the floating point representation using integer
|
|
||||||
// bitshifting
|
|
||||||
int16x8_t epart = vcvtq_s16_f16(ipart);
|
|
||||||
epart = vaddq_s16(epart, vdupq_n_s16(15));
|
|
||||||
epart = vshlq_n_s16(epart, 10);
|
|
||||||
|
|
||||||
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Implementation of folding maximum for ARM neon. This should possibly be
|
|
||||||
* refactored out of softmax.cpp at some point.
|
|
||||||
*/
|
|
||||||
inline float16_t neon_reduce_max(float16x8_t x) {
|
|
||||||
float16x4_t y;
|
|
||||||
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
|
|
||||||
y = vpmax_f16(y, y);
|
|
||||||
y = vpmax_f16(y, y);
|
|
||||||
return vget_lane_f16(y, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Implementation of folding sum for ARM neon. This should possibly be
|
|
||||||
* refactored out of softmax.cpp at some point.
|
|
||||||
*/
|
|
||||||
inline float16_t neon_reduce_add(float16x8_t x) {
|
|
||||||
float16x4_t y;
|
|
||||||
float16x4_t zero = vdup_n_f16(0);
|
|
||||||
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
|
|
||||||
y = vpadd_f16(y, zero);
|
|
||||||
y = vpadd_f16(y, zero);
|
|
||||||
return vget_lane_f16(y, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename VT>
|
|
||||||
struct NeonFp16SimdOps {
|
|
||||||
VT init(T a) {
|
|
||||||
return vdupq_n_f16(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT load(const T* a) {
|
|
||||||
return vld1q_f16(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
void store(T* dst, VT x) {
|
|
||||||
vst1q_f16(dst, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT max(VT a, VT b) {
|
|
||||||
return vmaxq_f16(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT exp(VT x) {
|
|
||||||
return neon_fast_exp(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT add(VT a, VT b) {
|
|
||||||
return vaddq_f16(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT sub(VT a, T b) {
|
|
||||||
return vsubq_f16(a, vdupq_n_f16(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
VT mul(VT a, VT b) {
|
|
||||||
return vmulq_f16(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT mul(VT a, T b) {
|
|
||||||
return vmulq_f16(a, vdupq_n_f16(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
T reduce_max(VT x) {
|
|
||||||
return neon_reduce_max(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
T reduce_add(VT x) {
|
|
||||||
return neon_reduce_add(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
|
|
||||||
template <typename T, typename VT>
|
|
||||||
struct AccelerateSimdOps {
|
|
||||||
VT init(T a) {
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT load(const T* a) {
|
|
||||||
return *(VT*)a;
|
|
||||||
}
|
|
||||||
|
|
||||||
void store(T* dst, VT x) {
|
|
||||||
*(VT*)dst = x;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT max(VT a, VT b) {
|
|
||||||
return simd_max(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT exp(VT x) {
|
|
||||||
return simd_fast_exp(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT add(VT a, VT b) {
|
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT sub(VT a, T b) {
|
|
||||||
return a - b;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT mul(VT a, VT b) {
|
|
||||||
return a * b;
|
|
||||||
}
|
|
||||||
|
|
||||||
VT mul(VT a, T b) {
|
|
||||||
return a * b;
|
|
||||||
}
|
|
||||||
|
|
||||||
T reduce_max(VT x) {
|
|
||||||
return simd_reduce_max(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
T reduce_add(VT x) {
|
|
||||||
return simd_reduce_add(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
|
||||||
void softmax(const array& in, array& out) {
|
|
||||||
Ops ops;
|
|
||||||
|
|
||||||
const T* in_ptr = in.data<T>();
|
|
||||||
T* out_ptr = out.data<T>();
|
|
||||||
int M = in.shape().back();
|
|
||||||
int L = in.data_size() / M;
|
|
||||||
const T* current_in_ptr;
|
|
||||||
T* current_out_ptr;
|
|
||||||
|
|
||||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
|
||||||
// Find the maximum
|
|
||||||
current_in_ptr = in_ptr;
|
|
||||||
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
|
||||||
size_t s = M;
|
|
||||||
while (s >= N) {
|
|
||||||
VT vals;
|
|
||||||
if constexpr (std::is_same<T, AccT>::value) {
|
|
||||||
vals = ops.load(current_in_ptr);
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
vals[i] = static_cast<AccT>(current_in_ptr[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
vmaximum = ops.max(vals, vmaximum);
|
|
||||||
current_in_ptr += N;
|
|
||||||
s -= N;
|
|
||||||
}
|
|
||||||
AccT maximum = ops.reduce_max(vmaximum);
|
|
||||||
while (s-- > 0) {
|
|
||||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
|
||||||
current_in_ptr++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the normalizer and the exponentials
|
|
||||||
VT vnormalizer = ops.init(0.0);
|
|
||||||
current_out_ptr = out_ptr;
|
|
||||||
current_in_ptr = in_ptr;
|
|
||||||
s = M;
|
|
||||||
while (s >= N) {
|
|
||||||
VT vexp;
|
|
||||||
if constexpr (std::is_same<T, AccT>::value) {
|
|
||||||
vexp = ops.load(current_in_ptr);
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
vexp = ops.exp(ops.sub(vexp, maximum));
|
|
||||||
if constexpr (std::is_same<T, AccT>::value) {
|
|
||||||
ops.store(current_out_ptr, vexp);
|
|
||||||
}
|
|
||||||
vnormalizer = ops.add(vnormalizer, vexp);
|
|
||||||
current_in_ptr += N;
|
|
||||||
current_out_ptr += N;
|
|
||||||
s -= N;
|
|
||||||
}
|
|
||||||
AccT normalizer = ops.reduce_add(vnormalizer);
|
|
||||||
while (s-- > 0) {
|
|
||||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
|
||||||
if (std::is_same<T, AccT>::value) {
|
|
||||||
*current_out_ptr = _exp;
|
|
||||||
}
|
|
||||||
normalizer += _exp;
|
|
||||||
current_in_ptr++;
|
|
||||||
current_out_ptr++;
|
|
||||||
}
|
|
||||||
normalizer = 1 / normalizer;
|
|
||||||
|
|
||||||
// Normalize
|
|
||||||
current_out_ptr = out_ptr;
|
|
||||||
current_in_ptr = in_ptr;
|
|
||||||
s = M;
|
|
||||||
while (s >= N) {
|
|
||||||
if constexpr (std::is_same<T, AccT>::value) {
|
|
||||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
|
||||||
} else {
|
|
||||||
VT vexp;
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
|
||||||
}
|
|
||||||
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
current_out_ptr[i] = vexp[i];
|
|
||||||
}
|
|
||||||
current_in_ptr += N;
|
|
||||||
}
|
|
||||||
current_out_ptr += N;
|
|
||||||
s -= N;
|
|
||||||
}
|
|
||||||
while (s-- > 0) {
|
|
||||||
if constexpr (std::is_same<T, AccT>::value) {
|
|
||||||
*current_out_ptr *= normalizer;
|
|
||||||
} else {
|
|
||||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
|
||||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
|
||||||
current_in_ptr++;
|
|
||||||
}
|
|
||||||
current_out_ptr++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
|
|
||||||
// Make sure that the last dimension is contiguous
|
|
||||||
auto check_input = [](array x) {
|
|
||||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
|
||||||
if (x.ndim() > 1) {
|
|
||||||
auto s = x.strides()[x.ndim() - 2];
|
|
||||||
no_copy &= (s == 0 || s == x.shape().back());
|
|
||||||
}
|
|
||||||
if (no_copy) {
|
|
||||||
return x;
|
|
||||||
} else {
|
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
|
||||||
copy(x, x_copy, CopyType::General);
|
|
||||||
return x_copy;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
array in = check_input(std::move(inputs[0]));
|
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
|
||||||
in.data_size(),
|
|
||||||
in.strides(),
|
|
||||||
in.flags());
|
|
||||||
|
|
||||||
switch (in.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
case uint8:
|
|
||||||
case uint16:
|
|
||||||
case uint32:
|
|
||||||
case uint64:
|
|
||||||
case int8:
|
|
||||||
case int16:
|
|
||||||
case int32:
|
|
||||||
case int64:
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"Softmax is defined only for floating point types");
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
softmax<
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
simd_float16,
|
|
||||||
AccelerateSimdOps<float, simd_float16>,
|
|
||||||
16>(in, out);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
if (precise_) {
|
|
||||||
softmax<
|
|
||||||
float16_t,
|
|
||||||
float,
|
|
||||||
simd_float16,
|
|
||||||
AccelerateSimdOps<float, simd_float16>,
|
|
||||||
16>(in, out);
|
|
||||||
} else {
|
|
||||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
softmax<
|
|
||||||
float16_t,
|
|
||||||
float16_t,
|
|
||||||
float16x8_t,
|
|
||||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
|
||||||
8>(in, out);
|
|
||||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
eval(inputs, out); // Redirect to common backend for consistency
|
|
||||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
eval(inputs, out);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
eval(inputs, out);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
|
||||||
#include "mlx/dtype.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
|
||||||
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
|
||||||
switch (kindof(mlx_dtype)) {
|
|
||||||
case Dtype::Kind::b:
|
|
||||||
return BNNSDataTypeBoolean;
|
|
||||||
case Dtype::Kind::u:
|
|
||||||
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
|
||||||
case Dtype::Kind::i:
|
|
||||||
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
|
||||||
case Dtype::Kind::f:
|
|
||||||
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
|
||||||
case Dtype::Kind::V:
|
|
||||||
return BNNSDataTypeBFloat16;
|
|
||||||
case Dtype::Kind::c:
|
|
||||||
throw std::invalid_argument("BNNS does not support complex types");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,61 +1,9 @@
|
|||||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|
||||||
set(COMPILER ${CMAKE_C_COMPILER})
|
|
||||||
set(CLANG TRUE)
|
|
||||||
else()
|
|
||||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_custom_command(
|
|
||||||
OUTPUT compiled_preamble.cpp
|
|
||||||
COMMAND
|
|
||||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
|
||||||
${PROJECT_SOURCE_DIR} ${CLANG}
|
|
||||||
DEPENDS make_compiled_preamble.sh
|
|
||||||
compiled_preamble.h
|
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
|
||||||
ops.h)
|
|
||||||
|
|
||||||
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
|
||||||
|
|
||||||
add_dependencies(mlx cpu_compiled_preamble)
|
|
||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
|
||||||
|
|
||||||
if(IOS)
|
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
|
||||||
else()
|
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
|
|
||||||
endif()
|
|
||||||
|
|||||||
@@ -1,74 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/array.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void arange(T start, T next, array& out, size_t size) {
|
|
||||||
auto ptr = out.data<T>();
|
|
||||||
auto step_size = next - start;
|
|
||||||
for (int i = 0; i < size; ++i) {
|
|
||||||
ptr[i] = start;
|
|
||||||
start += step_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void arange(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
array& out,
|
|
||||||
double start,
|
|
||||||
double step) {
|
|
||||||
assert(inputs.size() == 0);
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
throw std::runtime_error("Bool type unsupported for arange.");
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
arange<uint8_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
arange<uint16_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
arange<uint32_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
arange<uint64_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
arange<int8_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
arange<int16_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
arange<int32_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
arange<int64_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
arange<float16_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
arange<float>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
arange<bfloat16_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
arange<complex64_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
#include "utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename InT, typename OpT>
|
|
||||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
|
||||||
auto axis_size = in.shape()[axis];
|
|
||||||
auto axis_stride = in.strides()[axis];
|
|
||||||
std::vector<size_t> strides = in.strides();
|
|
||||||
std::vector<int> shape = in.shape();
|
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
shape.erase(shape.begin() + axis);
|
|
||||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
|
||||||
auto loc = elem_to_loc(i, shape, strides);
|
|
||||||
auto in_ptr = in.data<InT>() + loc;
|
|
||||||
uint32_t ind_v = 0;
|
|
||||||
InT v = (*in_ptr);
|
|
||||||
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
|
|
||||||
op(j, (*in_ptr), &ind_v, &v);
|
|
||||||
}
|
|
||||||
out.data<uint32_t>()[i] = ind_v;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename InT>
|
|
||||||
void arg_reduce_dispatch(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
ArgReduce::ReduceType rtype,
|
|
||||||
int axis) {
|
|
||||||
switch (rtype) {
|
|
||||||
case ArgReduce::ArgMin: {
|
|
||||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
|
||||||
if (x < (*y)) {
|
|
||||||
(*y) = x;
|
|
||||||
(*ind_y) = ind_x;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
arg_reduce<InT>(in, out, op, axis);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case ArgReduce::ArgMax: {
|
|
||||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
|
||||||
if (x > (*y)) {
|
|
||||||
(*y) = x;
|
|
||||||
(*ind_y) = ind_x;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
arg_reduce<InT>(in, out, op, axis);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void ArgReduce::eval(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
switch (in.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user