mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-03 06:14:43 +08:00
Compare commits
1022 Commits
v0.0.6
...
split_logs
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7c99acb799 | ||
![]() |
5a1a5d5ed1 | ||
![]() |
1683975acf | ||
![]() |
af705590ac | ||
![]() |
825124af8f | ||
![]() |
9c5e7da507 | ||
![]() |
481349495b | ||
![]() |
9daa6b003f | ||
![]() |
a3a632d567 | ||
![]() |
e496c5a4b4 | ||
![]() |
ea890d8710 | ||
![]() |
aa5d84f102 | ||
![]() |
f1606486d2 | ||
![]() |
87720a8908 | ||
![]() |
bb6565ef14 | ||
![]() |
7bb063bcb3 | ||
![]() |
b36dd472bb | ||
![]() |
167b759a38 | ||
![]() |
99b9868859 | ||
![]() |
6b2d5448f2 | ||
![]() |
eaf709b83e | ||
![]() |
f0e70afff0 | ||
![]() |
86984cad68 | ||
![]() |
fbc89e3ced | ||
![]() |
38c1e720c2 | ||
![]() |
600e87e03c | ||
![]() |
3836445241 | ||
![]() |
1d2c9d6a07 | ||
![]() |
e8ac6bd2f5 | ||
![]() |
fdadc4f22c | ||
![]() |
79b527f45f | ||
![]() |
dc4eada7f0 | ||
![]() |
70ebc3b598 | ||
![]() |
b13f2aed16 | ||
![]() |
5f04c0f818 | ||
![]() |
55935ccae7 | ||
![]() |
b529515eb1 | ||
![]() |
3cde719eb7 | ||
![]() |
5de6d94a90 | ||
![]() |
99eefd2ec0 | ||
![]() |
e9e268336b | ||
![]() |
7275ac7523 | ||
![]() |
c4189a38e4 | ||
![]() |
68d1b3256b | ||
![]() |
9c6953bda7 | ||
![]() |
ef7ece9851 | ||
![]() |
ddaa4b7dcb | ||
![]() |
dfae2c6989 | ||
![]() |
515f104926 | ||
![]() |
9ecefd56db | ||
![]() |
e5d35aa187 | ||
![]() |
00794c42bc | ||
![]() |
08a1bf3f10 | ||
![]() |
60c4154346 | ||
![]() |
f2c85308c1 | ||
![]() |
1a28b69ee2 | ||
![]() |
ba09f01ce8 | ||
![]() |
6cf48872b7 | ||
![]() |
7b3b8fa000 | ||
![]() |
ec5e2aae61 | ||
![]() |
86389bf970 | ||
![]() |
3290bfa690 | ||
![]() |
8777fd104f | ||
![]() |
c41f7565ed | ||
![]() |
9ba81e3da4 | ||
![]() |
c23888acd7 | ||
![]() |
f98ce25ab9 | ||
![]() |
de5f38fd48 | ||
![]() |
ec2854b13a | ||
![]() |
90823d2938 | ||
![]() |
5f5770e3a2 | ||
![]() |
28f39e9038 | ||
![]() |
b2d2b37888 | ||
![]() |
fe597e141c | ||
![]() |
72ca1539e0 | ||
![]() |
13b26775f1 | ||
![]() |
05d7118561 | ||
![]() |
98b901ad66 | ||
![]() |
5580b47291 | ||
![]() |
bc62932984 | ||
![]() |
a6b5d6e759 | ||
![]() |
a8931306e1 | ||
![]() |
fecdb8717e | ||
![]() |
916fd273ea | ||
![]() |
0da8506552 | ||
![]() |
eda7a7b43e | ||
![]() |
022eabb734 | ||
![]() |
aba899cef8 | ||
![]() |
6a40e1c176 | ||
![]() |
9307b2ab8b | ||
![]() |
522d8d3917 | ||
![]() |
a84cc0123f | ||
![]() |
f018e248cd | ||
![]() |
cfd7237a80 | ||
![]() |
4eef8102c9 | ||
![]() |
69e4dd506b | ||
![]() |
25814a9458 | ||
![]() |
2a980a76ce | ||
![]() |
d343782c8b | ||
![]() |
4e1994e9d7 | ||
![]() |
65a38c452b | ||
![]() |
7b7e2352cd | ||
![]() |
1177d28395 | ||
![]() |
005e7efa64 | ||
![]() |
b42d13ec84 | ||
![]() |
9adcd1a650 | ||
![]() |
3c164fca8c | ||
![]() |
95e335db7b | ||
![]() |
f90206ad74 | ||
![]() |
3779150750 | ||
![]() |
0a9777aa5c | ||
![]() |
45ad06aac8 | ||
![]() |
c6ea2ba329 | ||
![]() |
2770a10240 | ||
![]() |
d2a94f9e6a | ||
![]() |
32da94507a | ||
![]() |
736a340478 | ||
![]() |
117e1355a2 | ||
![]() |
3c3e558c60 | ||
![]() |
cffceda6ee | ||
![]() |
048805ad2c | ||
![]() |
d14c9fe7ea | ||
![]() |
5db90ce822 | ||
![]() |
d699cc1330 | ||
![]() |
c4230747a1 | ||
![]() |
5245f12a46 | ||
![]() |
a198b2787e | ||
![]() |
04edad8c59 | ||
![]() |
392b3060b0 | ||
![]() |
85b34d59bc | ||
![]() |
f599c11bc8 | ||
![]() |
0792ff02ff | ||
![]() |
fd0d63ba5b | ||
![]() |
3835a428c5 | ||
![]() |
9680f72cca | ||
![]() |
a0737273d3 | ||
![]() |
e613d0eaf0 | ||
![]() |
6bcd6bcf70 | ||
![]() |
ba12e4999a | ||
![]() |
4e7cd31d12 | ||
![]() |
5e6c130d93 | ||
![]() |
5d68082881 | ||
![]() |
607181644f | ||
![]() |
89d327075f | ||
![]() |
6bf00ef631 | ||
![]() |
7d042f17fe | ||
![]() |
28b8079e30 | ||
![]() |
7face5d9fd | ||
![]() |
a44dc4bdb0 | ||
![]() |
2d0f384b6f | ||
![]() |
8ff84b5c43 | ||
![]() |
10b271d963 | ||
![]() |
0ebc8a3d25 | ||
![]() |
bbda0fdbdb | ||
![]() |
c86422bdd4 | ||
![]() |
c707b2b0a6 | ||
![]() |
78ba24c37d | ||
![]() |
1a2cb72030 | ||
![]() |
344a29506e | ||
![]() |
71de73a668 | ||
![]() |
4c1dfa58b7 | ||
![]() |
5274c3c43f | ||
![]() |
1762793989 | ||
![]() |
6cec78d8f2 | ||
![]() |
2dc307f2e6 | ||
![]() |
7aea5b1895 | ||
![]() |
9733e16496 | ||
![]() |
7f2d1024f3 | ||
![]() |
428f589364 | ||
![]() |
5cd97f7ffe | ||
![]() |
e425dc00c0 | ||
![]() |
d274ae77f2 | ||
![]() |
55c5ac7820 | ||
![]() |
0145911bea | ||
![]() |
0a5215693e | ||
![]() |
2a45056ba8 | ||
![]() |
142b77751d | ||
![]() |
a5ededf1c3 | ||
![]() |
7df3f792a2 | ||
![]() |
9eb7d7362f | ||
![]() |
1c0c118f7c | ||
![]() |
1a1b2108ec | ||
![]() |
b6c6552d20 | ||
![]() |
83a0340fa7 | ||
![]() |
a62fc1b39f | ||
![]() |
af1b725fda | ||
![]() |
9174606d4c | ||
![]() |
ca305afdbe | ||
![]() |
fe5987b81d | ||
![]() |
a229c8cef0 | ||
![]() |
f6c0499b8d | ||
![]() |
1156c84e86 | ||
![]() |
ec7c7def40 | ||
![]() |
2d8e667400 | ||
![]() |
80c863b972 | ||
![]() |
f5cc1eea72 | ||
![]() |
b7c9f1d38f | ||
![]() |
c6fc07f1f4 | ||
![]() |
ded914f442 | ||
![]() |
4758c8baa1 | ||
![]() |
7064fed1b1 | ||
![]() |
1017ac4a9e | ||
![]() |
ccb61d7aae | ||
![]() |
2235dee906 | ||
![]() |
28091aa1ff | ||
![]() |
121d9a0702 | ||
![]() |
0cea88bcc5 | ||
![]() |
72146fc4cd | ||
![]() |
e6a7ab9675 | ||
![]() |
1f4c127fb9 | ||
![]() |
90532b1f37 | ||
![]() |
a8666a757a | ||
![]() |
a4667da1eb | ||
![]() |
0c259961ac | ||
![]() |
f288db8d34 | ||
![]() |
33421c1dd3 | ||
![]() |
5cc5201914 | ||
![]() |
252e423e81 | ||
![]() |
a4a2764a52 | ||
![]() |
ab8e832c18 | ||
![]() |
1ce0c0fcb0 | ||
![]() |
657f466402 | ||
![]() |
c7b0300af5 | ||
![]() |
da8c885784 | ||
![]() |
1ccaf80575 | ||
![]() |
ec36bfa317 | ||
![]() |
b8f76f717a | ||
![]() |
d1766f2c70 | ||
![]() |
516ded618b | ||
![]() |
c9c81d0584 | ||
![]() |
545f84d905 | ||
![]() |
d5ec172c95 | ||
![]() |
25b3a3e541 | ||
![]() |
058d6ce683 | ||
![]() |
eab93985b8 | ||
![]() |
b51d70a83c | ||
![]() |
259025100e | ||
![]() |
c9d30aa6ac | ||
![]() |
8544b42007 | ||
![]() |
6fa0501387 | ||
![]() |
ae69cb15e9 | ||
![]() |
a64a8dfe45 | ||
![]() |
491fa95b1f | ||
![]() |
92ec632ad5 | ||
![]() |
8ecdfb718b | ||
![]() |
4ba0c24a8f | ||
![]() |
935c8c4bb1 | ||
![]() |
88f993da38 | ||
![]() |
ebfe64b92d | ||
![]() |
0308e9af71 | ||
![]() |
c3628eea49 | ||
![]() |
e03f0372b1 | ||
![]() |
f17536af9c | ||
![]() |
ed4ec81bca | ||
![]() |
7480059306 | ||
![]() |
8bae22b0fa | ||
![]() |
49c34c4161 | ||
![]() |
5548fcc96d | ||
![]() |
070bd433ab | ||
![]() |
c8fb54951a | ||
![]() |
f110357aaa | ||
![]() |
a6b426422e | ||
![]() |
d03c01dfbc | ||
![]() |
a82996e9fb | ||
![]() |
af5a614aad | ||
![]() |
f9640e049d | ||
![]() |
4768c61b57 | ||
![]() |
dfccd17ab9 | ||
![]() |
635117c5d4 | ||
![]() |
50f3535693 | ||
![]() |
9111999af3 | ||
![]() |
6bd28d246e | ||
![]() |
4d595a2a39 | ||
![]() |
3a21f61772 | ||
![]() |
4e1e9520e1 | ||
![]() |
0bf19037ca | ||
![]() |
f3dfa36a3a | ||
![]() |
4f9b60dd53 | ||
![]() |
f76a49e555 | ||
![]() |
310ad8d9db | ||
![]() |
56db268f47 | ||
![]() |
92ab6bdeb8 | ||
![]() |
0070e360a1 | ||
![]() |
9df8fed046 | ||
![]() |
a59fae040f | ||
![]() |
29a620cab2 | ||
![]() |
87d7a2520e | ||
![]() |
40c62c1321 | ||
![]() |
35b412c099 | ||
![]() |
d0f471cff7 | ||
![]() |
6f316b8bf5 | ||
![]() |
7c10c93a1f | ||
![]() |
d92ea094f1 | ||
![]() |
6ae5423b4a | ||
![]() |
9635cffdc8 | ||
![]() |
96986fb362 | ||
![]() |
3ceb341a75 | ||
![]() |
50fa705125 | ||
![]() |
69a2991614 | ||
![]() |
fd3377dd1f | ||
![]() |
d0b6cb0425 | ||
![]() |
95c4a2e3af | ||
![]() |
bc2a29f033 | ||
![]() |
3bb5b4a302 | ||
![]() |
fc88fd9097 | ||
![]() |
c5b0928c1f | ||
![]() |
e047fd977d | ||
![]() |
9d40e521d7 | ||
![]() |
1445dcaa60 | ||
![]() |
e4eeb4e910 | ||
![]() |
aa86876813 | ||
![]() |
974bb54ab2 | ||
![]() |
9bc2183a31 | ||
![]() |
d4b222b6d3 | ||
![]() |
af2af818a6 | ||
![]() |
698e63a608 | ||
![]() |
211411faf2 | ||
![]() |
bb303c45a5 | ||
![]() |
6f7986d592 | ||
![]() |
7cbb4aef17 | ||
![]() |
02bec0bb6d | ||
![]() |
c79f6a4a8c | ||
![]() |
0c5eea226b | ||
![]() |
dcca0d7477 | ||
![]() |
0d5e7716ad | ||
![]() |
d8c824c594 | ||
![]() |
cb431dfc9f | ||
![]() |
61d787726a | ||
![]() |
5e89aace9b | ||
![]() |
2af7e8a9a6 | ||
![]() |
2419edd5b2 | ||
![]() |
bf481e8e5d | ||
![]() |
9d7fa6b8e6 | ||
![]() |
073076ac7d | ||
![]() |
9bd03dd9b4 | ||
![]() |
6931f84412 | ||
![]() |
16ec0556a0 | ||
![]() |
610af352d4 | ||
![]() |
b35f1e3c9c | ||
![]() |
dfa0b9aab4 | ||
![]() |
a4c47b0276 | ||
![]() |
111fefd5e9 | ||
![]() |
c1fe1ef081 | ||
![]() |
8c34c9dac4 | ||
![]() |
91c0277356 | ||
![]() |
9f0d5c12fc | ||
![]() |
59247c2b62 | ||
![]() |
9a3842a2d9 | ||
![]() |
726dbd9267 | ||
![]() |
54f05e7195 | ||
![]() |
26be608470 | ||
![]() |
248431eb3c | ||
![]() |
76f275b4df | ||
![]() |
f1951d6cce | ||
![]() |
62f297b51d | ||
![]() |
09bc32f62f | ||
![]() |
46d8b16ab4 | ||
![]() |
42533931fa | ||
![]() |
9bd3a7102f | ||
![]() |
9e516b71ea | ||
![]() |
eac961ddb1 | ||
![]() |
57c6aa7188 | ||
![]() |
cde5b4ad80 | ||
![]() |
4f72c66911 | ||
![]() |
960e3f0f05 | ||
![]() |
884af42da2 | ||
![]() |
048fabdabd | ||
![]() |
917252a5a1 | ||
![]() |
1a992e31e8 | ||
![]() |
d2ff04a4f2 | ||
![]() |
015c247393 | ||
![]() |
d3cd26820e | ||
![]() |
91f6c499d7 | ||
![]() |
35e9c87ab9 | ||
![]() |
8e88e30d95 | ||
![]() |
0eb56d5be0 | ||
![]() |
f70764a162 | ||
![]() |
dad1b00b13 | ||
![]() |
430ffef58a | ||
![]() |
3d17077187 | ||
![]() |
c9b41d460f | ||
![]() |
32972a5924 | ||
![]() |
f6afb9c09b | ||
![]() |
3ddc07e936 | ||
![]() |
c26208f67d | ||
![]() |
d15fa13daf | ||
![]() |
58a855682c | ||
![]() |
92d7cb71f8 | ||
![]() |
50d8bed468 | ||
![]() |
9dd72cd421 | ||
![]() |
343aa46b78 | ||
![]() |
b8ab89b413 | ||
![]() |
f9f8c167d4 | ||
![]() |
3f86399922 | ||
![]() |
2b8ace6a03 | ||
![]() |
0ab8e099e8 | ||
![]() |
020f048cd0 | ||
![]() |
881615b072 | ||
![]() |
0eef4febfd | ||
![]() |
b54a70ec2d | ||
![]() |
bf6ec92216 | ||
![]() |
c21331d47f | ||
![]() |
e1c9600da3 | ||
![]() |
1fa0d20a30 | ||
![]() |
3274c6a087 | ||
![]() |
9b12093739 | ||
![]() |
f374b6ca4d | ||
![]() |
0070e1db40 | ||
![]() |
95d04805b3 | ||
![]() |
e4534dac17 | ||
![]() |
fef3c4ec1d | ||
![]() |
1bdc038bf9 | ||
![]() |
5523d9c426 | ||
![]() |
d878015228 | ||
![]() |
5900e3249f | ||
![]() |
bacced53d3 | ||
![]() |
4a64d4bff1 | ||
![]() |
b1e2b53c2d | ||
![]() |
11354d5bff | ||
![]() |
718aea3f1d | ||
![]() |
5b6f38df2b | ||
![]() |
0b4a58699e | ||
![]() |
4f9f9ebb6f | ||
![]() |
afc9c0ec1b | ||
![]() |
195b429d99 | ||
![]() |
2b878e9dd7 | ||
![]() |
67b6bf530d | ||
![]() |
6af5ca35b2 | ||
![]() |
4f46e9c997 | ||
![]() |
c6739ba7f3 | ||
![]() |
914409fef9 | ||
![]() |
8d68a3e805 | ||
![]() |
6bbcc453ef | ||
![]() |
d5ed4d7a71 | ||
![]() |
669c27140d | ||
![]() |
adcc88e208 | ||
![]() |
d6492b0163 | ||
![]() |
b3f52c9fbe | ||
![]() |
bd8396fad8 | ||
![]() |
d0c58841d1 | ||
![]() |
881f09b2e2 | ||
![]() |
8b30acd7eb | ||
![]() |
02efb310ca | ||
![]() |
e7e59c6f05 | ||
![]() |
3ae6aabe9f | ||
![]() |
dc627dcb5e | ||
![]() |
efeb9c0f02 | ||
![]() |
ba3e913c7a | ||
![]() |
7cca1727af | ||
![]() |
11371fe251 | ||
![]() |
41c603d48a | ||
![]() |
969337345f | ||
![]() |
9592766939 | ||
![]() |
58dca7d846 | ||
![]() |
0d302cd25b | ||
![]() |
da691257ec | ||
![]() |
1600092e92 | ||
![]() |
dba2bd1105 | ||
![]() |
28be4de7c2 | ||
![]() |
a6c3b38fba | ||
![]() |
fcb65a3897 | ||
![]() |
4e22a1dffe | ||
![]() |
291cf40aca | ||
![]() |
bd47e1f066 | ||
![]() |
e6b223df5f | ||
![]() |
e64349bbdd | ||
![]() |
cdb59faea6 | ||
![]() |
1d94ac3f90 | ||
![]() |
5f7d19d1f5 | ||
![]() |
2fdf9eb535 | ||
![]() |
860d3a50d7 | ||
![]() |
d1183821a7 | ||
![]() |
8081df79be | ||
![]() |
64bec4fad7 | ||
![]() |
b96e105244 | ||
![]() |
3b4d5484c7 | ||
![]() |
684e11c664 | ||
![]() |
b57a52813b | ||
![]() |
da8deb2b62 | ||
![]() |
98b6ce3460 | ||
![]() |
f9e00efe31 | ||
![]() |
0fd2a1f4b0 | ||
![]() |
df3233454d | ||
![]() |
82db84b899 | ||
![]() |
8ae751d3da | ||
![]() |
d40e76809f | ||
![]() |
bb1b76d9dc | ||
![]() |
9d26441224 | ||
![]() |
f12f24a77c | ||
![]() |
ae5b5cabfd | ||
![]() |
d0630ffe8c | ||
![]() |
99bb7d3a58 | ||
![]() |
63ae767232 | ||
![]() |
eaaea02010 | ||
![]() |
a098bc92e0 | ||
![]() |
1086dc4db0 | ||
![]() |
19fb69e2ed | ||
![]() |
9231617eb3 | ||
![]() |
32668a7317 | ||
![]() |
780c197f95 | ||
![]() |
eb8819e91e | ||
![]() |
30bbea2f08 | ||
![]() |
635ccd9e25 | ||
![]() |
8c9f0278b9 | ||
![]() |
58d0e199e1 | ||
![]() |
10b5835501 | ||
![]() |
6c8dd307eb | ||
![]() |
43ffdab172 | ||
![]() |
40b6d67333 | ||
![]() |
c52d1600f0 | ||
![]() |
aa1d6cadad | ||
![]() |
6e06e3a904 | ||
![]() |
8cfb9fc0b8 | ||
![]() |
7b456fd2c0 | ||
![]() |
e9e53856d2 | ||
![]() |
5029894662 | ||
![]() |
baf9fa5f42 | ||
![]() |
7f914365fd | ||
![]() |
ebd7135b50 | ||
![]() |
50eff6a10a | ||
![]() |
c34a5ae7f7 | ||
![]() |
e2aa6ec8ae | ||
![]() |
6768c6a54a | ||
![]() |
6307d166eb | ||
![]() |
1fba87b0df | ||
![]() |
df124e018a | ||
![]() |
2f83d6e4b7 | ||
![]() |
987785d8d7 | ||
![]() |
8c01a7893b | ||
![]() |
218047c75a | ||
![]() |
d0da74209b | ||
![]() |
5c1fa64fb0 | ||
![]() |
a3c287354f | ||
![]() |
03cf033f82 | ||
![]() |
bdb36c9a63 | ||
![]() |
20bb301195 | ||
![]() |
d6383a1c6a | ||
![]() |
b05bcfd27f | ||
![]() |
2615660e62 | ||
![]() |
5b0af4cdb1 | ||
![]() |
8c2e15e6c8 | ||
![]() |
56c8a33439 | ||
![]() |
4eef1e8a3e | ||
![]() |
95d11bda06 | ||
![]() |
af9079cc1f | ||
![]() |
2d6cd47713 | ||
![]() |
fe3167d7ea | ||
![]() |
31e134be35 | ||
![]() |
e84ba8056d | ||
![]() |
f20e97b092 | ||
![]() |
934683088e | ||
![]() |
de2b9e7d0a | ||
![]() |
dd7d8e5e29 | ||
![]() |
df964132fb | ||
![]() |
709ccc6800 | ||
![]() |
cf236fc390 | ||
![]() |
27d70c7d9d | ||
![]() |
0e585b4409 | ||
![]() |
0163a8e57a | ||
![]() |
578842954c | ||
![]() |
496315fe1d | ||
![]() |
0fe6895893 | ||
![]() |
0b7d71fd2f | ||
![]() |
83b11bc58d | ||
![]() |
375a8bbdcc | ||
![]() |
ea9090bbc4 | ||
![]() |
81def6ac76 | ||
![]() |
3de8ce3f3c | ||
![]() |
4d485fca24 | ||
![]() |
1865299a30 | ||
![]() |
3576b547c5 | ||
![]() |
079882495d | ||
![]() |
ab977109db | ||
![]() |
fd1c08137b | ||
![]() |
76b6cece46 | ||
![]() |
9f0df51f8d | ||
![]() |
e7a2a3dcd1 | ||
![]() |
a87ef5bfc1 | ||
![]() |
9f9cb7a2ef | ||
![]() |
7e26fd8032 | ||
![]() |
eab2685c67 | ||
![]() |
50dfb664db | ||
![]() |
0189ab6ab6 | ||
![]() |
9401507336 | ||
![]() |
eb8321d863 | ||
![]() |
79ef49b2c2 | ||
![]() |
e110ca11e2 | ||
![]() |
226748b3e7 | ||
![]() |
d568c7ee36 | ||
![]() |
e6fecbb3e1 | ||
![]() |
da83f899bb | ||
![]() |
7e5674d8be | ||
![]() |
0a558577bf | ||
![]() |
fb71a82ada | ||
![]() |
23406c9e9e | ||
![]() |
b3ec792380 | ||
![]() |
6a9b584f3d | ||
![]() |
81dd33af66 | ||
![]() |
8b76571896 | ||
![]() |
e78a6518fa | ||
![]() |
1873ffda01 | ||
![]() |
c417e42116 | ||
![]() |
358e1fd6ab | ||
![]() |
631dfbe673 | ||
![]() |
56a4eaed72 | ||
![]() |
bf925d9dc7 | ||
![]() |
1a7ed5dcb6 | ||
![]() |
5be5daa6ef | ||
![]() |
60cb11764e | ||
![]() |
cbd5445ea7 | ||
![]() |
2c7e9b5158 | ||
![]() |
2263e4b279 | ||
![]() |
863039da4c | ||
![]() |
7178ac0111 | ||
![]() |
e7f9710499 | ||
![]() |
ff4223904d | ||
![]() |
a9f80d60f6 | ||
![]() |
2e158cf6d0 | ||
![]() |
8bd6bfa4b5 | ||
![]() |
8b1906abd0 | ||
![]() |
06375e6605 | ||
![]() |
b21242faf1 | ||
![]() |
cc05a281c4 | ||
![]() |
fe96ceee66 | ||
![]() |
9814a2ae12 | ||
![]() |
6992498e7a | ||
![]() |
21623156a3 | ||
![]() |
79c859e2e0 | ||
![]() |
b00ac960b4 | ||
![]() |
02a9fc7bfa | ||
![]() |
f390957685 | ||
![]() |
17f57df797 | ||
![]() |
7f7b9662ea | ||
![]() |
19bef39f5c | ||
![]() |
a30e7ed2da | ||
![]() |
8db7161c94 | ||
![]() |
09f1777896 | ||
![]() |
490c0c4fdc | ||
![]() |
c4a471c99d | ||
![]() |
86f495985b | ||
![]() |
67d1894759 | ||
![]() |
5bfe89bdb1 | ||
![]() |
82463e9938 | ||
![]() |
771575d27b | ||
![]() |
20a01bbd9f | ||
![]() |
ec8578d41a | ||
![]() |
d0dbfe0b97 | ||
![]() |
3d405fb3b1 | ||
![]() |
b0012cdd0f | ||
![]() |
84d61d27aa | ||
![]() |
ed83908931 | ||
![]() |
ef5f7d1aea | ||
![]() |
090ff659dc | ||
![]() |
85c8a91a27 | ||
![]() |
581b699ac9 | ||
![]() |
8a0677d56d | ||
![]() |
b18468bf81 | ||
![]() |
107ba2891a | ||
![]() |
cd9e184529 | ||
![]() |
2e7c02d5cd | ||
![]() |
ae18326533 | ||
![]() |
91eba8e485 | ||
![]() |
d07e295c62 | ||
![]() |
dce4bd74a4 | ||
![]() |
ffff671273 | ||
![]() |
12d4507ee3 | ||
![]() |
8580d997ff | ||
![]() |
061cf9a4ce | ||
![]() |
99abb9eff4 | ||
![]() |
fffe072028 | ||
![]() |
a1a31eed27 | ||
![]() |
ae812350f9 | ||
![]() |
b63ef10a7f | ||
![]() |
42afe27e12 | ||
![]() |
76e63212ff | ||
![]() |
aac2f9fb61 | ||
![]() |
bddf23f175 | ||
![]() |
039da779d1 | ||
![]() |
d88d2124b5 | ||
![]() |
e142aaf8a1 | ||
![]() |
0caf35f4b8 | ||
![]() |
3fc993f82d | ||
![]() |
741eb28443 | ||
![]() |
1a87dc5ea8 | ||
![]() |
2427fa171e | ||
![]() |
639e06e1f3 | ||
![]() |
02fedbf1da | ||
![]() |
110d9b149d | ||
![]() |
9cbff5ec1d | ||
![]() |
433c0206b0 | ||
![]() |
8915901966 | ||
![]() |
f48bc496c7 | ||
![]() |
913b19329c | ||
![]() |
d8cb3128f6 | ||
![]() |
5f9ba3019f | ||
![]() |
46caf0bef0 | ||
![]() |
45f636e759 | ||
![]() |
a7b404ff53 | ||
![]() |
c4fd0e5ede | ||
![]() |
bab5386306 | ||
![]() |
aca7584635 | ||
![]() |
d611251502 | ||
![]() |
f30b659291 | ||
![]() |
90dfa43ff1 | ||
![]() |
dc175f08d3 | ||
![]() |
29221fa238 | ||
![]() |
a789685c63 | ||
![]() |
240d10699c | ||
![]() |
925014b661 | ||
![]() |
5611e1a95e | ||
![]() |
570f2bf29e | ||
![]() |
9948eddf11 | ||
![]() |
a3ee03da01 | ||
![]() |
28fcd2b519 | ||
![]() |
8e686764ac | ||
![]() |
479051ce1c | ||
![]() |
bfb5bad4f0 | ||
![]() |
1e16331d9c | ||
![]() |
be98f4ab6b | ||
![]() |
6ee1112f30 | ||
![]() |
8e5a5a1ccd | ||
![]() |
fcda3a0e66 | ||
![]() |
9663c22fe9 | ||
![]() |
f0ae00da12 | ||
![]() |
44390bd3d0 | ||
![]() |
2225374060 | ||
![]() |
105d236889 | ||
![]() |
53e6a9367c | ||
![]() |
f5a1582fe8 | ||
![]() |
a54f06b16f | ||
![]() |
4650d94d98 | ||
![]() |
a5681ebc52 | ||
![]() |
e849b3424a | ||
![]() |
b219d12a6b | ||
![]() |
cec8661113 | ||
![]() |
73a8c090e0 | ||
![]() |
db6796ac61 | ||
![]() |
9a8ee00246 | ||
![]() |
d39ed54f8e | ||
![]() |
16546c70d8 | ||
![]() |
eaba55c9bf | ||
![]() |
19ec023256 | ||
![]() |
63ab0ab580 | ||
![]() |
8dfc376c00 | ||
![]() |
1efee9db09 | ||
![]() |
43abc402d8 | ||
![]() |
3f8b1668c4 | ||
![]() |
76c919b4ec | ||
![]() |
29d0c10ee5 | ||
![]() |
5ad133f8bb | ||
![]() |
d0c544a868 | ||
![]() |
ffb19df3c0 | ||
![]() |
8b7532b9ab | ||
![]() |
366478c560 | ||
![]() |
8e5600022a | ||
![]() |
0e95b64942 | ||
![]() |
0ae22b915b | ||
![]() |
7c441600fe | ||
![]() |
a4d290adb9 | ||
![]() |
28301807c2 | ||
![]() |
74ed0974b3 | ||
![]() |
ec8a4864fa | ||
![]() |
b7588fd5d7 | ||
![]() |
f512b905c7 | ||
![]() |
afd5274049 | ||
![]() |
1074674e32 | ||
![]() |
7762e07fde | ||
![]() |
cbefd9129e | ||
![]() |
e39bebe13e | ||
![]() |
14b4e51a7c | ||
![]() |
cbcf44a4ca | ||
![]() |
859ae15a54 | ||
![]() |
0787724c44 | ||
![]() |
7b463ffb07 | ||
![]() |
6686e61ca4 | ||
![]() |
c096a77b9b | ||
![]() |
5121f028d9 | ||
![]() |
6a665ea6ed | ||
![]() |
bc06cb9ff6 | ||
![]() |
8e281c76c3 | ||
![]() |
d5964a2710 | ||
![]() |
cf3eb87e52 | ||
![]() |
ab3a466711 | ||
![]() |
4494970f47 | ||
![]() |
776c3d226d | ||
![]() |
f5f18b704f | ||
![]() |
420ff2f331 | ||
![]() |
56ba3ec40e | ||
![]() |
de3d2467a3 | ||
![]() |
fe1dabf272 | ||
![]() |
08226ab491 | ||
![]() |
3b661b7394 | ||
![]() |
e6418781ab | ||
![]() |
ac02cf33bd | ||
![]() |
22364c40b7 | ||
![]() |
d729a1991b | ||
![]() |
126c9869c8 | ||
![]() |
ad4a45e615 | ||
![]() |
04fc896016 | ||
![]() |
884b4ed43b | ||
![]() |
972d9a3aea | ||
![]() |
7dcdd88e27 | ||
![]() |
8120a3b65c | ||
![]() |
5798256fcf | ||
![]() |
d0fda82595 | ||
![]() |
f883fcede0 | ||
![]() |
e1bdf6a8d9 | ||
![]() |
1a4f4c5ea6 | ||
![]() |
0925af43b0 | ||
![]() |
dc937b8ed3 | ||
![]() |
c3965fc5ee | ||
![]() |
bf7cd29970 | ||
![]() |
a000d2288c | ||
![]() |
165abf0e4c | ||
![]() |
818cda16bc | ||
![]() |
85143fecdd | ||
![]() |
35431a4ac8 | ||
![]() |
ccf1645995 | ||
![]() |
1a48713d32 | ||
![]() |
1eb04aa23f | ||
![]() |
0c65517e91 | ||
![]() |
2fdc2462c3 | ||
![]() |
be6e9d6a9f | ||
![]() |
e54cbb7ba6 | ||
![]() |
40c108766b | ||
![]() |
4cc70290f7 | ||
![]() |
74caa68d02 | ||
![]() |
3756381358 | ||
![]() |
d12573daa6 | ||
![]() |
0dbc4c7547 | ||
![]() |
06072601ce | ||
![]() |
11d2c8f7a1 | ||
![]() |
7f3f8d8f8d | ||
![]() |
b96be943dc | ||
![]() |
b670485185 | ||
![]() |
b57bd0488d | ||
![]() |
221f8d3fc2 | ||
![]() |
5c03efaf29 | ||
![]() |
7dccd42133 | ||
![]() |
1b97b2958b | ||
![]() |
e5e816a5ef | ||
![]() |
28eac18571 | ||
![]() |
5fd11c347d | ||
![]() |
ef73393a19 | ||
![]() |
ea406d5e33 | ||
![]() |
146bd69470 | ||
![]() |
316ff490b3 | ||
![]() |
d40a04f8dc | ||
![]() |
d75ae52ecd | ||
![]() |
31fea3758e | ||
![]() |
e319383ef9 | ||
![]() |
5c3ac52dd7 | ||
![]() |
ebfd3618b0 | ||
![]() |
11a9fd40f0 | ||
![]() |
4fd2fb84a6 | ||
![]() |
9852af1a19 | ||
![]() |
16750f3c51 | ||
![]() |
95b5fb8245 | ||
![]() |
83f63f2184 | ||
![]() |
cb6156d35d | ||
![]() |
506d43035c | ||
![]() |
36cff34701 | ||
![]() |
e88e474fd1 | ||
![]() |
601c6d6aa8 | ||
![]() |
ba8d6bf365 | ||
![]() |
4a5f3b21bb | ||
![]() |
fcc5ac1c64 | ||
![]() |
bad67fec37 | ||
![]() |
199aebcf77 | ||
![]() |
0de5988f92 | ||
![]() |
143e2690d5 | ||
![]() |
375446453e | ||
![]() |
1895d34c20 | ||
![]() |
09b9275027 | ||
![]() |
d3a9005454 | ||
![]() |
3f7aba8498 | ||
![]() |
65d0b8df9f | ||
![]() |
3c2f192345 | ||
![]() |
37d98ba6ff | ||
![]() |
8993382aaa | ||
![]() |
07f35c9d8a | ||
![]() |
bf17ab5002 | ||
![]() |
8fa6b322b9 | ||
![]() |
874b739f3c | ||
![]() |
077c1ee64a | ||
![]() |
2463496471 | ||
![]() |
87b7fa9ba2 | ||
![]() |
624065c074 | ||
![]() |
f27ec5e097 | ||
![]() |
f30e63353a | ||
![]() |
4fe2fa2a64 | ||
![]() |
37fc9db82c | ||
![]() |
755dcf6137 | ||
![]() |
6b4b30e3fc | ||
![]() |
86e0c79467 | ||
![]() |
98c37d3a22 | ||
![]() |
f326dd8334 | ||
![]() |
6d3bee3364 | ||
![]() |
ecb174ca9d | ||
![]() |
7a34e46677 | ||
![]() |
92c22c1ea3 | ||
![]() |
d52383367a | ||
![]() |
363d3add6d | ||
![]() |
b207c2c86b | ||
![]() |
6bf779e72b | ||
![]() |
ddf50113c5 | ||
![]() |
6589c869d6 | ||
![]() |
f6feb61f92 | ||
![]() |
c4ec836523 | ||
![]() |
550d4bf7c0 | ||
![]() |
f6e911ced0 | ||
![]() |
3d99a8d31d | ||
![]() |
a749a91c75 | ||
![]() |
49a52610b7 | ||
![]() |
d1fef34138 | ||
![]() |
9c111f176d | ||
![]() |
78e5f2d17d | ||
![]() |
90c234b7ac | ||
![]() |
135fd796d2 | ||
![]() |
78102a47ad | ||
![]() |
556cdf0e06 | ||
![]() |
275db7221a | ||
![]() |
4a9012cba0 | ||
![]() |
a2bf7693dd | ||
![]() |
d8fabaa12b | ||
![]() |
4e290d282f | ||
![]() |
e72458a3fa | ||
![]() |
a2ffea683a | ||
![]() |
c15fe3e61b | ||
![]() |
f44c132f4a | ||
![]() |
92a2fdd577 | ||
![]() |
6022d4129e | ||
![]() |
4bc446be08 | ||
![]() |
41cc7bdfdb | ||
![]() |
6e81c3e164 | ||
![]() |
2e29d0815b | ||
![]() |
1b71487e1f | ||
![]() |
1416e7b664 | ||
![]() |
29081204d1 | ||
![]() |
006d01ba42 | ||
![]() |
46dc24d835 | ||
![]() |
c9934fe8a4 | ||
![]() |
975e265f74 | ||
![]() |
c92a134b0d | ||
![]() |
3b4f066dac | ||
![]() |
b7f905787e | ||
![]() |
e3e933c6bc | ||
![]() |
1d90a76d63 | ||
![]() |
961435a243 | ||
![]() |
e9ca65c939 | ||
![]() |
753867123d | ||
![]() |
f099ebe535 | ||
![]() |
f45f70f133 | ||
![]() |
0b8aeddac6 | ||
![]() |
432ee5650b | ||
![]() |
73321b8097 | ||
![]() |
022a944367 | ||
![]() |
026ef9aae4 | ||
![]() |
a611b0bc82 | ||
![]() |
449b43762e | ||
![]() |
6ea6b4258d | ||
![]() |
48f6ca8c3a | ||
![]() |
c6d2878c1a | ||
![]() |
b34bf5d52b | ||
![]() |
608bd43604 | ||
![]() |
4c48f6460d | ||
![]() |
1331fa19f6 | ||
![]() |
dfdb284e16 | ||
![]() |
d8f41a5c0f | ||
![]() |
b9e415d19c | ||
![]() |
c82a8cc526 | ||
![]() |
75dc537e44 | ||
![]() |
cf88db44b5 | ||
![]() |
16856a0160 | ||
![]() |
d752f8e142 | ||
![]() |
d2467c320d | ||
![]() |
0d31128a44 | ||
![]() |
1ac18eac20 | ||
![]() |
526466dd09 | ||
![]() |
e7f5059fe4 | ||
![]() |
d7ac050f4b | ||
![]() |
c7edafb729 | ||
![]() |
dff4a3833f | ||
![]() |
0782a4573a | ||
![]() |
af66a09bde | ||
![]() |
436bec9fd9 | ||
![]() |
99c80a2c8b | ||
![]() |
295ce9db09 | ||
![]() |
44c1ce5e6a | ||
![]() |
144ecff849 | ||
![]() |
350095ce6e | ||
![]() |
e09bf35b28 | ||
![]() |
99c20f523e | ||
![]() |
e3b8da2a49 | ||
![]() |
a020a2d49d | ||
![]() |
930b159885 | ||
![]() |
5ad8fb7268 | ||
![]() |
2aedf3e791 | ||
![]() |
473b6b43b4 | ||
![]() |
d29770eeaa | ||
![]() |
040c3bafab | ||
![]() |
05767b026f | ||
![]() |
a83d5d60bd | ||
![]() |
ff2b58e299 | ||
![]() |
4417e37ede | ||
![]() |
79c95b6919 | ||
![]() |
1f6ab6a556 | ||
![]() |
6b0d30bb85 | ||
![]() |
447bc089b9 | ||
![]() |
fc4e5b476b | ||
![]() |
d58ac083f3 | ||
![]() |
a123c3c7d2 | ||
![]() |
9e6b8c9f48 | ||
![]() |
22fee5a383 | ||
![]() |
7365d142a3 | ||
![]() |
8b227fa9af | ||
![]() |
8c3da54c7d | ||
![]() |
acf1721b98 | ||
![]() |
f91f450141 | ||
![]() |
cd3616a463 | ||
![]() |
d35fa1db41 | ||
![]() |
e8deca84e0 |
@@ -1,5 +1,8 @@
|
||||
version: 2.1
|
||||
|
||||
orbs:
|
||||
apple: ml-explore/pr-approval@0.1.0
|
||||
|
||||
parameters:
|
||||
nightly_build:
|
||||
type: boolean
|
||||
@@ -7,8 +10,65 @@ parameters:
|
||||
weekly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
parameters:
|
||||
upload-docs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install
|
||||
command: |
|
||||
brew install python@3.9
|
||||
brew install doxygen
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install -r docs/requirements.txt
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
||||
- when:
|
||||
condition:
|
||||
not: << parameters.upload-docs >>
|
||||
steps:
|
||||
- run:
|
||||
name: Build documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd docs && doxygen && make html O=-W
|
||||
- when:
|
||||
condition: << parameters.upload-docs >>
|
||||
steps:
|
||||
- add_ssh_keys:
|
||||
fingerprints:
|
||||
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
||||
- run:
|
||||
name: Upload documentation
|
||||
command: |
|
||||
source env/bin/activate
|
||||
git config user.email "mlx@group.apple.com"
|
||||
git config user.name "CircleCI Docs"
|
||||
git checkout gh-pages
|
||||
git rebase main
|
||||
cd docs
|
||||
git rm -rf build/html
|
||||
doxygen && make html O=-W
|
||||
git add -f build/html
|
||||
git commit -m "rebase"
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
@@ -25,176 +85,278 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
- run:
|
||||
name: Build python package
|
||||
name: Install Python package
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py develop
|
||||
- run:
|
||||
name: Run the python tests
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
python3 -m unittest discover python/tests
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j `nproc`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: ./build/tests/tests
|
||||
|
||||
mac_build_and_test:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=3.9
|
||||
conda activate runner-env
|
||||
brew install python@3.9
|
||||
brew install openmpi
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
pip install unittest-xml-reporting
|
||||
- run:
|
||||
name: Build python package
|
||||
name: Install Python package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
pip install -e . -v
|
||||
- run:
|
||||
name: Run the python tests
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd examples/extensions
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext -j8
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source env/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: |
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
|
||||
build_release:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
macos_version:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "14"
|
||||
default: "16.2.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: m2pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
||||
conda activate runner-env
|
||||
brew install python@<< parameters.python_version >>
|
||||
brew install openmpi
|
||||
python<< parameters.python_version >> -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install twine
|
||||
pip install build
|
||||
- run:
|
||||
name: Build pacakge
|
||||
name: Install Python package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
PYPI_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py bdist_wheel
|
||||
twine upload dist/* --repository mlx
|
||||
source env/bin/activate
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
<< parameters.build_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload dist/*
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_dev_release:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
build_linux_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
macos_version:
|
||||
extra_env:
|
||||
type: string
|
||||
default: "14"
|
||||
default: "DEV_RELEASE=1"
|
||||
docker:
|
||||
- image: ubuntu:20.04
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
name: Build wheel
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
||||
conda activate runner-env
|
||||
PYTHON=python<< parameters.python_version >>
|
||||
apt-get update
|
||||
apt-get upgrade -y
|
||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
||||
apt-get install -y apt-utils
|
||||
apt-get install -y software-properties-common
|
||||
add-apt-repository -y ppa:deadsnakes/ppa
|
||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
apt-get install -y build-essential git
|
||||
$PYTHON -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
- run:
|
||||
name: Build pacakge
|
||||
name: Upload package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py bdist_wheel
|
||||
twine upload dist/* --repository mlx
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_package:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
macos_version:
|
||||
type: string
|
||||
default: "14"
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
||||
conda activate runner-env
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install numpy
|
||||
pip install twine
|
||||
- run:
|
||||
name: Build pacakge
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py bdist_wheel
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
path: wheelhouse/
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
and:
|
||||
- matches:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test
|
||||
- mac_build_and_test
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
and:
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
filters:
|
||||
tags:
|
||||
@@ -203,21 +365,236 @@ workflows:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
upload-docs: true
|
||||
|
||||
prb:
|
||||
when:
|
||||
matches:
|
||||
pattern: "^pull/\\d+(/head)?$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- hold:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- mac_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
when: << pipeline.parameters.nightly_build >>
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.nightly_build >>
|
||||
jobs:
|
||||
- build_package:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
weekly_build:
|
||||
when: << pipeline.parameters.weekly_build >>
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.weekly_build >>
|
||||
jobs:
|
||||
- build_dev_release:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.linux_release >>
|
||||
jobs:
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report about an issue you've encountered
|
||||
title: "[BUG] "
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
|
||||
Include code snippet
|
||||
```python
|
||||
|
||||
```
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS Version: [e.g. MacOS 14.1.2]
|
||||
- Version [e.g. 0.7.0]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
2
.github/workflows/pull_request.yml
vendored
2
.github/workflows/pull_request.yml
vendored
@@ -17,4 +17,4 @@ jobs:
|
||||
pip install pre-commit black isort clang-format
|
||||
- name: Run lint
|
||||
run: |
|
||||
pre-commit run --all-files
|
||||
pre-commit run --all-files
|
||||
|
8
.gitignore
vendored
8
.gitignore
vendored
@@ -6,6 +6,10 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# tensor files
|
||||
*.safe
|
||||
*.safetensors
|
||||
|
||||
# Metal libraries
|
||||
*.metallib
|
||||
venv/
|
||||
@@ -32,6 +36,7 @@ share/python-wheels/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
uv.lock
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
@@ -72,6 +77,9 @@ build/
|
||||
*.out
|
||||
*.app
|
||||
|
||||
# Debug symbols
|
||||
*.pdb
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
.DS_Store
|
||||
|
@@ -1,16 +1,21 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v17.0.6
|
||||
rev: v19.1.7
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 22.10.0
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
- --profile=black
|
||||
- repo: https://github.com/cheshirekow/cmake-format-precommit
|
||||
rev: v0.6.13
|
||||
hooks:
|
||||
- id: cmake-format
|
||||
|
@@ -6,9 +6,23 @@ with a short description of your contribution(s) below. For example:
|
||||
- Jane Smith: Added the `foo` and `bar` ops.
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
||||
# Third-Party Software
|
||||
|
||||
@@ -245,4 +259,4 @@ Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
limitations under the License.
|
||||
|
24
CITATION.cff
Normal file
24
CITATION.cff
Normal file
@@ -0,0 +1,24 @@
|
||||
cff-version: 1.2.0
|
||||
title: mlx
|
||||
message: >-
|
||||
If you use this software, please cite it using the
|
||||
metadata from this file.
|
||||
type: software
|
||||
authors:
|
||||
- given-names: Awni
|
||||
family-names: Hannun
|
||||
affiliation: Apple
|
||||
- given-names: Jagrit
|
||||
family-names: Digani
|
||||
affiliation: Apple
|
||||
- given-names: Angelos
|
||||
family-names: Katharopoulos
|
||||
affiliation: Apple
|
||||
- given-names: Ronan
|
||||
family-names: Collobert
|
||||
affiliation: Apple
|
||||
repository-code: 'https://github.com/ml-explore'
|
||||
abstract: >-
|
||||
MLX: efficient and flexible machine learning on Apple
|
||||
silicon
|
||||
license: MIT
|
343
CMakeLists.txt
343
CMakeLists.txt
@@ -1,6 +1,24 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
cmake_minimum_required(VERSION 3.25)
|
||||
|
||||
project(mlx LANGUAGES CXX)
|
||||
if(NOT MLX_VERSION)
|
||||
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
|
||||
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_major ${CMAKE_MATCH_1})
|
||||
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_minor ${CMAKE_MATCH_1})
|
||||
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_patch ${CMAKE_MATCH_1})
|
||||
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
||||
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
||||
else()
|
||||
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
||||
${MLX_VERSION})
|
||||
endif()
|
||||
|
||||
project(
|
||||
mlx
|
||||
LANGUAGES C CXX
|
||||
VERSION ${MLX_PROJECT_VERSION})
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
@@ -15,30 +33,39 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.0.6)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
message(
|
||||
STATUS
|
||||
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
||||
)
|
||||
|
||||
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
||||
|
||||
set(MLX_BUILD_ARM OFF)
|
||||
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
|
||||
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
message(WARNING
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, "
|
||||
" make sure you are building for arm64.")
|
||||
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
set(MLX_BUILD_ARM ON)
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
if(NOT MLX_ENABLE_X64_MAC)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, check the build"
|
||||
" documentation for possible fixes: "
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
||||
)
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
@@ -50,105 +77,187 @@ cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
|
||||
if (MLX_BUILD_METAL)
|
||||
find_library(METAL_LIB Metal)
|
||||
find_library(FOUNDATION_LIB Foundation)
|
||||
find_library(QUARTZ_LIB QuartzCore)
|
||||
if(MLX_BUILD_METAL)
|
||||
set(METAL_LIB "-framework Metal")
|
||||
set(FOUNDATION_LIB "-framework Foundation")
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
elseif (MLX_BUILD_METAL)
|
||||
set(MLX_METAL_DEBUG OFF)
|
||||
elseif(MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
add_compile_definitions(_METAL_)
|
||||
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION
|
||||
COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
|
||||
else()
|
||||
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
|
||||
if(MLX_METAL_DEBUG)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
metal_cpp
|
||||
URL ${METAL_CPP_URL}
|
||||
)
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
||||
endif()
|
||||
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL
|
||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
||||
|
||||
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
mlx PUBLIC
|
||||
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>
|
||||
)
|
||||
target_link_libraries(
|
||||
mlx
|
||||
${METAL_LIB}
|
||||
${FOUNDATION_LIB}
|
||||
${QUARTZ_LIB})
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>)
|
||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||
endif()
|
||||
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
#set(BLA_VENDOR Generic)
|
||||
find_package(BLAS REQUIRED)
|
||||
if (NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
if(WIN32)
|
||||
if(MSVC)
|
||||
# GGUF does not build with MSVC.
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
# There is no prebuilt OpenBLAS distribution for MSVC.
|
||||
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
|
||||
endif()
|
||||
# TODO find a cleaner way to do this
|
||||
find_path(BLAS_INCLUDE_DIRS cblas.h
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
$ENV{BLAS_HOME}/include)
|
||||
message(STATUS ${BLAS_LIBRARIES})
|
||||
message(STATUS ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||
# Windows implementation of dlfcn.h APIs.
|
||||
FetchContent_Declare(
|
||||
dlfcn-win32
|
||||
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
||||
GIT_TAG v1.4.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
block()
|
||||
set(BUILD_SHARED_LIBS OFF)
|
||||
FetchContent_MakeAvailable(dlfcn-win32)
|
||||
endblock()
|
||||
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
|
||||
target_link_libraries(mlx PRIVATE dl)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
if(ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(MLX_USE_ACCELERATE)
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
|
||||
# Download and build OpenBLAS from source code.
|
||||
FetchContent_Declare(
|
||||
openblas
|
||||
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
|
||||
GIT_TAG v0.3.28
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(BUILD_STATIC_LIBS ON) # link statically
|
||||
set(NOFORTRAN ON) # msvc has no fortran compiler
|
||||
FetchContent_MakeAvailable(openblas)
|
||||
target_link_libraries(mlx PRIVATE openblas)
|
||||
target_include_directories(
|
||||
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
|
||||
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
|
||||
else()
|
||||
if(${CMAKE_HOST_APPLE})
|
||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||
# openblas instead.
|
||||
set(BLA_VENDOR OpenBLAS)
|
||||
set(LAPACK_ROOT
|
||||
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||
endif()
|
||||
# Search and link with lapack.
|
||||
find_package(LAPACK REQUIRED)
|
||||
if(NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
|
||||
/usr/local/opt/openblas/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old
|
||||
# version of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
if(NOT BLAS_FOUND)
|
||||
message(FATAL_ERROR "Must have BLAS installed")
|
||||
endif()
|
||||
# TODO find a cleaner way to do this
|
||||
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
|
||||
$ENV{BLAS_HOME}/include)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
message(STATUS "Downloading json")
|
||||
FetchContent_Declare(
|
||||
json
|
||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||
FetchContent_MakeAvailable(json)
|
||||
target_include_directories(
|
||||
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
||||
target_include_directories(
|
||||
mlx
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(Python COMPONENTS Interpreter Development)
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
find_package(
|
||||
Python 3.8
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE nanobind_ROOT)
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_TESTS)
|
||||
if(MLX_BUILD_TESTS)
|
||||
include(CTest)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_EXAMPLES)
|
||||
if(MLX_BUILD_EXAMPLES)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_BENCHMARKS)
|
||||
if(MLX_BUILD_BENCHMARKS)
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||
endif()
|
||||
|
||||
@@ -157,32 +266,31 @@ include(GNUInstallDirs)
|
||||
|
||||
# Install library
|
||||
install(
|
||||
TARGETS mlx
|
||||
EXPORT MLXTargets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
)
|
||||
|
||||
TARGETS mlx
|
||||
EXPORT MLXTargets
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
INCLUDES
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
# Install headers
|
||||
install(
|
||||
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
COMPONENT headers
|
||||
FILES_MATCHING PATTERN "*.h"
|
||||
)
|
||||
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
COMPONENT headers
|
||||
FILES_MATCHING
|
||||
PATTERN "*.h"
|
||||
PATTERN "backend/metal/kernels.h" EXCLUDE)
|
||||
|
||||
# Install metal dependencies
|
||||
if (MLX_BUILD_METAL)
|
||||
if(MLX_BUILD_METAL)
|
||||
|
||||
# Install metal cpp
|
||||
install(
|
||||
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
||||
COMPONENT metal_cpp_source
|
||||
)
|
||||
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
||||
COMPONENT metal_cpp_source)
|
||||
|
||||
endif()
|
||||
|
||||
@@ -194,31 +302,24 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
|
||||
install(
|
||||
EXPORT MLXTargets
|
||||
FILE MLXTargets.cmake
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
)
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||
|
||||
include(CMakePackageConfigHelpers)
|
||||
|
||||
write_basic_package_version_file(
|
||||
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||
COMPATIBILITY SameMajorVersion
|
||||
VERSION ${MLX_VERSION}
|
||||
)
|
||||
VERSION ${MLX_VERSION})
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
|
||||
${MLX_CMAKE_BUILD_CONFIG}
|
||||
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
|
||||
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
||||
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
|
||||
)
|
||||
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
|
||||
MLX_CMAKE_INSTALL_MODULE_DIR)
|
||||
|
||||
install(
|
||||
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
)
|
||||
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||
|
||||
install(
|
||||
DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||
)
|
||||
install(DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||
|
@@ -5,26 +5,26 @@ possible.
|
||||
|
||||
## Pull Requests
|
||||
|
||||
1. Fork and submit pull requests to the repo.
|
||||
1. Fork and submit pull requests to the repo.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
||||
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
||||
4. If you've changed APIs, update the documentation.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||
This should install hooks for running `black` and `clang-format` to ensure
|
||||
consistent style for C++ and python code.
|
||||
|
||||
|
||||
You can also run the formatters manually as follows:
|
||||
|
||||
```
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```
|
||||
black file.py
|
||||
```
|
||||
|
||||
|
||||
```shell
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```shell
|
||||
black file.py
|
||||
```
|
||||
|
||||
or run `pre-commit run --all-files` to check all files in the repo.
|
||||
|
||||
## Issues
|
||||
|
@@ -1,3 +1,6 @@
|
||||
include CMakeLists.txt
|
||||
include mlx.pc.in
|
||||
recursive-include mlx/ *
|
||||
include cmake/*
|
||||
include python/src/*
|
||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||
|
32
README.md
32
README.md
@@ -6,15 +6,17 @@
|
||||
|
||||
[](https://circleci.com/gh/ml-explore/mlx)
|
||||
|
||||
MLX is an array framework for machine learning on Apple silicon, brought to you
|
||||
by Apple machine learning research.
|
||||
MLX is an array framework for machine learning on Apple silicon,
|
||||
brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
|
||||
MLX also has a fully featured C++ API, which closely mirrors the Python API.
|
||||
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
|
||||
that closely follow PyTorch to simplify building more complex models.
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||
more complex models.
|
||||
|
||||
- **Composable function transformations**: MLX supports composable function
|
||||
transformations for automatic differentiation, automatic vectorization,
|
||||
@@ -53,7 +55,7 @@ variety of examples, including:
|
||||
|
||||
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
|
||||
- Large-scale text generation with
|
||||
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and
|
||||
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
|
||||
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
|
||||
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
|
||||
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
|
||||
@@ -61,31 +63,39 @@ variety of examples, including:
|
||||
## Quickstart
|
||||
|
||||
See the [quick start
|
||||
guide](https://ml-explore.github.io/mlx/build/html/quick_start.html)
|
||||
guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)
|
||||
in the documentation.
|
||||
|
||||
## Installation
|
||||
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||
|
||||
**With `pip`**:
|
||||
|
||||
```
|
||||
pip install mlx
|
||||
```
|
||||
|
||||
**With `conda`**:
|
||||
|
||||
```
|
||||
conda install -c conda-forge mlx
|
||||
```
|
||||
|
||||
Checkout the
|
||||
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
||||
for more information on building the C++ and Python APIs from source.
|
||||
|
||||
## Contributing
|
||||
|
||||
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
||||
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
||||
on contributing to MLX. See the
|
||||
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
||||
information on building from source, and running tests.
|
||||
|
||||
We are grateful for all of [our
|
||||
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||
to MLX and wish to be acknowledged, please add your name to to the list in your
|
||||
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||
to MLX and wish to be acknowledged, please add your name to the list in your
|
||||
pull request.
|
||||
|
||||
## Citing MLX
|
||||
|
@@ -5,35 +5,35 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_value_and_grad() {
|
||||
auto x = ones({200, 1000});
|
||||
eval(x);
|
||||
auto fn = [](array x) {
|
||||
auto x = mx::ones({200, 1000});
|
||||
mx::eval(x);
|
||||
auto fn = [](mx::array x) {
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
x = log(exp(x));
|
||||
x = mx::log(mx::exp(x));
|
||||
}
|
||||
return sum(x);
|
||||
return mx::sum(x);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(fn);
|
||||
auto grad_fn = mx::grad(fn);
|
||||
auto independent_value_and_grad = [&]() {
|
||||
auto value = fn(x);
|
||||
auto dfdx = grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(independent_value_and_grad);
|
||||
|
||||
auto value_and_grad_fn = value_and_grad(fn);
|
||||
auto value_and_grad_fn = mx::value_and_grad(fn);
|
||||
auto combined_value_and_grad = [&]() {
|
||||
auto [value, dfdx] = value_and_grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
return std::vector<mx::array>{value, dfdx};
|
||||
};
|
||||
TIME(combined_value_and_grad);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_value_and_grad();
|
||||
}
|
||||
|
@@ -4,21 +4,21 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_add_op() {
|
||||
std::vector<int> sizes(1, 1);
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
sizes.push_back(10 * sizes.back());
|
||||
}
|
||||
set_default_device(Device::cpu);
|
||||
set_default_device(mx::Device::cpu);
|
||||
for (auto size : sizes) {
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
std::cout << "Size " << size << std::endl;
|
||||
TIMEM("cpu", add, a, b, Device::cpu);
|
||||
TIMEM("gpu", add, a, b, Device::gpu);
|
||||
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
|
||||
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -6,105 +6,105 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_irregular_binary_ops_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
auto a = mx::random::uniform({size});
|
||||
auto b = mx::random::uniform({size});
|
||||
mx::eval(a, b);
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
b = slice(b, {0}, {size}, {step});
|
||||
TIMEM("1D strided", add, a, b, device);
|
||||
TIMEM("1D strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
auto a = random::uniform({size, size});
|
||||
auto b = random::uniform({size, size});
|
||||
eval(a, b);
|
||||
TIMEM("2D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({size, size});
|
||||
auto b = mx::random::uniform({size, size});
|
||||
mx::eval(a, b);
|
||||
TIMEM("2D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b);
|
||||
eval(b);
|
||||
TIMEM("2D transpose", add, a, b, device);
|
||||
b = mx::transpose(b);
|
||||
mx::eval(b);
|
||||
TIMEM("2D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({size});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({size});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = reshape(b, {size, 1});
|
||||
eval(b);
|
||||
TIMEM("2D broadcast dim 1", add, a, b, device);
|
||||
b = mx::reshape(b, {size, 1});
|
||||
mx::eval(b);
|
||||
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_3D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int d0 = 32;
|
||||
int d1 = 512;
|
||||
int d2 = 512;
|
||||
auto a = random::uniform({d0, d1, d2});
|
||||
auto b = random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", add, a, b, device);
|
||||
auto a = mx::random::uniform({d0, d1, d2});
|
||||
auto b = mx::random::uniform({d0, d1, d2});
|
||||
TIMEM("3D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 2, 1});
|
||||
TIMEM("3D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 2, 1});
|
||||
TIMEM("3D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", add, a, b, device);
|
||||
b = mx::random::uniform({d1, d2});
|
||||
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, d2});
|
||||
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, d1, 1});
|
||||
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", add, a, b, device);
|
||||
b = mx::random::uniform({d2});
|
||||
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d1, 1});
|
||||
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
|
||||
|
||||
b = random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
|
||||
b = mx::random::uniform({d0, 1, 1});
|
||||
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_irregular_binary_ops_4D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape = {8, 8, 512, 512};
|
||||
auto a = random::uniform(shape);
|
||||
auto b = random::uniform(shape);
|
||||
auto a = mx::random::uniform(shape);
|
||||
auto b = mx::random::uniform(shape);
|
||||
|
||||
TIMEM("4D regular", add, a, b, device);
|
||||
TIMEM("4D regular", mx::add, a, b, device);
|
||||
|
||||
b = transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D transpose", add, a, b, device);
|
||||
b = mx::transpose(b, {0, 1, 3, 2});
|
||||
TIMEM("4D mx::transpose", mx::add, a, b, device);
|
||||
|
||||
std::string om = "4D broadcast dims ";
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = 1;
|
||||
b = random::uniform(shape);
|
||||
b = mx::random::uniform(shape);
|
||||
std::ostringstream msg;
|
||||
msg << om << i;
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
|
||||
for (int j = i + 1; j < shape.size(); ++j) {
|
||||
shape[j] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[j] = a.shape(j);
|
||||
|
||||
for (int k = j + 1; k < shape.size(); ++k) {
|
||||
shape[k] = 1;
|
||||
std::ostringstream msg;
|
||||
msg << om << i << ", " << j << ", " << k;
|
||||
b = random::uniform(shape);
|
||||
TIMEM(msg.str(), add, a, b, device);
|
||||
b = mx::random::uniform(shape);
|
||||
TIMEM(msg.str(), mx::add, a, b, device);
|
||||
shape[k] = a.shape(k);
|
||||
}
|
||||
}
|
||||
@@ -113,83 +113,83 @@ void time_irregular_binary_ops_4D() {
|
||||
}
|
||||
|
||||
void time_irregular_reshape() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
std::vector<int> shape;
|
||||
auto reshape_fn = [&shape, device](const array& a) {
|
||||
return reshape(a, shape, device);
|
||||
auto reshape_fn = [&shape, device](const mx::array& a) {
|
||||
return mx::reshape(a, shape, device);
|
||||
};
|
||||
|
||||
int size = 64;
|
||||
int d = 2 * size;
|
||||
|
||||
auto a = random::uniform({d, d, d});
|
||||
auto a = mx::random::uniform({d, d, d});
|
||||
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D contiguous", reshape_fn, a);
|
||||
|
||||
a = transpose(a);
|
||||
a = mx::transpose(a);
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose", reshape_fn, a);
|
||||
|
||||
a = transpose(a, {1, 2, 0});
|
||||
a = mx::transpose(a, {1, 2, 0});
|
||||
shape = {8 * size, size, size};
|
||||
TIMEM("3D transpose dims 1 2", reshape_fn, a);
|
||||
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
||||
|
||||
a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
|
||||
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
|
||||
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
||||
}
|
||||
|
||||
void time_irregular_astype_1D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 1000000;
|
||||
int step = 2;
|
||||
auto a = random::uniform({size});
|
||||
auto a = mx::random::uniform({size});
|
||||
a = slice(a, {0}, {size}, {step});
|
||||
TIMEM("1D strided", astype, a, int32, device);
|
||||
TIMEM("1D strided", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
void time_irregular_astype_2D() {
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
int size = 2048;
|
||||
std::vector<int> shape = {size, size};
|
||||
|
||||
auto a = random::uniform(shape);
|
||||
TIMEM("2D regular", astype, a, int32, device);
|
||||
auto a = mx::random::uniform(shape);
|
||||
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = transpose(a);
|
||||
TIMEM("2D transpose", astype, a, int32, device);
|
||||
a = mx::transpose(a);
|
||||
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size}), shape);
|
||||
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
|
||||
|
||||
a = broadcast_to(random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", astype, a, int32, device);
|
||||
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
|
||||
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 1) {
|
||||
bool use_gpu = !strcmp(argv[1], "gpu");
|
||||
set_default_device(use_gpu ? Device::gpu : Device::cpu);
|
||||
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
|
||||
}
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_irregular_binary_ops_1D();
|
||||
time_irregular_binary_ops_2D();
|
||||
time_irregular_binary_ops_3D();
|
||||
|
@@ -3,20 +3,20 @@
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
namespace mx = mlx::core;
|
||||
|
||||
void time_creation_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto full_fp32 = [&]() { return full(shape, 3.3f); };
|
||||
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
|
||||
TIME(full_fp32);
|
||||
auto zeros_fp32 = [&]() { return zeros(shape, float32); };
|
||||
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
|
||||
TIME(zeros_fp32);
|
||||
auto ones_fp32 = [&]() { return ones(shape, float32); };
|
||||
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
|
||||
TIME(ones_fp32);
|
||||
|
||||
auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); };
|
||||
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
|
||||
TIME(arange_fp32);
|
||||
}
|
||||
|
||||
@@ -24,188 +24,196 @@ void time_type_conversions() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto shape = {M, N};
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = zeros(shape, float32);
|
||||
eval(a);
|
||||
TIMEM("float32 to int32", astype, a, int32, device);
|
||||
TIMEM("float32 to uint32", astype, a, uint32, device);
|
||||
auto a = mx::zeros(shape, mx::float32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
|
||||
a = zeros(shape, int32);
|
||||
eval(a);
|
||||
TIMEM("int32 to float32", astype, a, float32, device);
|
||||
a = mx::zeros(shape, mx::int32);
|
||||
mx::eval(a);
|
||||
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
|
||||
|
||||
a = zeros(shape, bool_);
|
||||
eval(a);
|
||||
TIMEM("bool to float32", astype, a, float32, device);
|
||||
TIMEM("bool to int32", astype, a, int32, device);
|
||||
TIMEM("bool to uint32", astype, a, uint32, device);
|
||||
a = mx::zeros(shape, mx::bool_);
|
||||
mx::eval(a);
|
||||
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
|
||||
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
|
||||
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
|
||||
}
|
||||
|
||||
void time_random_generation() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
|
||||
auto uniform = [&]() { return random::uniform({M, N}, float32); };
|
||||
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
|
||||
TIME(uniform);
|
||||
auto normal = [&]() { return random::normal({M, N}, float32); };
|
||||
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
|
||||
TIME(normal);
|
||||
}
|
||||
|
||||
void time_unary_ops() {
|
||||
int M = 2000;
|
||||
int N = 500;
|
||||
auto device = default_device();
|
||||
auto device = mx::default_device();
|
||||
|
||||
auto a = random::normal({M, N});
|
||||
eval(a);
|
||||
auto a = mx::random::normal({M, N});
|
||||
mx::eval(a);
|
||||
TIME(mlx::core::abs, a, device);
|
||||
TIME(negative, a, device);
|
||||
TIME(sign, a, device);
|
||||
TIME(square, a, device);
|
||||
TIME(mx::negative, a, device);
|
||||
TIME(mx::sign, a, device);
|
||||
TIME(mx::square, a, device);
|
||||
TIME(mlx::core::sqrt, a, device);
|
||||
TIME(rsqrt, a, device);
|
||||
TIME(mx::rsqrt, a, device);
|
||||
TIME(mlx::core::exp, a, device);
|
||||
|
||||
a = random::uniform({M, N});
|
||||
a = mx::random::uniform({M, N});
|
||||
TIME(mlx::core::log, a, device);
|
||||
}
|
||||
|
||||
void time_binary_ops() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
auto condition = mx::random::randint(0, 2, {M, N, K});
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
|
||||
TIME(add, a, b, device);
|
||||
TIME(subtract, a, b, device);
|
||||
TIME(multiply, a, b, device);
|
||||
TIME(divide, a, b, device);
|
||||
TIME(maximum, a, b, device);
|
||||
TIME(minimum, a, b, device);
|
||||
TIME(mx::add, a, b, device);
|
||||
TIME(mx::subtract, a, b, device);
|
||||
TIME(mx::multiply, a, b, device);
|
||||
TIME(mx::divide, a, b, device);
|
||||
TIME(mx::maximum, a, b, device);
|
||||
TIME(mx::minimum, a, b, device);
|
||||
TIME(mx::where, condition, a, b, device);
|
||||
|
||||
b = random::uniform({1});
|
||||
eval(b);
|
||||
TIMEM("scalar", add, a, b, device);
|
||||
TIMEM("vector-scalar", subtract, a, b, device);
|
||||
TIMEM("scalar-vector", subtract, b, a, device);
|
||||
TIMEM("scalar", multiply, a, b, device);
|
||||
TIMEM("vector-scalar", divide, a, b, device);
|
||||
TIMEM("scalar-vector", divide, b, a, device);
|
||||
condition = mx::array({true});
|
||||
b = mx::random::uniform({1});
|
||||
mx::eval(b);
|
||||
TIMEM("scalar", mx::add, a, b, device);
|
||||
TIMEM("vector-scalar", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-vector", mx::subtract, b, a, device);
|
||||
TIMEM("scalar", mx::multiply, a, b, device);
|
||||
TIMEM("vector-scalar", mx::divide, a, b, device);
|
||||
TIMEM("scalar-vector", mx::divide, b, a, device);
|
||||
TIMEM("scalar-vector", mx::where, condition, a, b, device);
|
||||
|
||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||
eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
|
||||
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
||||
mx::eval(a, b);
|
||||
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
|
||||
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
|
||||
}
|
||||
|
||||
void time_strided_ops() {
|
||||
int M = 50, N = 50, O = 50, P = 50;
|
||||
auto a = random::uniform({M, N, O, P});
|
||||
auto b = random::uniform({M, N, O, P});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIMEM("non-strided", add, a, b, device);
|
||||
a = transpose(a, {1, 0, 2, 3});
|
||||
b = transpose(b, {3, 2, 0, 1});
|
||||
eval(a, b);
|
||||
TIMEM("strided", add, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, O, P});
|
||||
auto b = mx::random::uniform({M, N, O, P});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIMEM("non-strided", mx::add, a, b, device);
|
||||
a = mx::transpose(a, {1, 0, 2, 3});
|
||||
b = mx::transpose(b, {3, 2, 0, 1});
|
||||
mx::eval(a, b);
|
||||
TIMEM("strided", mx::add, a, b, device);
|
||||
}
|
||||
|
||||
void time_comparisons() {
|
||||
int M = 1000, N = 100, K = 10;
|
||||
auto a = random::uniform({M, N, K});
|
||||
auto b = random::uniform({M, N, K});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(equal, a, b, device);
|
||||
TIME(greater, a, b, device);
|
||||
TIME(greater_equal, a, b, device);
|
||||
TIME(less, a, b, device);
|
||||
TIME(less_equal, a, b, device);
|
||||
auto a = mx::random::uniform({M, N, K});
|
||||
auto b = mx::random::uniform({M, N, K});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::equal, a, b, device);
|
||||
TIME(mx::greater, a, b, device);
|
||||
TIME(mx::greater_equal, a, b, device);
|
||||
TIME(mx::less, a, b, device);
|
||||
TIME(mx::less_equal, a, b, device);
|
||||
}
|
||||
|
||||
void time_matvec() {
|
||||
int M = 2000, N = 200;
|
||||
auto a = random::uniform({M, N});
|
||||
auto b = random::uniform({N});
|
||||
auto c = random::uniform({M});
|
||||
eval(a, b, c);
|
||||
auto matvec = [&]() { return matmul(a, b); };
|
||||
auto a = mx::random::uniform({M, N});
|
||||
auto b = mx::random::uniform({N});
|
||||
auto c = mx::random::uniform({M});
|
||||
mx::eval(a, b, c);
|
||||
auto matvec = [&]() { return mx::matmul(a, b); };
|
||||
TIME(matvec);
|
||||
|
||||
auto matvec_transpose = [&]() { return matmul(transpose(a), c); };
|
||||
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
|
||||
TIME(matvec_transpose);
|
||||
}
|
||||
|
||||
void time_matmul() {
|
||||
int M = 1000, N = 1000, K = 1000;
|
||||
auto a = random::uniform({M, K});
|
||||
auto b = random::uniform({K, N});
|
||||
auto device = default_device();
|
||||
eval(a, b);
|
||||
TIME(matmul, a, b, device);
|
||||
auto a = mx::random::uniform({M, K});
|
||||
auto b = mx::random::uniform({K, N});
|
||||
auto device = mx::default_device();
|
||||
mx::eval(a, b);
|
||||
TIME(mx::matmul, a, b, device);
|
||||
|
||||
auto transpose_matmul = [&]() { return matmul(transpose(a), b); };
|
||||
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
|
||||
TIME(transpose_matmul);
|
||||
}
|
||||
|
||||
void time_reductions() {
|
||||
auto a = random::normal({10000, 1000});
|
||||
eval(a);
|
||||
auto sum_all = [&a]() { return sum(a, false); };
|
||||
auto a = mx::random::normal({10000, 1000});
|
||||
mx::eval(a);
|
||||
auto sum_all = [&a]() { return mx::sum(a, false); };
|
||||
TIME(sum_all);
|
||||
|
||||
auto sum_along_0 = [&a]() { return sum(a, 0, false); };
|
||||
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
|
||||
TIME(sum_along_0);
|
||||
|
||||
auto sum_along_1 = [&a]() { return sum(a, 1, false); };
|
||||
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
|
||||
TIME(sum_along_1);
|
||||
|
||||
auto prod_all = [&a]() { return prod(a, false); };
|
||||
auto prod_all = [&a]() { return mx::prod(a, false); };
|
||||
TIME(prod_all);
|
||||
|
||||
auto all_true = [&a]() { return all(a, false); };
|
||||
auto all_true = [&a]() { return mx::all(a, false); };
|
||||
TIME(all_true);
|
||||
|
||||
auto all_along_0 = [&a]() { return all(a, 0, false); };
|
||||
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
|
||||
TIME(all_along_0);
|
||||
|
||||
auto all_along_1 = [&a]() { return all(a, 1, false); };
|
||||
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
|
||||
TIME(all_along_1);
|
||||
|
||||
auto any_true = [&a]() { return any(a, false); };
|
||||
auto any_true = [&a]() { return mx::any(a, false); };
|
||||
TIME(any_true);
|
||||
|
||||
auto argmin_along_0 = [&a]() { return argmin(a, 0, false); };
|
||||
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
|
||||
TIME(argmin_along_0);
|
||||
|
||||
auto argmin_along_1 = [&a]() { return argmin(a, 1, false); };
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
auto a = random::normal({1000, 768});
|
||||
eval(a);
|
||||
auto indices = random::randint(0, 1000, {256});
|
||||
eval(indices);
|
||||
auto a = mx::random::normal({1000, 768});
|
||||
mx::eval(a);
|
||||
auto indices = mx::random::randint(0, 1000, {256});
|
||||
mx::eval(indices);
|
||||
|
||||
auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); };
|
||||
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
|
||||
TIME(embedding_lookup);
|
||||
|
||||
indices = random::randint(0, 768 * 1000, {256 * 768});
|
||||
eval(indices);
|
||||
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
|
||||
mx::eval(indices);
|
||||
|
||||
auto single_element_lookup = [&a, &indices]() { return take(a, indices); };
|
||||
auto single_element_lookup = [&a, &indices]() {
|
||||
return mx::take(a, indices);
|
||||
};
|
||||
TIME(single_element_lookup);
|
||||
|
||||
indices = random::randint(0, 1000, {256});
|
||||
auto updates = random::normal({256, 1, 768});
|
||||
eval(indices, updates);
|
||||
indices = mx::random::randint(0, 1000, {256});
|
||||
auto updates = mx::random::normal({256, 1, 768});
|
||||
mx::eval(indices, updates);
|
||||
|
||||
auto embedding_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -217,10 +225,10 @@ void time_gather_scatter() {
|
||||
};
|
||||
TIME(embedding_add);
|
||||
|
||||
a = reshape(a, {-1});
|
||||
indices = random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = random::normal({256 * 768, 1});
|
||||
eval(a, indices, updates);
|
||||
a = mx::reshape(a, {-1});
|
||||
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
|
||||
updates = mx::random::normal({256 * 768, 1});
|
||||
mx::eval(a, indices, updates);
|
||||
|
||||
auto single_element_update = [&a, &indices, &updates]() {
|
||||
return scatter(a, indices, updates, 0);
|
||||
@@ -233,8 +241,22 @@ void time_gather_scatter() {
|
||||
TIME(single_element_add);
|
||||
}
|
||||
|
||||
void time_divmod() {
|
||||
auto a = mx::random::normal({1000});
|
||||
auto b = mx::random::normal({1000});
|
||||
mx::eval({a, b});
|
||||
|
||||
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
|
||||
TIME(divmod_fused);
|
||||
|
||||
auto divmod_separate = [&a, &b]() {
|
||||
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
|
||||
};
|
||||
TIME(divmod_separate);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
||||
time_creation_ops();
|
||||
time_type_conversions();
|
||||
time_unary_ops();
|
||||
@@ -246,4 +268,5 @@ int main() {
|
||||
time_matmul();
|
||||
time_reductions();
|
||||
time_gather_scatter();
|
||||
time_divmod();
|
||||
}
|
||||
|
@@ -17,14 +17,13 @@
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " \
|
||||
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
|
||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||
<< std::endl;
|
||||
#define TIMEM(MSG, FUNC, ...) \
|
||||
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
|
||||
<< std::flush << std::setprecision(5) \
|
||||
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
|
||||
|
||||
template <typename F, typename... Args>
|
||||
double time_fn(F fn, Args... args) {
|
||||
double time_fn(F fn, Args&&... args) {
|
||||
// warmup
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
eval(fn(std::forward<Args>(args)...));
|
||||
|
@@ -166,13 +166,13 @@ if __name__ == "__main__":
|
||||
dtypes = ("float32", "float16")
|
||||
transposes = ("nn", "nt", "tn")
|
||||
shapes = (
|
||||
(16, 234, 768, 3072),
|
||||
(1, 64, 64, 25344),
|
||||
(16, 1024, 1024, 1024),
|
||||
(1, 1024, 1024, 2048),
|
||||
(4, 1024, 1024, 4096),
|
||||
(4, 1024, 4096, 1024),
|
||||
(1, 4096, 4096, 4096),
|
||||
(15, 1023, 1023, 1023),
|
||||
(17, 1025, 1025, 1025),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
|
@@ -133,7 +133,7 @@ def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
|
||||
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
|
||||
|
||||
|
||||
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
||||
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
|
||||
np_dtype = getattr(np, dtype)
|
||||
mlx_gb_s = []
|
||||
mlx_gflops = []
|
||||
@@ -164,7 +164,7 @@ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
||||
ax.legend()
|
||||
|
||||
|
||||
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
|
||||
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
||||
np_dtype = getattr(np, dtype)
|
||||
mlx_gb_s = []
|
||||
mlx_gflops = []
|
||||
|
@@ -4,6 +4,7 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -59,15 +60,63 @@ def matmul(x, y):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def quant_matmul(x, w, s, b):
|
||||
groups = x.shape[-1] // s.shape[-1]
|
||||
width = 32 // (x.shape[-1] // w.shape[0])
|
||||
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width))
|
||||
ys.append(
|
||||
mx.quantized_matmul(
|
||||
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
|
||||
)
|
||||
)
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
quant_matmul = {
|
||||
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
|
||||
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
|
||||
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
|
||||
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
|
||||
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
|
||||
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
|
||||
"quant_matmul_128_2": partial(
|
||||
_quant_matmul, transpose=False, group_size=128, bits=2
|
||||
),
|
||||
"quant_matmul_128_4": partial(
|
||||
_quant_matmul, transpose=False, group_size=128, bits=4
|
||||
),
|
||||
"quant_matmul_128_8": partial(
|
||||
_quant_matmul, transpose=False, group_size=128, bits=8
|
||||
),
|
||||
"quant_matmul_t_32_2": partial(
|
||||
_quant_matmul, transpose=True, group_size=32, bits=2
|
||||
),
|
||||
"quant_matmul_t_32_4": partial(
|
||||
_quant_matmul, transpose=True, group_size=32, bits=4
|
||||
),
|
||||
"quant_matmul_t_32_8": partial(
|
||||
_quant_matmul, transpose=True, group_size=32, bits=8
|
||||
),
|
||||
"quant_matmul_t_64_2": partial(
|
||||
_quant_matmul, transpose=True, group_size=64, bits=2
|
||||
),
|
||||
"quant_matmul_t_64_4": partial(
|
||||
_quant_matmul, transpose=True, group_size=64, bits=4
|
||||
),
|
||||
"quant_matmul_t_64_8": partial(
|
||||
_quant_matmul, transpose=True, group_size=64, bits=8
|
||||
),
|
||||
"quant_matmul_t_128_2": partial(
|
||||
_quant_matmul, transpose=True, group_size=128, bits=2
|
||||
),
|
||||
"quant_matmul_t_128_4": partial(
|
||||
_quant_matmul, transpose=True, group_size=128, bits=4
|
||||
),
|
||||
"quant_matmul_t_128_8": partial(
|
||||
_quant_matmul, transpose=True, group_size=128, bits=8
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def conv1d(x, y):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
@@ -95,6 +144,13 @@ def reduction(op, axis, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
mx.eval(z)
|
||||
|
||||
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
for i in range(100):
|
||||
@@ -220,6 +276,13 @@ def linear(w, b, x):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def linear_fused(w, b, x):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def rope(x):
|
||||
*_, N, D = x.shape
|
||||
ys = []
|
||||
@@ -324,10 +387,6 @@ if __name__ == "__main__":
|
||||
if len(args.axis) > 1:
|
||||
args.axis.pop(0)
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
else:
|
||||
@@ -350,17 +409,24 @@ if __name__ == "__main__":
|
||||
x = xs[0]
|
||||
axis = args.axis[0]
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
if args.benchmark == "matmul_square":
|
||||
print(bench(matmul_square, x))
|
||||
|
||||
elif args.benchmark == "matmul":
|
||||
print(bench(matmul, *xs))
|
||||
|
||||
elif args.benchmark == "quant_matmul":
|
||||
print(bench(quant_matmul, *xs))
|
||||
elif args.benchmark.startswith("quant_matmul"):
|
||||
print(bench(quant_matmul[args.benchmark], *xs))
|
||||
|
||||
elif args.benchmark == "linear":
|
||||
print(bench(linear, *xs))
|
||||
if args.fused:
|
||||
print(bench(linear_fused, *xs))
|
||||
else:
|
||||
print(bench(linear, *xs))
|
||||
|
||||
elif args.benchmark == "sum_axis":
|
||||
print(bench(reduction, "sum", axis, x))
|
||||
@@ -446,5 +512,8 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
|
@@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
|
||||
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x
|
||||
for _ in range(100):
|
||||
return torch.nn.functional.mish(y)
|
||||
y = torch.nn.functional.mish(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@@ -283,6 +283,14 @@ def topk(axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step_function(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.where(y < 0, 0, 1)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def selu(x):
|
||||
y = x
|
||||
@@ -331,10 +339,6 @@ if __name__ == "__main__":
|
||||
if len(args.axis) > 1:
|
||||
args.axis.pop(0)
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = "cpu" if args.cpu else "mps"
|
||||
|
||||
@@ -354,6 +358,10 @@ if __name__ == "__main__":
|
||||
x = xs[0]
|
||||
axis = args.axis[0]
|
||||
|
||||
if args.print_pid:
|
||||
print(os.getpid())
|
||||
input("Press enter to run")
|
||||
|
||||
if args.benchmark == "matmul_square":
|
||||
print(bench(matmul_square, x))
|
||||
|
||||
@@ -446,5 +454,11 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "topk":
|
||||
print(bench(topk, axis, x))
|
||||
|
||||
elif args.benchmark == "step":
|
||||
print(bench(step_function, x))
|
||||
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
|
@@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
|
||||
result = run(*args, capture_output=True, **kwargs)
|
||||
return float(result.stdout)
|
||||
except ValueError:
|
||||
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
|
||||
raise ValueError(
|
||||
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
|
||||
)
|
||||
|
||||
|
||||
def compare(args):
|
||||
@@ -62,7 +64,7 @@ def make_predicate(positive_filter, negative_filter):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch")
|
||||
parser = argparse.ArgumentParser(description="Run comparisons against PyTorch")
|
||||
parser.add_argument(
|
||||
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
|
||||
)
|
||||
@@ -80,10 +82,8 @@ if __name__ == "__main__":
|
||||
_filter = make_predicate(args.filter, args.negative_filter)
|
||||
|
||||
if args.mlx_dtypes:
|
||||
compare_filtered = (
|
||||
lambda x: compare_mlx_dtypes(
|
||||
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
|
||||
)
|
||||
compare_filtered = lambda x: (
|
||||
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
|
||||
if _filter(x)
|
||||
else None
|
||||
)
|
||||
@@ -125,6 +125,14 @@ if __name__ == "__main__":
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
|
||||
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
||||
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
||||
|
107
benchmarks/python/compile_bench.py
Normal file
107
benchmarks/python/compile_bench.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import random
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def bench_gelu():
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
x = mx.random.uniform(shape=(1000, 1024))
|
||||
|
||||
def gen_fun(fun):
|
||||
def bench_fun(x):
|
||||
for _ in range(10):
|
||||
x = fun(x)
|
||||
return x
|
||||
|
||||
return bench_fun
|
||||
|
||||
time_fn(gen_fun(gelu), x, msg="fixed gelu")
|
||||
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
|
||||
|
||||
def randint():
|
||||
return random.randint(1, x.shape[0])
|
||||
|
||||
def gen_fun(fun):
|
||||
def bench_fun(x, y):
|
||||
x = x[: randint()]
|
||||
for _ in range(10):
|
||||
x = fun(x)
|
||||
y = fun(y)
|
||||
return x, y
|
||||
|
||||
return bench_fun
|
||||
|
||||
y = mx.random.uniform(shape=(1000, 1024))
|
||||
time_fn(gen_fun(gelu), x, y, msg="variable gelu")
|
||||
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
|
||||
time_fn(
|
||||
gen_fun(mx.compile(gelu, shapeless=True)),
|
||||
x,
|
||||
y,
|
||||
msg="shapeless variable gelu",
|
||||
)
|
||||
|
||||
|
||||
def bench_layernorm():
|
||||
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
mx.eval(weight, bias)
|
||||
|
||||
def layernorm(x):
|
||||
x = x.astype(mx.float32)
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + 1e-4)
|
||||
x = x.astype(mx.float16)
|
||||
return weight * x + bias
|
||||
|
||||
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
|
||||
|
||||
def gen_fun(fun):
|
||||
def bench_fun(x):
|
||||
for _ in range(10):
|
||||
x = fun(x)
|
||||
return x
|
||||
|
||||
return bench_fun
|
||||
|
||||
time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
|
||||
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
|
||||
|
||||
def randint():
|
||||
return random.randint(1, x.shape[0])
|
||||
|
||||
def gen_fun(fun):
|
||||
def bench_fun(x):
|
||||
x = x[: randint()]
|
||||
for _ in range(10):
|
||||
x = fun(x)
|
||||
return x
|
||||
|
||||
return bench_fun
|
||||
|
||||
random.seed(0)
|
||||
time_fn(gen_fun(layernorm), x, msg="variable layernorm")
|
||||
random.seed(0)
|
||||
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
|
||||
random.seed(0)
|
||||
time_fn(
|
||||
gen_fun(mx.compile(layernorm, shapeless=True)),
|
||||
x,
|
||||
msg="shapeless variable layernorm",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Compile benchmarks.")
|
||||
args = parser.parse_args()
|
||||
|
||||
bench_gelu()
|
||||
bench_layernorm()
|
123
benchmarks/python/conv1d_bench.py
Normal file
123
benchmarks/python/conv1d_bench.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_1D(strides=1, padding=0, groups=1):
|
||||
def mx_conv_1D(a, b):
|
||||
ys = []
|
||||
for _ in range(N_iter_func):
|
||||
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_1D
|
||||
|
||||
|
||||
def make_pt_conv_1D(strides=1, padding=0, groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_1D(a, b):
|
||||
ys = []
|
||||
for _ in range(N_iter_func):
|
||||
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_1D
|
||||
|
||||
|
||||
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
|
||||
scale = 1.0 / math.sqrt(wH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_1D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_1D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv1d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 5, 32, 1, 2, 1),
|
||||
(4, 32, 32, 5, 32, 1, 2, 2),
|
||||
(4, 32, 32, 5, 32, 1, 2, 4),
|
||||
(4, 32, 32, 5, 32, 1, 2, 8),
|
||||
(4, 32, 32, 5, 32, 1, 2, 8),
|
||||
(4, 32, 32, 5, 32, 1, 2, 16),
|
||||
(4, 32, 32, 5, 32, 1, 2, 32),
|
||||
(4, 32, 256, 5, 512, 1, 2, 2),
|
||||
(4, 32, 256, 5, 512, 1, 2, 128),
|
||||
(4, 32, 256, 5, 512, 1, 2, 256),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
|
||||
for N, iH, C, wH, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, iH, C, wH, O, strides, padding, np_dtype, groups
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
127
benchmarks/python/conv2d_bench_cpu.py
Normal file
127
benchmarks/python/conv2d_bench_cpu.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
143
benchmarks/python/conv2d_train_bench_cpu.py
Normal file
143
benchmarks/python/conv2d_train_bench_cpu.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn
|
||||
import mlx.optimizers as opt
|
||||
import torch
|
||||
|
||||
|
||||
def bench_mlx(steps: int = 20) -> float:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
class BenchNetMLX(mlx.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=32):
|
||||
super().__init__()
|
||||
|
||||
self.net = mlx.nn.Sequential(
|
||||
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.Conv2d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose2d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose2d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetMLX(3)
|
||||
mx.eval(benchNet.parameters())
|
||||
optim = opt.Adam(learning_rate=1e-3)
|
||||
|
||||
inputs = mx.random.normal([10, 256, 256, 3])
|
||||
|
||||
params = benchNet.parameters()
|
||||
optim.init(params)
|
||||
|
||||
state = [benchNet.state, optim.state]
|
||||
|
||||
def loss_fn(params, image):
|
||||
benchNet.update(params)
|
||||
pred_image = benchNet(image)
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
def step(params, image):
|
||||
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
||||
optim.update(benchNet, grads)
|
||||
return loss
|
||||
|
||||
total_time = 0.0
|
||||
print("MLX:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
step(benchNet.parameters(), inputs)
|
||||
mx.eval(state)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def bench_torch(steps: int = 20) -> float:
|
||||
device = torch.device("cpu")
|
||||
|
||||
class BenchNetTorch(torch.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=32):
|
||||
super().__init__()
|
||||
|
||||
self.net = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose2d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose2d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetTorch(3).to(device)
|
||||
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
||||
|
||||
inputs = torch.randn(10, 3, 256, 256, device=device)
|
||||
|
||||
def loss_fn(pred_image, image):
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
total_time = 0.0
|
||||
print("PyTorch:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
optim.zero_grad()
|
||||
pred_image = benchNet(inputs)
|
||||
loss = loss_fn(pred_image, inputs)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def main():
|
||||
steps = 20
|
||||
time_mlx = bench_mlx(steps)
|
||||
time_torch = bench_torch(steps)
|
||||
|
||||
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
||||
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
||||
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
||||
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
||||
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
129
benchmarks/python/conv2d_transpose_bench_cpu.py
Normal file
129
benchmarks/python/conv2d_transpose_bench_cpu.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
||||
)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_transpose_2D
|
||||
|
||||
|
||||
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_transpose_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv_transpose2d(
|
||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
||||
)
|
||||
out_pt = torch.conv_transpose2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
110
benchmarks/python/conv3d_bench_cpu.py
Normal file
110
benchmarks/python/conv3d_bench_cpu.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_3D
|
||||
|
||||
|
||||
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_3D
|
||||
|
||||
|
||||
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_3D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_3D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv3d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
||||
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
143
benchmarks/python/conv3d_train_bench_cpu.py
Normal file
143
benchmarks/python/conv3d_train_bench_cpu.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn
|
||||
import mlx.optimizers as opt
|
||||
import torch
|
||||
|
||||
|
||||
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
class BenchNetMLX(mlx.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=16):
|
||||
super().__init__()
|
||||
|
||||
self.net = mlx.nn.Sequential(
|
||||
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.Conv3d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose3d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
mlx.nn.ReLU(),
|
||||
mlx.nn.ConvTranspose3d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetMLX(3)
|
||||
mx.eval(benchNet.parameters())
|
||||
optim = opt.Adam(learning_rate=1e-3)
|
||||
|
||||
inputs = mx.random.normal(shape)
|
||||
|
||||
params = benchNet.parameters()
|
||||
optim.init(params)
|
||||
|
||||
state = [benchNet.state, optim.state]
|
||||
|
||||
def loss_fn(params, image):
|
||||
benchNet.update(params)
|
||||
pred_image = benchNet(image)
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
def step(params, image):
|
||||
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
||||
optim.update(benchNet, grads)
|
||||
return loss
|
||||
|
||||
total_time = 0.0
|
||||
print("MLX:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
step(benchNet.parameters(), inputs)
|
||||
mx.eval(state)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
|
||||
device = torch.device("cpu")
|
||||
|
||||
class BenchNetTorch(torch.nn.Module):
|
||||
# simple encoder-decoder net
|
||||
|
||||
def __init__(self, in_channels, hidden_channels=16):
|
||||
super().__init__()
|
||||
|
||||
self.net = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv3d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose3d(
|
||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ConvTranspose3d(
|
||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.net(input)
|
||||
|
||||
benchNet = BenchNetTorch(3).to(device)
|
||||
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
||||
|
||||
inputs = torch.randn(*shape, device=device)
|
||||
|
||||
def loss_fn(pred_image, image):
|
||||
return (pred_image - image).abs().mean()
|
||||
|
||||
total_time = 0.0
|
||||
print("PyTorch:")
|
||||
for i in range(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
optim.zero_grad()
|
||||
pred_image = benchNet(inputs)
|
||||
loss = loss_fn(pred_image, inputs)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||
total_time += (end_time - start_time) * 1000
|
||||
|
||||
return total_time
|
||||
|
||||
|
||||
def main():
|
||||
steps = 10
|
||||
time_mlx = bench_mlx(steps)
|
||||
time_torch = bench_torch(steps)
|
||||
|
||||
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
||||
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
||||
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
||||
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
||||
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
116
benchmarks/python/conv3d_transpose_bench_cpu.py
Normal file
116
benchmarks/python/conv3d_transpose_bench_cpu.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 1
|
||||
N_iter_bench = 10
|
||||
N_iter_func = 5
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
||||
def mx_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv_transpose3d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_3D
|
||||
|
||||
|
||||
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_3D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv_transpose3d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
return ys
|
||||
|
||||
return pt_conv_3D
|
||||
|
||||
|
||||
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
|
||||
|
||||
f_mx = make_mx_conv_3D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_3D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv_transpose3d(
|
||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.conv_transpose3d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
||||
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
135
benchmarks/python/conv_bench.py
Normal file
135
benchmarks/python/conv_bench.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
135
benchmarks/python/conv_transpose_bench.py
Normal file
135
benchmarks/python/conv_transpose_bench.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_transpose_2D
|
||||
|
||||
|
||||
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_transpose_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv_transpose2d(
|
||||
a, b, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_transpose_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv_transpose2d(
|
||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.conv_transpose2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
66
benchmarks/python/distributed_bench.py
Normal file
66
benchmarks/python/distributed_bench.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
Run with:
|
||||
mpirun -n 2 python /path/to/distributed_bench.py
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def time_fn(fn, *args, **kwargs):
|
||||
msg = kwargs.pop("msg", None)
|
||||
world = mx.distributed.init()
|
||||
if world.rank() == 0:
|
||||
if msg:
|
||||
print(f"Timing {msg} ...", end=" ")
|
||||
else:
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(*args, **kwargs))
|
||||
|
||||
num_iters = 100
|
||||
tic = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
x = mx.eval(fn(*args, **kwargs))
|
||||
toc = time.perf_counter()
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
if world.rank() == 0:
|
||||
print(f"{msec:.5f} msec")
|
||||
|
||||
|
||||
def time_all_sum():
|
||||
shape = (4096,)
|
||||
x = mx.random.uniform(shape=shape)
|
||||
mx.eval(x)
|
||||
|
||||
def sine(x):
|
||||
for _ in range(20):
|
||||
x = mx.sin(x)
|
||||
return x
|
||||
|
||||
time_fn(sine, x)
|
||||
|
||||
def all_sum_plain(x):
|
||||
for _ in range(20):
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
time_fn(all_sum_plain, x)
|
||||
|
||||
def all_sum_with_sine(x):
|
||||
for _ in range(20):
|
||||
x = mx.sin(x)
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
time_fn(all_sum_with_sine, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_all_sum()
|
84
benchmarks/python/einsum_bench.py
Normal file
84
benchmarks/python/einsum_bench.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def timeit(fn, its=100, args=[]):
|
||||
for _ in range(5):
|
||||
fn(*args)
|
||||
tic = time.perf_counter()
|
||||
for _ in range(its):
|
||||
fn(*args)
|
||||
toc = time.perf_counter()
|
||||
return 1e3 * (toc - tic) / its
|
||||
|
||||
|
||||
def time_little_einsum_path():
|
||||
subscripts = "ik,kj->ij"
|
||||
x = mx.ones((32, 32))
|
||||
y = mx.ones((32, 32))
|
||||
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
|
||||
|
||||
x = np.array(x)
|
||||
y = np.array(y)
|
||||
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
|
||||
print("Timing little einsum path...")
|
||||
print(f"MLX ... {mx_time:.3f} ms")
|
||||
print(f"NumPy... {np_time:.3f} ms")
|
||||
|
||||
|
||||
def time_big_einsum_path():
|
||||
chars = list("abcdefgh")
|
||||
char_to_dim = {c: v for v, c in enumerate(chars)}
|
||||
|
||||
num_inputs = 10
|
||||
inputs = []
|
||||
subscripts = []
|
||||
for _ in range(num_inputs):
|
||||
subscript = np.random.choice(chars, size=5, replace=False).tolist()
|
||||
subscripts.append("".join(subscript))
|
||||
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
|
||||
subscripts = ",".join(subscripts)
|
||||
|
||||
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
|
||||
|
||||
inputs = [mx.array(x) for x in inputs]
|
||||
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
|
||||
print("Timing big einsum path...")
|
||||
print(f"MLX ... {mx_time:.3f} ms")
|
||||
print(f"NumPy... {np_time:.3f} ms")
|
||||
|
||||
|
||||
def time_attention():
|
||||
def regular_attention(x):
|
||||
# shape [batch, sequence, num_heads, head_dim]
|
||||
queries, keys, values = x, x, x
|
||||
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
|
||||
mx.eval(output)
|
||||
|
||||
def einsum_attention(x):
|
||||
# shape [batch, sequence, num_heads, head_dim]
|
||||
queries, keys, values = x, x, x
|
||||
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
output = mx.einsum("ijtu,iujk->itjk", scores, values)
|
||||
mx.eval(output)
|
||||
|
||||
x = mx.random.uniform(shape=(8, 512, 32, 128))
|
||||
|
||||
regular_time = timeit(regular_attention, args=(x,))
|
||||
ein_time = timeit(einsum_attention, args=(x,))
|
||||
print("Timing einsum attention...")
|
||||
print(f"Regular ... {regular_time:.3f} ms")
|
||||
print(f"Einsum ... {ein_time:.3f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_little_einsum_path()
|
||||
time_big_einsum_path()
|
||||
time_attention()
|
118
benchmarks/python/fft_bench.py
Normal file
118
benchmarks/python/fft_bench.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import sympy
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def bandwidth_gb(runtime_ms, system_size):
|
||||
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
|
||||
bytes_per_gb = 1e9
|
||||
ms_per_s = 1e3
|
||||
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
||||
|
||||
|
||||
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
|
||||
def fft_mlx(x):
|
||||
if dim == 1:
|
||||
out = mx.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = mx.fft.fft2(x)
|
||||
mx.eval(out)
|
||||
return out
|
||||
|
||||
def fft_mps(x):
|
||||
if dim == 1:
|
||||
out = torch.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = torch.fft.fft2(x)
|
||||
torch.mps.synchronize()
|
||||
return out
|
||||
|
||||
bandwidths = []
|
||||
for n in fft_sizes:
|
||||
batch_size = system_size // n**dim
|
||||
shape = [batch_size] + [n for _ in range(dim)]
|
||||
if backend == "mlx":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = mx.array(x_np)
|
||||
mx.eval(x)
|
||||
fft = fft_mlx
|
||||
elif backend == "mps":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = torch.tensor(x_np, device="mps")
|
||||
torch.mps.synchronize()
|
||||
fft = fft_mps
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
|
||||
print(n, bandwidth)
|
||||
bandwidths.append(bandwidth)
|
||||
|
||||
return np.array(bandwidths)
|
||||
|
||||
|
||||
def time_fft():
|
||||
x = np.array(range(2, 512))
|
||||
system_size = int(2**26)
|
||||
|
||||
print("MLX GPU")
|
||||
with mx.stream(mx.gpu):
|
||||
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
|
||||
print("MPS GPU")
|
||||
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
|
||||
|
||||
print("CPU")
|
||||
system_size = int(2**20)
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
|
||||
x = np.array(x)
|
||||
|
||||
all_indices = x - x[0]
|
||||
radix_2to13 = (
|
||||
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
bluesteins = (
|
||||
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
|
||||
for indices, name in [
|
||||
(all_indices, "All"),
|
||||
(radix_2to13, "Radix 2-13"),
|
||||
(bluesteins, "Bluestein's"),
|
||||
]:
|
||||
# plot bandwidths
|
||||
print(name)
|
||||
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
|
||||
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
|
||||
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
|
||||
plt.title(f"MLX FFT Benchmark -- {name}")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig(f"{name}.png")
|
||||
plt.clf()
|
||||
|
||||
av_gpu_bandwidth = np.mean(gpu_bandwidths)
|
||||
av_mps_bandwidth = np.mean(mps_bandwidths)
|
||||
av_cpu_bandwidth = np.mean(cpu_bandwidths)
|
||||
print("Average bandwidths:")
|
||||
print("GPU:", av_gpu_bandwidth)
|
||||
print("MPS:", av_mps_bandwidth)
|
||||
print("CPU:", av_cpu_bandwidth)
|
||||
|
||||
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
|
||||
print("Percent MLX faster than MPS: ", portion_faster * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_fft()
|
52
benchmarks/python/gather_bench.py
Normal file
52
benchmarks/python/gather_bench.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
|
||||
def benchmark_gather_mlx(x_shape, idx_shape):
|
||||
def gather(x, idx):
|
||||
mx.eval(x[idx])
|
||||
|
||||
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
|
||||
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||
|
||||
runtime = measure_runtime(gather, x=x, idx=idx)
|
||||
print(f"MLX: {runtime:.3f}ms")
|
||||
|
||||
|
||||
def benchmark_gather_torch(x_shape, idx_shape, device):
|
||||
def gather(x, idx, device):
|
||||
_ = x[idx]
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
idx_shapes = [(1_000_000,), (100_000,), ()]
|
||||
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
|
||||
|
||||
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_gather_mlx(x_shape, idx_shape)
|
||||
benchmark_gather_torch(x_shape, idx_shape, device=device)
|
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_mm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = x @ w1.T
|
||||
x = x @ w2.T
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_mm()
|
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate(
|
||||
[
|
||||
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||
for i, j in enumerate(idx.tolist())
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_qmm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_qmm()
|
70
benchmarks/python/hadamard_bench.py
Normal file
70
benchmarks/python/hadamard_bench.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import argparse
|
||||
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def had(x):
|
||||
y = mx.hadamard_transform(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def copy(x):
|
||||
y = x + 1.0
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def run(dtype):
|
||||
system_size = 2**26
|
||||
outputs = {}
|
||||
for test_fn in (had, copy):
|
||||
for m in [1, 12, 20, 28]:
|
||||
if test_fn == copy:
|
||||
key = "copy"
|
||||
elif m == 1:
|
||||
key = "had_2^k"
|
||||
else:
|
||||
key = "had_m*2^k"
|
||||
outputs.setdefault(key, {})
|
||||
for k in range(7, 14):
|
||||
n = m * 2**k
|
||||
if n > 2**15:
|
||||
continue
|
||||
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
|
||||
x = mx.array(x_np)
|
||||
runtime_ms = measure_runtime(test_fn, x=x)
|
||||
bytes_per_gb = 1e9
|
||||
ms_per_s = 1e3
|
||||
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
|
||||
bandwidth_gb = (
|
||||
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
|
||||
)
|
||||
print(n, bandwidth_gb)
|
||||
outputs[key][n] = bandwidth_gb
|
||||
|
||||
colors = {
|
||||
"copy": "black",
|
||||
"had_2^k": "steelblue",
|
||||
"had_m*2^k": "skyblue",
|
||||
}
|
||||
for key, output in outputs.items():
|
||||
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
|
||||
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig(f"bench_{dtype.__name__}.png")
|
||||
plt.clf()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
args = parser.parse_args()
|
||||
dtype = np.float16 if args.fp16 else np.float32
|
||||
run(dtype)
|
68
benchmarks/python/layer_norm_bench.py
Normal file
68
benchmarks/python/layer_norm_bench.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def layer_norm(x, w, b, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
mu = mx.mean(x, -1, keepdims=True)
|
||||
v = mx.var(x, -1, keepdims=True)
|
||||
y = (x - mu) * mx.rsqrt(v + eps)
|
||||
if w is not None:
|
||||
y = y * w
|
||||
if b is not None:
|
||||
y = y + b
|
||||
return y
|
||||
|
||||
|
||||
def time_layer_norm():
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x, w, b):
|
||||
gx, gw, gb = x, w, b
|
||||
for _ in range(32):
|
||||
gx, gw, gb = g(gx, gw, gb, y)
|
||||
return gx, gw, gb
|
||||
|
||||
time_fn(layer_norm_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(layer_norm_loop, g1, x)
|
||||
time_fn(layer_norm_loop, g2, x)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_layer_norm()
|
@@ -1,198 +0,0 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import linen as nn
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
dims: int
|
||||
traditional: bool = False
|
||||
|
||||
def _compute_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., : self.dims // 2]
|
||||
x2 = x[..., self.dims // 2 : self.dims]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
|
||||
else:
|
||||
rx = jnp.concatenate([rx1, rx2], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
raise NotImplementedError(
|
||||
"RoPE doesn't implement partial traditional application"
|
||||
)
|
||||
|
||||
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
@staticmethod
|
||||
def create_cos_sin_theta(
|
||||
N: int,
|
||||
D: int,
|
||||
offset: int = 0,
|
||||
base: float = 10000,
|
||||
dtype=jnp.float32,
|
||||
):
|
||||
D = D // 2
|
||||
positions = jnp.arange(offset, N, dtype=dtype)
|
||||
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
|
||||
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
|
||||
costheta = jnp.cos(theta)
|
||||
sintheta = jnp.sin(theta)
|
||||
|
||||
return costheta, sintheta
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = x.reshape((-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return rx.reshape(shape)
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
dims: int
|
||||
num_heads: int
|
||||
dtype: jnp.dtype
|
||||
|
||||
def setup(self):
|
||||
num_heads = self.num_heads
|
||||
dims = self.dims
|
||||
|
||||
self.rope = RoPE(dims // num_heads, True)
|
||||
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
||||
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
||||
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = jnp.concatenate([key_cache, keys], axis=2)
|
||||
values = jnp.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = jax.nn.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
dims: int
|
||||
mlp_dims: int
|
||||
num_heads: int
|
||||
dtype: jnp.dtype
|
||||
|
||||
def setup(self):
|
||||
dims = self.dims
|
||||
mlp_dims = self.mlp_dims
|
||||
num_heads = self.num_heads
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads, dtype)
|
||||
|
||||
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
|
||||
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
|
||||
|
||||
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
|
||||
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
|
||||
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = jax.nn.silu(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
def measure(model, x, cache):
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
jax.block_until_ready((y, c))
|
||||
|
||||
start = time.time()
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
jax.block_until_ready((y, c))
|
||||
|
||||
end = time.time()
|
||||
return (end - start) * 1000 / 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
H = 32
|
||||
D = 4096
|
||||
F = 43 * 256
|
||||
C = 1000
|
||||
dtype = jnp.float16
|
||||
|
||||
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
|
||||
|
||||
x = jax.random.normal(k1, (1, 1, D), dtype)
|
||||
cache = [
|
||||
jax.random.normal(k2, [1, H, C, D // H], dtype),
|
||||
jax.random.normal(k3, [1, H, C, D // H], dtype),
|
||||
]
|
||||
|
||||
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
|
||||
params = layer.init(k4, x, mask=None, cache=cache)["params"]
|
||||
|
||||
@jax.jit
|
||||
def model_fn(x, mask, cache):
|
||||
return layer.apply({"params": params}, x, mask=mask, cache=cache)
|
||||
|
||||
T = measure(model_fn, x, cache)
|
||||
|
||||
print("Time per layer per token:", T, "ms")
|
||||
print("Lower bound total time per token:", T * 32, "ms")
|
@@ -1,118 +0,0 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.rope = nn.RoPE(dims // num_heads, True)
|
||||
self.query_proj = nn.Linear(dims, dims, False)
|
||||
self.key_proj = nn.Linear(dims, dims, False)
|
||||
self.value_proj = nn.Linear(dims, dims, False)
|
||||
self.out_proj = nn.Linear(dims, dims, False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
|
||||
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = nn.RMSNorm(dims)
|
||||
self.norm2 = nn.RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = a * mx.sigmoid(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
def measure(model, x, cache):
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
mx.eval(y, c)
|
||||
|
||||
start = time.time()
|
||||
rs = []
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
rs.append((y, c))
|
||||
mx.eval(rs)
|
||||
end = time.time()
|
||||
|
||||
return (end - start) * 1000 / 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
H = 32
|
||||
D = 4096
|
||||
F = 43 * 256
|
||||
C = 1000
|
||||
mx.set_default_device(mx.gpu)
|
||||
dtype = mx.float16
|
||||
|
||||
layer = LlamaEncoderLayer(D, F, H)
|
||||
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
|
||||
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
|
||||
x = mx.random.normal([1, 1, D], dtype=dtype)
|
||||
cache = [
|
||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
||||
]
|
||||
mx.eval(x, cache)
|
||||
|
||||
T = measure(layer, x, cache)
|
||||
|
||||
print("Time per layer per token:", T, "ms")
|
||||
print("Lower bound total time per token:", T * 32, "ms")
|
@@ -1,199 +0,0 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.mps
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
if x.device != torch.device("cpu"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
def __init__(self, dims: int, traditional: bool = False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
|
||||
def _compute_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., : self.dims // 2]
|
||||
x2 = x[..., self.dims // 2 : self.dims]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
|
||||
else:
|
||||
rx = torch.cat([rx1, rx2], dim=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
raise NotImplementedError(
|
||||
"RoPE doesn't implement partial traditional application"
|
||||
)
|
||||
|
||||
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def forward(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = x.view(-1, shape[-2], shape[-1])
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return rx.view(*shape)
|
||||
|
||||
@staticmethod
|
||||
def create_cos_sin_theta(
|
||||
N: int,
|
||||
D: int,
|
||||
offset: int = 0,
|
||||
base: float = 10000,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
):
|
||||
D = D // 2
|
||||
positions = torch.arange(offset, N, dtype=dtype, device=device)
|
||||
freqs = torch.exp(
|
||||
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
|
||||
)
|
||||
theta = positions.view(-1, 1) * freqs.view(1, -1)
|
||||
costheta = torch.cos(theta)
|
||||
sintheta = torch.sin(theta)
|
||||
|
||||
return costheta, sintheta
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, epsilon: float = 1e-6):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones((dims,)))
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
|
||||
return self.gamma * x * n
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.rope = RoPE(dims // num_heads, True)
|
||||
self.query_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.key_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.value_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.out_proj = nn.Linear(dims, dims, bias=False)
|
||||
|
||||
def forward(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = torch.cat([key_cache, keys], dim=2)
|
||||
values = torch.cat([value_cache, values], dim=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = torch.softmax(scores, dim=-1)
|
||||
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = RMSNorm(dims)
|
||||
self.norm2 = RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
||||
|
||||
def forward(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = torch.nn.functional.silu(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def measure(model, x, cache):
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
sync_if_needed(x)
|
||||
|
||||
start = time.time()
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
sync_if_needed(x)
|
||||
end = time.time()
|
||||
return (end - start) * 1000 / 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
H = 32
|
||||
D = 4096
|
||||
F = 43 * 256
|
||||
C = 1000
|
||||
device = torch.device("mps")
|
||||
dtype = torch.float16
|
||||
|
||||
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
|
||||
x = torch.randn(1, 1, D).to(device).to(dtype)
|
||||
cache = [
|
||||
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
||||
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
||||
]
|
||||
|
||||
T = measure(layer, x, cache)
|
||||
|
||||
print("Time per layer per token:", T, "ms")
|
||||
print("Lower bound total time per token:", T * 32, "ms")
|
63
benchmarks/python/rms_norm_bench.py
Normal file
63
benchmarks/python/rms_norm_bench.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def rms_norm(x, w, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
y = (x * n).astype(ot)
|
||||
if w is not None:
|
||||
y = y * w
|
||||
return y
|
||||
|
||||
|
||||
def time_rms_norm():
|
||||
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1))
|
||||
g2 = mx.grad(f2, argnums=(0, 1))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, y)
|
||||
|
||||
def rms_norm_loop(g, x, w):
|
||||
gx, gw = x, w
|
||||
for _ in range(32):
|
||||
gx, gw = g(gx, gw, y)
|
||||
return gx, gw
|
||||
|
||||
time_fn(rms_norm_loop, g1, x, w)
|
||||
time_fn(rms_norm_loop, g2, x, w)
|
||||
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
||||
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
||||
|
||||
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, y)
|
||||
|
||||
def rms_norm_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(rms_norm_loop, g1, x)
|
||||
time_fn(rms_norm_loop, g2, x)
|
||||
time_fn(rms_norm_loop, mx.compile(g1), x)
|
||||
time_fn(rms_norm_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_rms_norm()
|
35
benchmarks/python/rope_bench.py
Normal file
35
benchmarks/python/rope_bench.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def time_rope():
|
||||
rope = nn.RoPE(64)
|
||||
|
||||
# vec
|
||||
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_vec(x):
|
||||
for _ in range(32):
|
||||
x = rope(x, offset=100)
|
||||
return x
|
||||
|
||||
time_fn(rope_vec, x)
|
||||
|
||||
# matrix
|
||||
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_mat(x):
|
||||
for _ in range(32):
|
||||
x = rope(x)
|
||||
return x
|
||||
|
||||
time_fn(rope_mat, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_rope()
|
96
benchmarks/python/scatter_bench.py
Normal file
96
benchmarks/python/scatter_bench.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||
def scatter(dst, x, idx):
|
||||
dst[tuple(idx)] = x
|
||||
mx.eval(dst)
|
||||
|
||||
idx = []
|
||||
for idx_shape in idx_shapes:
|
||||
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
|
||||
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||
dst = mx.random.normal(dst_shape).astype(mx.float32)
|
||||
|
||||
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
|
||||
print(f"MLX: {runtime:.3f}ms")
|
||||
|
||||
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||
def scatter(dst, x, idx, device):
|
||||
dst[tuple(idx)] = x
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
idx = []
|
||||
for idx_shape in idx_shapes:
|
||||
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
dst_shapes = [
|
||||
(10, 64),
|
||||
(100_000, 64),
|
||||
(1_000_000, 64),
|
||||
(100_000,),
|
||||
(200_000,),
|
||||
(20_000_000,),
|
||||
(10000, 64),
|
||||
(100, 64),
|
||||
(100, 10_000, 64),
|
||||
(10, 100, 100, 21),
|
||||
(1_000, 1_000, 10),
|
||||
]
|
||||
idx_shapes = [
|
||||
[(1_000_000,)],
|
||||
[(1_000_000,)],
|
||||
[(100_000,)],
|
||||
[(1_000_000,)],
|
||||
[(20_000_000,)],
|
||||
[(20_000_000,)],
|
||||
[(1000000,)],
|
||||
[(10000000,)],
|
||||
[(1_000,)],
|
||||
[(10_000,)],
|
||||
[(1_000,), (1_000,)],
|
||||
]
|
||||
x_shapes = [
|
||||
(1_000_000, 64),
|
||||
(1_000_000, 64),
|
||||
(100_000, 64),
|
||||
(1_000_000,),
|
||||
(20_000_000,),
|
||||
(20_000_000,),
|
||||
(1000000, 64),
|
||||
(10000000, 64),
|
||||
(1_000, 10_000, 64),
|
||||
(10_000, 100, 100, 21),
|
||||
(1_000, 10),
|
||||
]
|
||||
|
||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
223
benchmarks/python/sdpa_bench.py
Normal file
223
benchmarks/python/sdpa_bench.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 5
|
||||
N_iter_bench = 40
|
||||
N_iter_func = 8
|
||||
|
||||
|
||||
def bench(f, *args):
|
||||
for i in range(N_warmup):
|
||||
f(*args)
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(*args)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||
|
||||
scale = 1.0 / math.sqrt(D)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
if mask is not None:
|
||||
if mask == "additive":
|
||||
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
||||
mask = mx.array(mask_np)
|
||||
elif mask == "bool":
|
||||
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||
mask = mx.array(mask_np)
|
||||
|
||||
return q_mx, k_mx, v_mx, scale, mask
|
||||
|
||||
|
||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
n_kv_heads = k.shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
B = q.shape[0]
|
||||
L = q.shape[2]
|
||||
kL = k.shape[2]
|
||||
|
||||
if n_repeats > 1:
|
||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||
k = mx.expand_dims(k, 2)
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
if mask == "causal":
|
||||
q_offset = max(0, kL - L)
|
||||
q_indices = mx.arange(q_offset, q_offset + L)
|
||||
k_indices = mx.arange(kL)
|
||||
mask = q_indices[:, None] >= k_indices[None]
|
||||
|
||||
if n_repeats > 1 and mask.ndim >= 3:
|
||||
if mask.shape[-3] == 1:
|
||||
mask = mx.expand_dims(mask, -3)
|
||||
else:
|
||||
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||
|
||||
if mask.dtype == mx.bool_:
|
||||
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||
else:
|
||||
scores += mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mlx_fused_attn(q, k, v, scale, mask):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||
if transpose:
|
||||
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||
else:
|
||||
return f(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
||||
q_out = q
|
||||
|
||||
for i in range(N_iter_func):
|
||||
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
||||
):
|
||||
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
||||
)
|
||||
|
||||
time_mlx_unfused = bench(
|
||||
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
time_mlx_fused = bench(
|
||||
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
|
||||
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
||||
o_mlx_unfused = do_attention(
|
||||
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
|
||||
atol = 1e-5 if dtype == "float32" else 2e-4
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
||||
print(
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
)
|
||||
|
||||
return time_mlx_fused, time_mlx_unfused
|
||||
|
||||
|
||||
def get_gflop_count(B, M, N, K):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
dtypes = ("float16", "float32")[:1]
|
||||
transposes = (False,)
|
||||
|
||||
# fmt: off
|
||||
shapes_64 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 32, 32, 64, 32, 32),
|
||||
( 1, 64, 64, 64, 32, 32),
|
||||
( 1, 128, 128, 64, 32, 32),
|
||||
( 1, 256, 256, 64, 32, 32),
|
||||
( 1, 512, 512, 64, 32, 32),
|
||||
( 1, 1024, 1024, 64, 32, 8),
|
||||
( 1, 2048, 2048, 64, 32, 8),
|
||||
( 1, 4096, 4096, 64, 32, 8),
|
||||
)
|
||||
|
||||
shapes_80 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 80, 32, 8),
|
||||
( 1, 2048, 2048, 80, 32, 8),
|
||||
( 1, 4096, 4096, 80, 32, 8),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 128, 32, 8),
|
||||
( 1, 2048, 2048, 128, 32, 8),
|
||||
( 1, 4096, 4096, 128, 32, 8),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_80 + shapes_128
|
||||
|
||||
masks = [None, "bool", "causal"]
|
||||
|
||||
print(
|
||||
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||
for mask_in in masks:
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B,
|
||||
qsl,
|
||||
ksl,
|
||||
head_dim,
|
||||
n_q_heads,
|
||||
n_kv_heads,
|
||||
dtype,
|
||||
transpose,
|
||||
mask_in,
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
95
benchmarks/python/sdpa_vector_bench.py
Normal file
95
benchmarks/python/sdpa_vector_bench.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 16384
|
||||
H = 32
|
||||
H_k = H // 4
|
||||
D = 128
|
||||
V = 128
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def upproject(x, w):
|
||||
if w is None:
|
||||
return x
|
||||
else:
|
||||
return x @ w.T
|
||||
|
||||
|
||||
def attention(q, k, v, mask=None, w=None):
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
_, _, _, V = v.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
if mask is not None:
|
||||
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
|
||||
s = mx.where(m, s, mx.finfo(s.dtype).min)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, V)
|
||||
|
||||
for i in range(loops):
|
||||
q = _sdpa(q, k, v)
|
||||
q = upproject(q, w)
|
||||
return q
|
||||
|
||||
|
||||
def sdpa(q, k, v, mask=None, w=None):
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
q = upproject(q, w)
|
||||
return q
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mx.eval(q, k, v, w)
|
||||
time_fn(attention, q, k, v, w=w)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mx.eval(q, k, v, w)
|
||||
time_fn(sdpa, q, k, v, w=w)
|
||||
|
||||
|
||||
def time_self_attention_sdpa_with_mask():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mask = mx.full((L,), True)
|
||||
mask[L // 2 :] = False
|
||||
mx.eval(q, k, v, mask, w)
|
||||
|
||||
def sdpa_mask(*args):
|
||||
return sdpa(*args, mask=mask, w=w)
|
||||
|
||||
def attention_mask(*args):
|
||||
return attention(*args, mask=mask, w=w)
|
||||
|
||||
time_fn(attention_mask, q, k, v)
|
||||
time_fn(sdpa_mask, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
time_self_attention_sdpa_with_mask()
|
@@ -44,6 +44,13 @@ def time_matmul():
|
||||
time_fn(mx.matmul, a, b)
|
||||
|
||||
|
||||
def time_maximum():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
b = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
mx.eval(a, b)
|
||||
time_fn(mx.maximum, a, b)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
@@ -101,6 +108,7 @@ if __name__ == "__main__":
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_maximum()
|
||||
time_exp()
|
||||
time_negative()
|
||||
time_logsumexp()
|
||||
|
55
benchmarks/python/synchronize_bench.py
Normal file
55
benchmarks/python/synchronize_bench.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
rank = mx.distributed.init().rank()
|
||||
|
||||
|
||||
def timeit(fn, a):
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(a))
|
||||
|
||||
its = 10
|
||||
tic = time.perf_counter()
|
||||
for _ in range(its):
|
||||
mx.eval(fn(a))
|
||||
toc = time.perf_counter()
|
||||
ms = 1000 * (toc - tic) / its
|
||||
return ms
|
||||
|
||||
|
||||
def all_reduce_benchmark():
|
||||
a = mx.ones((5, 5), mx.int32)
|
||||
|
||||
its_per_eval = 100
|
||||
|
||||
def fn(x):
|
||||
for _ in range(its_per_eval):
|
||||
x = mx.distributed.all_sum(x)
|
||||
x = x - 1
|
||||
return x
|
||||
|
||||
ms = timeit(fn, a) / its_per_eval
|
||||
if rank == 0:
|
||||
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
|
||||
|
||||
|
||||
def all_gather_benchmark():
|
||||
a = mx.ones((5, 5), mx.int32)
|
||||
its_per_eval = 100
|
||||
|
||||
def fn(x):
|
||||
for _ in range(its_per_eval):
|
||||
x = mx.distributed.all_gather(x)[0]
|
||||
return x
|
||||
|
||||
ms = timeit(fn, a) / its_per_eval
|
||||
if rank == 0:
|
||||
print(f"All gather: time per iteration {ms:.6f} (ms)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
all_reduce_benchmark()
|
||||
all_gather_benchmark()
|
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
@@ -6,7 +6,11 @@ import mlx.core as mx
|
||||
|
||||
|
||||
def time_fn(fn, *args, **kwargs):
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
msg = kwargs.pop("msg", None)
|
||||
if msg:
|
||||
print(f"Timing {msg} ...", end=" ")
|
||||
else:
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
@@ -20,3 +24,15 @@ def time_fn(fn, *args, **kwargs):
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
print(f"{msec:.5f} msec")
|
||||
|
||||
|
||||
def measure_runtime(fn, **kwargs):
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
fn(**kwargs)
|
||||
|
||||
tic = time.time()
|
||||
iters = 100
|
||||
for _ in range(iters):
|
||||
fn(**kwargs)
|
||||
return (time.time() - tic) * 1000 / iters
|
||||
|
@@ -1,56 +1,45 @@
|
||||
include(CMakeParseArguments)
|
||||
|
||||
###############################################################################
|
||||
# clang format off
|
||||
#
|
||||
# ##############################################################################
|
||||
# Build metal library
|
||||
#
|
||||
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
||||
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
||||
#
|
||||
# Args:
|
||||
# TARGET: Custom target to be added for the metal library
|
||||
# TITLE: Name of the .metallib
|
||||
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
||||
# SOURCES: List of source files
|
||||
# INCLUDE_DIRS: List of include dirs
|
||||
# DEPS: List of depedency files (like headers)
|
||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||
# files (like headers)
|
||||
#
|
||||
# clang format on
|
||||
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||
cmake_parse_arguments(
|
||||
MTLLIB
|
||||
""
|
||||
"${oneValueArgs}"
|
||||
"${multiValueArgs}"
|
||||
${ARGN}
|
||||
)
|
||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
# Set output
|
||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
|
||||
# Prepare metllib build command
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND xcrun -sdk macosx metal
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS}
|
||||
${MTLLIB_SOURCES}
|
||||
-o ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND
|
||||
xcrun -sdk macosx metal
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||
COMMAND_EXPAND_LISTS
|
||||
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
||||
VERBATIM
|
||||
)
|
||||
VERBATIM)
|
||||
|
||||
# Add metallib custom target
|
||||
add_custom_target(
|
||||
${MTLLIB_TARGET}
|
||||
DEPENDS
|
||||
${MTLLIB_BUILD_TARGET}
|
||||
)
|
||||
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
|
||||
|
||||
endmacro(mlx_build_metallib)
|
||||
endmacro(mlx_build_metallib)
|
||||
|
1
docs/.gitignore
vendored
1
docs/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
src/python/_autosummary*/
|
||||
src/python/nn/_autosummary*/
|
||||
src/python/optimizers/_autosummary*/
|
||||
|
50
docs/Doxyfile
Normal file
50
docs/Doxyfile
Normal file
@@ -0,0 +1,50 @@
|
||||
################################################################################
|
||||
# Primary project setup. #
|
||||
################################################################################
|
||||
|
||||
PROJECT_NAME = "MLX"
|
||||
OUTPUT_DIRECTORY = build
|
||||
XML_OUTPUT = xml
|
||||
HTML_OUTPUT = html
|
||||
STRIP_FROM_PATH = ../
|
||||
INPUT = ../mlx
|
||||
FILE_PATTERNS = *.h
|
||||
EXCLUDE_PATTERNS = */private/*
|
||||
CREATE_SUBDIRS = NO
|
||||
FULL_PATH_NAMES = YES
|
||||
RECURSIVE = YES
|
||||
GENERATE_HTML = NO
|
||||
GENERATE_LATEX = NO
|
||||
GENERATE_XML = YES
|
||||
XML_PROGRAMLISTING = YES
|
||||
|
||||
################################################################################
|
||||
# Doxygen preprocessor / parser control. #
|
||||
################################################################################
|
||||
|
||||
ENABLE_PREPROCESSING = YES
|
||||
MACRO_EXPANSION = YES
|
||||
EXPAND_ONLY_PREDEF = NO
|
||||
SKIP_FUNCTION_MACROS = NO
|
||||
|
||||
################################################################################
|
||||
# Compound extraction control. #
|
||||
################################################################################
|
||||
|
||||
EXTRACT_ALL = YES
|
||||
EXTRACT_PACKAGE = YES
|
||||
EXTRACT_STATIC = YES
|
||||
CASE_SENSE_NAMES = NO
|
||||
|
||||
################################################################################
|
||||
# Docstring control / customization. #
|
||||
################################################################################
|
||||
|
||||
JAVADOC_AUTOBRIEF = YES
|
||||
|
||||
################################################################################
|
||||
# Warning suppression. #
|
||||
################################################################################
|
||||
|
||||
QUIET = YES
|
||||
WARN_IF_UNDOCUMENTED = NO
|
@@ -2,12 +2,16 @@
|
||||
|
||||
### Setup (do once)
|
||||
|
||||
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
|
||||
for example with `conda`:
|
||||
Install Doxygen:
|
||||
|
||||
```
|
||||
conda install sphinx
|
||||
pip install sphinx-book-theme
|
||||
brew install doxygen
|
||||
```
|
||||
|
||||
Install Python packages:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Build
|
||||
@@ -15,7 +19,7 @@ pip install sphinx-book-theme
|
||||
Build the docs from `mlx/docs/`
|
||||
|
||||
```
|
||||
make html
|
||||
doxygen && make html
|
||||
```
|
||||
|
||||
View the docs by running a server in `mlx/docs/build/html/`:
|
||||
@@ -26,7 +30,7 @@ python -m http.server <port>
|
||||
|
||||
and point your browser to `http://localhost:<port>`.
|
||||
|
||||
### Push to Github Pages
|
||||
### Push to GitHub Pages
|
||||
|
||||
Check-out the `gh-pages` branch (`git switch gh-pages`) and build
|
||||
the docs. Then force add the `build/html` directory:
|
||||
|
4
docs/requirements.txt
Normal file
4
docs/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
sphinx
|
||||
breathe
|
||||
sphinx-book-theme
|
||||
mlx
|
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 746 KiB |
Binary file not shown.
Before Width: | Height: | Size: 7.2 KiB After Width: | Height: | Size: 76 KiB |
BIN
docs/src/_static/mlx_logo_dark.png
Normal file
BIN
docs/src/_static/mlx_logo_dark.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 48 KiB |
33
docs/src/_templates/module-base-class.rst
Normal file
33
docs/src/_templates/module-base-class.rst
Normal file
@@ -0,0 +1,33 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. currentmodule:: {{ module }}
|
||||
|
||||
.. add toctree option to make autodoc generate the pages
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
|
||||
{% block attributes %}
|
||||
{% if attributes %}
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
:toctree: .
|
||||
{% for item in attributes %}
|
||||
~{{ fullname }}.{{ item }}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block methods %}
|
||||
{% if methods %}
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
:toctree: .
|
||||
{% for item in methods %}
|
||||
{%- if item not in inherited_members and item != '__init__' %}
|
||||
~{{ fullname }}.{{ item }}
|
||||
{%- endif -%}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}
|
@@ -4,16 +4,17 @@
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
|
||||
{#{% block methods %}
|
||||
{% block methods %}
|
||||
|
||||
{% if methods %}
|
||||
.. rubric:: {{ _('Methods') }}
|
||||
|
||||
.. autosummary::
|
||||
{% for item in methods %}
|
||||
{%- if item not in inherited_members and item != '__init__' %}
|
||||
{%- if item not in inherited_members and item != "__init__" %}
|
||||
~{{ name }}.{{ item }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}#}
|
||||
{% endblock %}
|
||||
|
||||
|
@@ -5,13 +5,15 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "MLX"
|
||||
copyright = "2023, MLX Contributors"
|
||||
author = "MLX Contributors"
|
||||
version = "0.0.6"
|
||||
release = "0.0.6"
|
||||
version = ".".join(mx.__version__.split(".")[:3])
|
||||
release = version
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
@@ -20,22 +22,28 @@ extensions = [
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx.ext.napoleon",
|
||||
"breathe",
|
||||
]
|
||||
|
||||
python_use_unqualified_type_names = True
|
||||
autosummary_generate = True
|
||||
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
|
||||
|
||||
intersphinx_mapping = {
|
||||
"https://docs.python.org/3": None,
|
||||
"https://numpy.org/doc/stable/": None,
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"numpy": ("https://numpy.org/doc/stable/", None),
|
||||
}
|
||||
|
||||
breathe_projects = {"mlx": "../build/xml"}
|
||||
breathe_default_project = "mlx"
|
||||
|
||||
templates_path = ["_templates"]
|
||||
html_static_path = ["_static"]
|
||||
source_suffix = ".rst"
|
||||
master_doc = "index"
|
||||
main_doc = "index"
|
||||
highlight_language = "python"
|
||||
pygments_style = "sphinx"
|
||||
add_module_names = False
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
@@ -46,11 +54,45 @@ html_theme_options = {
|
||||
"repository_url": "https://github.com/ml-explore/mlx",
|
||||
"use_repository_button": True,
|
||||
"navigation_with_keys": False,
|
||||
"logo": {
|
||||
"image_light": "_static/mlx_logo.png",
|
||||
"image_dark": "_static/mlx_logo_dark.png",
|
||||
},
|
||||
}
|
||||
|
||||
html_logo = "_static/mlx_logo.png"
|
||||
|
||||
html_favicon = html_theme_options["logo"]["image_light"]
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
htmlhelp_basename = "mlx_doc"
|
||||
|
||||
|
||||
def setup(app):
|
||||
from sphinx.util import inspect
|
||||
|
||||
wrapped_isfunc = inspect.isfunction
|
||||
|
||||
def isfunc(obj):
|
||||
type_name = str(type(obj))
|
||||
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
|
||||
return True
|
||||
return wrapped_isfunc(obj)
|
||||
|
||||
inspect.isfunction = isfunc
|
||||
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
|
||||
latex_elements = {
|
||||
"preamble": r"""
|
||||
\usepackage{enumitem}
|
||||
\setlistdepth{5}
|
||||
\setlist[itemize,1]{label=$\bullet$}
|
||||
\setlist[itemize,2]{label=$\bullet$}
|
||||
\setlist[itemize,3]{label=$\bullet$}
|
||||
\setlist[itemize,4]{label=$\bullet$}
|
||||
\setlist[itemize,5]{label=$\bullet$}
|
||||
\renewlist{itemize}{itemize}{5}
|
||||
""",
|
||||
}
|
||||
|
@@ -3,4 +3,5 @@
|
||||
Operations
|
||||
==========
|
||||
|
||||
|
||||
.. doxygengroup:: ops
|
||||
:content-only:
|
||||
|
427
docs/src/dev/custom_metal_kernels.rst
Normal file
427
docs/src/dev/custom_metal_kernels.rst
Normal file
@@ -0,0 +1,427 @@
|
||||
.. _custom_metal_kernels:
|
||||
|
||||
Custom Metal Kernels
|
||||
====================
|
||||
|
||||
MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
.. note::
|
||||
We are only required to pass the body of the Metal kernel in ``source``.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
* The shapes/dtypes of ``inputs``
|
||||
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
||||
so we will add ``const device float16_t* inp`` to the signature.
|
||||
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
||||
in ``source``.
|
||||
* The list of ``output_dtypes``
|
||||
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
||||
so we add ``device float16_t* out``.
|
||||
* Template parameters passed using ``template``
|
||||
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
|
||||
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
||||
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
||||
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
||||
These will be added as function arguments.
|
||||
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
|
||||
|
||||
Putting this all together, the generated function signature for ``myexp`` is as follows:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void custom_kernel_myexp_float(
|
||||
const device float16_t* inp [[buffer(0)]],
|
||||
device float16_t* out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
|
||||
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
|
||||
}
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||
when indexing.
|
||||
|
||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||
input array ``a`` if any are present in ``source``.
|
||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
# make non-contiguous
|
||||
a = a[::2]
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Complex Example
|
||||
-----------------------------
|
||||
|
||||
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
|
||||
|
||||
We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
``55.7ms -> 6.7ms => 8x speed up``
|
||||
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
|
||||
* ``atomic_outputs=True``
|
||||
Designate all of the kernel outputs as ``atomic`` in the function signature.
|
||||
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
|
||||
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
|
||||
|
||||
We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
``676.4ms -> 16.7ms => 40x speed up``
|
File diff suppressed because it is too large
Load Diff
68
docs/src/dev/metal_debugger.rst
Normal file
68
docs/src/dev/metal_debugger.rst
Normal file
@@ -0,0 +1,68 @@
|
||||
Metal Debugger
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Profiling is a key step for performance optimization. You can build MLX with
|
||||
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
|
||||
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
|
||||
|
||||
* Records source during Metal compilation, for later inspection while
|
||||
debugging.
|
||||
* Labels Metal objects such as command queues, improving capture readability.
|
||||
|
||||
To build with debugging enabled in Python prepend
|
||||
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
|
||||
|
||||
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
|
||||
work.
|
||||
|
||||
.. note::
|
||||
|
||||
To capture a GPU trace you must run the application with
|
||||
``MTL_CAPTURE_ENABLED=1``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
a = mx.random.uniform(shape=(512, 512))
|
||||
b = mx.random.uniform(shape=(512, 512))
|
||||
mx.eval(a, b)
|
||||
|
||||
trace_file = "mlx_trace.gputrace"
|
||||
|
||||
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
|
||||
# that the path trace_file does not already exist.
|
||||
mx.metal.start_capture(trace_file)
|
||||
|
||||
for _ in range(10):
|
||||
mx.eval(mx.add(a, b))
|
||||
|
||||
mx.metal.stop_capture()
|
||||
|
||||
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
|
||||
has a great overview of all operations. Checkout the `Metal debugger
|
||||
documentation`_ for more information.
|
||||
|
||||
.. image:: ../_static/metal_debugger/capture.png
|
||||
:class: dark-light
|
||||
|
||||
Xcode Workflow
|
||||
--------------
|
||||
|
||||
You can skip saving to a path by running within Xcode. First, generate an
|
||||
Xcode project using CMake.
|
||||
|
||||
.. code-block::
|
||||
|
||||
mkdir build && cd build
|
||||
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
|
||||
open mlx.xcodeproj
|
||||
|
||||
Select the ``metal_capture`` example schema and run.
|
||||
|
||||
.. image:: ../_static/metal_debugger/schema.png
|
||||
:class: dark-light
|
||||
|
||||
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger
|
121
docs/src/dev/mlx_in_cpp.rst
Normal file
121
docs/src/dev/mlx_in_cpp.rst
Normal file
@@ -0,0 +1,121 @@
|
||||
.. _mlx_in_cpp:
|
||||
|
||||
Using MLX in C++
|
||||
================
|
||||
|
||||
You can use MLX in a C++ project with CMake.
|
||||
|
||||
.. note::
|
||||
|
||||
This guide is based one the following `example using MLX in C++
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/cmake_project>`_
|
||||
|
||||
First install MLX:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U mlx
|
||||
|
||||
You can also install the MLX Python package from source or just the C++
|
||||
library. For more information see the :ref:`documentation on installing MLX
|
||||
<build_and_install>`.
|
||||
|
||||
Next make an example program in ``example.cpp``:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
int main() {
|
||||
auto x = mx::array({1, 2, 3});
|
||||
auto y = mx::array({1, 2, 3});
|
||||
std::cout << x + y << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
The next step is to setup a CMake file in ``CMakeLists.txt``:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(example LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
|
||||
Depending on how you installed MLX, you may need to tell CMake where to
|
||||
find it.
|
||||
|
||||
If you installed MLX with Python, then add the following to the CMake file:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
|
||||
If you installed the MLX C++ package to a system path, then CMake should be
|
||||
able to find it. If you installed it to a non-standard location or CMake can't
|
||||
find MLX then set ``MLX_ROOT`` to the location where MLX is installed:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
set(MLX_ROOT "/path/to/mlx/")
|
||||
|
||||
Next, instruct CMake to find MLX:
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
Finally, add the ``example.cpp`` program as an executable and link MLX.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
add_executable(example example.cpp)
|
||||
target_link_libraries(example PRIVATE mlx)
|
||||
|
||||
You can build the example with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
|
||||
And run it with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./build/example
|
||||
|
||||
Note ``find_package(MLX CONFIG REQUIRED)`` sets the following variables:
|
||||
|
||||
.. list-table:: Package Variables
|
||||
:widths: 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - Variable
|
||||
- Description
|
||||
* - MLX_FOUND
|
||||
- ``True`` if MLX is found
|
||||
* - MLX_INCLUDE_DIRS
|
||||
- Include directory
|
||||
* - MLX_LIBRARIES
|
||||
- Libraries to link against
|
||||
* - MLX_CXX_FLAGS
|
||||
- Additional compiler flags
|
||||
* - MLX_BUILD_ACCELERATE
|
||||
- ``True`` if MLX was built with Accelerate
|
||||
* - MLX_BUILD_METAL
|
||||
- ``True`` if MLX was built with Metal
|
@@ -15,7 +15,7 @@ module to concisely define the model architecture.
|
||||
Attention layer
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
We will start with the llama attention layer which notably uses the RoPE
|
||||
We will start with the Llama attention layer which notably uses the RoPE
|
||||
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
||||
key/value cache that will be concatenated with the provided keys and values to
|
||||
support efficient inference.
|
||||
@@ -371,7 +371,7 @@ Scripts
|
||||
|
||||
The full example code is available in `mlx-examples`_.
|
||||
|
||||
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llama
|
||||
.. _mlx-examples: https://github.com/ml-explore/mlx-examples/tree/main/llms/llama
|
||||
|
||||
.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021.
|
||||
Roformer: Enhanced transformer with rotary position embedding. arXiv
|
||||
|
@@ -64,7 +64,7 @@ set:
|
||||
Next, setup the problem parameters and load the data. To load the data, you need our
|
||||
`mnist data loader
|
||||
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
||||
we will import as `mnist`.
|
||||
we will import as ``mnist``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@@ -19,7 +19,7 @@ The main differences between MLX and NumPy are:
|
||||
|
||||
The design of MLX is inspired by frameworks like `PyTorch
|
||||
<https://pytorch.org/>`_, `Jax <https://github.com/google/jax>`_, and
|
||||
`ArrayFire <https://arrayfire.org/>`_. A noteable difference from these
|
||||
`ArrayFire <https://arrayfire.org/>`_. A notable difference from these
|
||||
frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared
|
||||
memory. Operations on MLX arrays can be performed on any of the supported
|
||||
device types without performing data copies. Currently supported device types
|
||||
@@ -35,9 +35,17 @@ are the CPU and GPU.
|
||||
:caption: Usage
|
||||
:maxdepth: 1
|
||||
|
||||
quick_start
|
||||
unified_memory
|
||||
using_streams
|
||||
usage/quick_start
|
||||
usage/lazy_evaluation
|
||||
usage/unified_memory
|
||||
usage/indexing
|
||||
usage/saving_and_loading
|
||||
usage/function_transforms
|
||||
usage/compile
|
||||
usage/numpy
|
||||
usage/distributed
|
||||
usage/using_streams
|
||||
usage/export
|
||||
|
||||
.. toctree::
|
||||
:caption: Examples
|
||||
@@ -52,13 +60,20 @@ are the CPU and GPU.
|
||||
:maxdepth: 1
|
||||
|
||||
python/array
|
||||
python/data_types
|
||||
python/devices_and_streams
|
||||
python/export
|
||||
python/ops
|
||||
python/random
|
||||
python/transforms
|
||||
python/fast
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
python/tree_utils
|
||||
|
||||
.. toctree::
|
||||
@@ -72,3 +87,6 @@ are the CPU and GPU.
|
||||
:maxdepth: 1
|
||||
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
dev/custom_metal_kernels
|
||||
dev/mlx_in_cpp
|
||||
|
@@ -1,8 +1,10 @@
|
||||
.. _build_and_install:
|
||||
|
||||
Build and Install
|
||||
=================
|
||||
|
||||
Install from PyPI
|
||||
-----------------
|
||||
Python Installation
|
||||
-------------------
|
||||
|
||||
MLX is available on PyPI. All you have to do to use MLX with your own Apple
|
||||
silicon computer is
|
||||
@@ -14,13 +16,21 @@ silicon computer is
|
||||
To install from PyPI you must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.8
|
||||
- macOS >= 13.3
|
||||
- Using a native Python >= 3.9
|
||||
- macOS >= 13.5
|
||||
|
||||
.. note::
|
||||
MLX is only available on devices running macOS >= 13.3
|
||||
MLX is only available on devices running macOS >= 13.5
|
||||
It is highly recommended to use macOS 14 (Sonoma)
|
||||
|
||||
|
||||
MLX is also available on conda-forge. To install MLX with conda do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
conda install conda-forge::mlx
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
@@ -45,9 +55,12 @@ Build Requirements
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
|
||||
- `cmake <https://cmake.org/>`_ -- version 3.25 or later, and ``make``
|
||||
- Xcode >= 15.0 and macOS SDK >= 14.0
|
||||
|
||||
.. note::
|
||||
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
|
||||
the output of ``uname -p`` is ``x86``, see the :ref:`troubleshooting section <build shell>` below.
|
||||
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
@@ -59,39 +72,36 @@ To build and install the MLX python library from source, first, clone MLX from
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
|
||||
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
|
||||
Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install "pybind11[global]"
|
||||
conda install pybind11
|
||||
brew install pybind11
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
||||
|
||||
Then simply build and install it using pip:
|
||||
For developing, install the package with development dependencies, and use an
|
||||
editable install:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
||||
|
||||
For developing use an editable install:
|
||||
Once the development dependencies are installed, you can build faster with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
||||
|
||||
To make sure the install is working run the tests with:
|
||||
Run the tests with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install ".[testing]"
|
||||
python -m unittest discover python/tests
|
||||
|
||||
Optional: Install stubs to enable auto completions and type checking from your IDE:
|
||||
Optional: Install stubs to enable auto completions and type checking from your
|
||||
IDE:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install ".[dev]"
|
||||
python setup.py generate_stubs
|
||||
|
||||
C++ API
|
||||
@@ -112,7 +122,7 @@ Create a build directory and run CMake and make:
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. && make -j
|
||||
cmake .. && make -j
|
||||
|
||||
Run tests with:
|
||||
|
||||
@@ -131,7 +141,7 @@ directory as the executable statically linked to ``libmlx.a`` or the
|
||||
preprocessor constant ``METAL_PATH`` should be defined at build time and it
|
||||
should point to the path to the built metal library.
|
||||
|
||||
.. list-table:: Build Options
|
||||
.. list-table:: Build Options
|
||||
:widths: 25 8
|
||||
:header-rows: 1
|
||||
|
||||
@@ -145,27 +155,64 @@ should point to the path to the built metal library.
|
||||
- OFF
|
||||
* - MLX_BUILD_METAL
|
||||
- ON
|
||||
* - MLX_BUILD_CPU
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
||||
|
||||
* - MLX_METAL_DEBUG
|
||||
- OFF
|
||||
* - MLX_BUILD_SAFETENSORS
|
||||
- ON
|
||||
* - MLX_BUILD_GGUF
|
||||
- ON
|
||||
* - MLX_METAL_JIT
|
||||
- OFF
|
||||
|
||||
.. note::
|
||||
|
||||
If you have multiple Xcode installations and wish to use
|
||||
a specific one while building, you can do so by adding the
|
||||
following environment variable before building
|
||||
If you have multiple Xcode installations and wish to use
|
||||
a specific one while building, you can do so by adding the
|
||||
following environment variable before building
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
|
||||
|
||||
Further, you can use the following command to find out which
|
||||
Further, you can use the following command to find out which
|
||||
macOS SDK will be used
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
|
||||
and ``BUILD_SHARED_LIBS=ON``.
|
||||
|
||||
The MLX CMake build has several additional options to make smaller binaries.
|
||||
For example, if you don't need the CPU backend or support for safetensors and
|
||||
GGUF, you can do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
cmake .. \
|
||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
|
||||
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
|
||||
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
@@ -189,3 +236,34 @@ Then set the active developer directory:
|
||||
.. code-block:: shell
|
||||
|
||||
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
|
||||
|
||||
x86 Shell
|
||||
~~~~~~~~~
|
||||
|
||||
.. _build shell:
|
||||
|
||||
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||
Rosetta instead of natively.
|
||||
|
||||
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
||||
``/Applications/Utilities`` for Terminal), right-click, and click “Get Info”.
|
||||
Uncheck “Open using Rosetta”, close the “Get Info” window, and restart your
|
||||
terminal.
|
||||
|
||||
Verify the terminal is now running natively the following command:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
$ uname -p
|
||||
arm
|
||||
|
||||
Also check that cmake is using the correct architecture:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
$ cmake --system-information | grep CMAKE_HOST_SYSTEM_PROCESSOR
|
||||
CMAKE_HOST_SYSTEM_PROCESSOR "arm64"
|
||||
|
||||
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
||||
but the build errors out with "Building for x86_64 on macOS is not supported."
|
||||
wipe your build cache with ``rm -rf build/`` and try again.
|
||||
|
@@ -10,27 +10,40 @@ Array
|
||||
|
||||
array
|
||||
array.astype
|
||||
array.at
|
||||
array.item
|
||||
array.tolist
|
||||
array.dtype
|
||||
array.itemsize
|
||||
array.nbytes
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
Dtype
|
||||
array.abs
|
||||
array.all
|
||||
array.any
|
||||
array.argmax
|
||||
array.argmin
|
||||
array.conj
|
||||
array.cos
|
||||
array.dtype
|
||||
array.cummax
|
||||
array.cummin
|
||||
array.cumprod
|
||||
array.cumsum
|
||||
array.diag
|
||||
array.diagonal
|
||||
array.exp
|
||||
array.flatten
|
||||
array.log
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logcumsumexp
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
array.min
|
||||
array.moveaxis
|
||||
array.prod
|
||||
array.reciprocal
|
||||
array.reshape
|
||||
@@ -40,7 +53,11 @@ Array
|
||||
array.split
|
||||
array.sqrt
|
||||
array.square
|
||||
array.squeeze
|
||||
array.std
|
||||
array.sum
|
||||
array.swapaxes
|
||||
array.transpose
|
||||
array.T
|
||||
array.var
|
||||
array.view
|
||||
|
@@ -1,7 +1,5 @@
|
||||
.. _data_types:
|
||||
|
||||
:orphan:
|
||||
|
||||
Data Types
|
||||
==========
|
||||
|
||||
@@ -29,9 +27,9 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``uint32``
|
||||
- 4
|
||||
- 32-bit unsigned integer
|
||||
* - ``uint32``
|
||||
* - ``uint64``
|
||||
- 8
|
||||
- 32-bit unsigned integer
|
||||
- 64-bit unsigned integer
|
||||
* - ``int8``
|
||||
- 1
|
||||
- 8-bit signed integer
|
||||
@@ -44,9 +42,37 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``int64``
|
||||
- 8
|
||||
- 64-bit signed integer
|
||||
* - ``bfloat16``
|
||||
- 2
|
||||
- 16-bit brain float (e8, m7)
|
||||
* - ``float16``
|
||||
- 2
|
||||
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
|
||||
- 16-bit IEEE float (e5, m10)
|
||||
* - ``float32``
|
||||
- 4
|
||||
- 32-bit float
|
||||
* - ``float64``
|
||||
- 4
|
||||
- 64-bit double
|
||||
* - ``complex64``
|
||||
- 8
|
||||
- 64-bit complex float
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Arrays with type ``float64`` only work with CPU operations. Using
|
||||
``float64`` arrays on the GPU will result in an exception.
|
||||
|
||||
|
||||
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
``dtype`` (or category) is a subtype of another category.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Dtype
|
||||
DtypeCategory
|
||||
issubdtype
|
||||
finfo
|
||||
|
@@ -9,9 +9,11 @@ Devices and Streams
|
||||
:toctree: _autosummary
|
||||
|
||||
Device
|
||||
Stream
|
||||
default_device
|
||||
set_default_device
|
||||
Stream
|
||||
default_stream
|
||||
new_stream
|
||||
set_default_stream
|
||||
stream
|
||||
synchronize
|
||||
|
22
docs/src/python/distributed.rst
Normal file
22
docs/src/python/distributed.rst
Normal file
@@ -0,0 +1,22 @@
|
||||
.. _distributed:
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
Distributed Communication
|
||||
==========================
|
||||
|
||||
MLX provides a distributed communication package using MPI. The MPI library is
|
||||
loaded at runtime; if MPI is available then distributed communication is also
|
||||
made available.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Group
|
||||
is_available
|
||||
init
|
||||
all_sum
|
||||
all_gather
|
||||
send
|
||||
recv
|
||||
recv_like
|
14
docs/src/python/export.rst
Normal file
14
docs/src/python/export.rst
Normal file
@@ -0,0 +1,14 @@
|
||||
.. _export:
|
||||
|
||||
Export Functions
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
export_function
|
||||
import_function
|
||||
exporter
|
||||
export_to_dot
|
15
docs/src/python/fast.rst
Normal file
15
docs/src/python/fast.rst
Normal file
@@ -0,0 +1,15 @@
|
||||
.. _fast:
|
||||
|
||||
Fast
|
||||
====
|
||||
|
||||
.. currentmodule:: mlx.core.fast
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
rms_norm
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
metal_kernel
|
@@ -20,3 +20,5 @@ FFT
|
||||
irfft2
|
||||
rfftn
|
||||
irfftn
|
||||
fftshift
|
||||
ifftshift
|
||||
|
25
docs/src/python/linalg.rst
Normal file
25
docs/src/python/linalg.rst
Normal file
@@ -0,0 +1,25 @@
|
||||
.. _linalg:
|
||||
|
||||
Linear Algebra
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core.linalg
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
inv
|
||||
tri_inv
|
||||
norm
|
||||
cholesky
|
||||
cholesky_inv
|
||||
cross
|
||||
qr
|
||||
svd
|
||||
eigvalsh
|
||||
eigh
|
||||
lu
|
||||
lu_factor
|
||||
pinv
|
||||
solve
|
||||
solve_triangular
|
16
docs/src/python/memory_management.rst
Normal file
16
docs/src/python/memory_management.rst
Normal file
@@ -0,0 +1,16 @@
|
||||
Memory Management
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
12
docs/src/python/metal.rst
Normal file
12
docs/src/python/metal.rst
Normal file
@@ -0,0 +1,12 @@
|
||||
Metal
|
||||
=====
|
||||
|
||||
.. currentmodule:: mlx.core.metal
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
is_available
|
||||
device_info
|
||||
start_capture
|
||||
stop_capture
|
@@ -123,7 +123,7 @@ To get more detailed information on the arrays in a :class:`Module` you can use
|
||||
all the parameters in a :class:`Module` do:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
from mlx.utils import tree_map
|
||||
shapes = tree_map(lambda p: p.shape, mlp.parameters())
|
||||
|
||||
@@ -131,7 +131,7 @@ As another example, you can count the number of parameters in a :class:`Module`
|
||||
with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
from mlx.utils import tree_flatten
|
||||
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
|
||||
|
||||
@@ -170,14 +170,16 @@ In detail:
|
||||
:meth:`mlx.core.value_and_grad`
|
||||
|
||||
.. autosummary::
|
||||
:recursive:
|
||||
:toctree: _autosummary
|
||||
|
||||
value_and_grad
|
||||
Module
|
||||
quantize
|
||||
average_gradients
|
||||
|
||||
.. toctree::
|
||||
|
||||
nn/module
|
||||
nn/layers
|
||||
nn/functions
|
||||
nn/losses
|
||||
nn/init
|
||||
|
@@ -12,12 +12,28 @@ simple functions.
|
||||
:toctree: _autosummary_functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
elu
|
||||
celu
|
||||
gelu
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
relu
|
||||
prelu
|
||||
silu
|
||||
step
|
||||
selu
|
||||
glu
|
||||
hard_shrink
|
||||
hard_tanh
|
||||
hardswish
|
||||
leaky_relu
|
||||
log_sigmoid
|
||||
log_softmax
|
||||
mish
|
||||
prelu
|
||||
relu
|
||||
relu6
|
||||
selu
|
||||
sigmoid
|
||||
silu
|
||||
softmax
|
||||
softmin
|
||||
softplus
|
||||
softshrink
|
||||
step
|
||||
tanh
|
||||
|
45
docs/src/python/nn/init.rst
Normal file
45
docs/src/python/nn/init.rst
Normal file
@@ -0,0 +1,45 @@
|
||||
.. _init:
|
||||
|
||||
.. currentmodule:: mlx.nn.init
|
||||
|
||||
Initializers
|
||||
------------
|
||||
|
||||
The ``mlx.nn.init`` package contains commonly used initializers for neural
|
||||
network parameters. Initializers return a function which can be applied to any
|
||||
input :obj:`mlx.core.array` to produce an initialized output.
|
||||
|
||||
For example:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
init_fn = nn.init.uniform()
|
||||
|
||||
# Produces a [2, 2] uniform matrix
|
||||
param = init_fn(mx.zeros((2, 2)))
|
||||
|
||||
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
|
||||
distribution, you can do:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.nn as nn
|
||||
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
|
||||
init_fn = nn.init.uniform(low=-0.1, high=0.1)
|
||||
model.apply(init_fn)
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
constant
|
||||
normal
|
||||
uniform
|
||||
identity
|
||||
glorot_normal
|
||||
glorot_uniform
|
||||
he_normal
|
||||
he_uniform
|
@@ -9,21 +9,61 @@ Layers
|
||||
:toctree: _autosummary
|
||||
:template: nn-module-template.rst
|
||||
|
||||
Embedding
|
||||
ReLU
|
||||
PReLU
|
||||
GELU
|
||||
SiLU
|
||||
Step
|
||||
SELU
|
||||
Mish
|
||||
Linear
|
||||
ALiBi
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
AvgPool3d
|
||||
BatchNorm
|
||||
CELU
|
||||
Conv1d
|
||||
Conv2d
|
||||
LayerNorm
|
||||
RMSNorm
|
||||
Conv3d
|
||||
ConvTranspose1d
|
||||
ConvTranspose2d
|
||||
ConvTranspose3d
|
||||
Dropout
|
||||
Dropout2d
|
||||
Dropout3d
|
||||
Embedding
|
||||
ELU
|
||||
GELU
|
||||
GLU
|
||||
GroupNorm
|
||||
RoPE
|
||||
GRU
|
||||
HardShrink
|
||||
HardTanh
|
||||
Hardswish
|
||||
InstanceNorm
|
||||
LayerNorm
|
||||
LeakyReLU
|
||||
Linear
|
||||
LogSigmoid
|
||||
LogSoftmax
|
||||
LSTM
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
MaxPool3d
|
||||
Mish
|
||||
MultiHeadAttention
|
||||
Sequential
|
||||
PReLU
|
||||
QuantizedEmbedding
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU6
|
||||
RNN
|
||||
RoPE
|
||||
SELU
|
||||
Sequential
|
||||
Sigmoid
|
||||
SiLU
|
||||
SinusoidalPositionalEncoding
|
||||
Softmin
|
||||
Softshrink
|
||||
Softsign
|
||||
Softmax
|
||||
Softplus
|
||||
Step
|
||||
Tanh
|
||||
Transformer
|
||||
Upsample
|
||||
|
@@ -10,9 +10,15 @@ Loss Functions
|
||||
:template: nn-module-template.rst
|
||||
|
||||
binary_cross_entropy
|
||||
cosine_similarity_loss
|
||||
cross_entropy
|
||||
gaussian_nll_loss
|
||||
hinge_loss
|
||||
huber_loss
|
||||
kl_div_loss
|
||||
l1_loss
|
||||
log_cosh_loss
|
||||
margin_ranking_loss
|
||||
mse_loss
|
||||
nll_loss
|
||||
smooth_l1_loss
|
||||
|
38
docs/src/python/nn/module.rst
Normal file
38
docs/src/python/nn/module.rst
Normal file
@@ -0,0 +1,38 @@
|
||||
Module
|
||||
======
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
.. autoclass:: Module
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Module.training
|
||||
Module.state
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Module.apply
|
||||
Module.apply_to_modules
|
||||
Module.children
|
||||
Module.eval
|
||||
Module.filter_and_map
|
||||
Module.freeze
|
||||
Module.leaf_modules
|
||||
Module.load_weights
|
||||
Module.modules
|
||||
Module.named_modules
|
||||
Module.parameters
|
||||
Module.save_weights
|
||||
Module.set_dtype
|
||||
Module.train
|
||||
Module.trainable_parameters
|
||||
Module.unfreeze
|
||||
Module.update
|
||||
Module.update_modules
|
@@ -5,13 +5,14 @@ Operations
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
abs
|
||||
add
|
||||
addmm
|
||||
all
|
||||
allclose
|
||||
allclose
|
||||
any
|
||||
arange
|
||||
arccos
|
||||
@@ -19,36 +20,80 @@ Operations
|
||||
arcsin
|
||||
arcsinh
|
||||
arctan
|
||||
arctan2
|
||||
arctanh
|
||||
argmax
|
||||
argmin
|
||||
argpartition
|
||||
argsort
|
||||
array_equal
|
||||
as_strided
|
||||
atleast_1d
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
bitwise_and
|
||||
bitwise_invert
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
broadcast_arrays
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
contiguous
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
conv1d
|
||||
conv2d
|
||||
conv3d
|
||||
conv_transpose1d
|
||||
conv_transpose2d
|
||||
conv_transpose3d
|
||||
conv_general
|
||||
cos
|
||||
cosh
|
||||
cummax
|
||||
cummin
|
||||
cumprod
|
||||
cumsum
|
||||
degrees
|
||||
dequantize
|
||||
diag
|
||||
diagonal
|
||||
divide
|
||||
divmod
|
||||
einsum
|
||||
einsum_path
|
||||
equal
|
||||
erf
|
||||
erfinv
|
||||
exp
|
||||
expm1
|
||||
expand_dims
|
||||
eye
|
||||
flatten
|
||||
floor
|
||||
floor_divide
|
||||
full
|
||||
gather_mm
|
||||
gather_qmm
|
||||
greater
|
||||
greater_equal
|
||||
hadamard_transform
|
||||
identity
|
||||
imag
|
||||
inner
|
||||
isfinite
|
||||
isclose
|
||||
isinf
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
issubdtype
|
||||
kron
|
||||
left_shift
|
||||
less
|
||||
less_equal
|
||||
linspace
|
||||
@@ -58,35 +103,54 @@ Operations
|
||||
log10
|
||||
log1p
|
||||
logaddexp
|
||||
logcumsumexp
|
||||
logical_not
|
||||
logical_and
|
||||
logical_or
|
||||
logsumexp
|
||||
matmul
|
||||
max
|
||||
maximum
|
||||
mean
|
||||
meshgrid
|
||||
min
|
||||
minimum
|
||||
moveaxis
|
||||
multiply
|
||||
nan_to_num
|
||||
negative
|
||||
not_equal
|
||||
ones
|
||||
ones_like
|
||||
outer
|
||||
partition
|
||||
pad
|
||||
power
|
||||
prod
|
||||
put_along_axis
|
||||
quantize
|
||||
quantized_matmul
|
||||
radians
|
||||
real
|
||||
reciprocal
|
||||
remainder
|
||||
repeat
|
||||
reshape
|
||||
right_shift
|
||||
roll
|
||||
round
|
||||
rsqrt
|
||||
save
|
||||
savez
|
||||
savez_compressed
|
||||
save_gguf
|
||||
save_safetensors
|
||||
sigmoid
|
||||
sign
|
||||
sin
|
||||
sinh
|
||||
slice
|
||||
slice_update
|
||||
softmax
|
||||
sort
|
||||
split
|
||||
@@ -94,6 +158,7 @@ Operations
|
||||
square
|
||||
squeeze
|
||||
stack
|
||||
std
|
||||
stop_gradient
|
||||
subtract
|
||||
sum
|
||||
@@ -102,11 +167,17 @@ Operations
|
||||
take_along_axis
|
||||
tan
|
||||
tanh
|
||||
tensordot
|
||||
tile
|
||||
topk
|
||||
trace
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
unflatten
|
||||
var
|
||||
view
|
||||
where
|
||||
zeros
|
||||
zeros_like
|
||||
|
@@ -1,5 +1,7 @@
|
||||
.. _optimizers:
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
Optimizers
|
||||
==========
|
||||
|
||||
@@ -29,19 +31,48 @@ model's parameters and the **optimizer state**.
|
||||
# Compute the new parameters but also the optimizer state.
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
Saving and Loading
|
||||
------------------
|
||||
|
||||
To serialize an optimizer, save its state. To load an optimizer, load and set
|
||||
the saved state. Here's a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
import mlx.optimizers as optim
|
||||
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
# Perform some updates with the optimizer
|
||||
model = {"w" : mx.zeros((5, 5))}
|
||||
grads = {"w" : mx.ones((5, 5))}
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Save the state
|
||||
state = tree_flatten(optimizer.state)
|
||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||
|
||||
# Later on, for example when loading from a checkpoint,
|
||||
# recreate the optimizer and load the state
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||
optimizer.state = state
|
||||
|
||||
Note, not every optimizer configuation parameter is saved in the state. For
|
||||
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
|
||||
parameters are not. A good rule of thumb is if the parameter can be scheduled
|
||||
then it will be included in the optimizer state.
|
||||
|
||||
.. toctree::
|
||||
|
||||
optimizers/optimizer
|
||||
optimizers/common_optimizers
|
||||
optimizers/schedulers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: optimizers-template.rst
|
||||
|
||||
OptimizerState
|
||||
Optimizer
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
AdaDelta
|
||||
Adam
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
clip_grad_norm
|
||||
|
21
docs/src/python/optimizers/common_optimizers.rst
Normal file
21
docs/src/python/optimizers/common_optimizers.rst
Normal file
@@ -0,0 +1,21 @@
|
||||
.. _common_optimizers:
|
||||
|
||||
Common Optimizers
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: optimizers-template.rst
|
||||
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
Adafactor
|
||||
AdaDelta
|
||||
Adam
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
23
docs/src/python/optimizers/optimizer.rst
Normal file
23
docs/src/python/optimizers/optimizer.rst
Normal file
@@ -0,0 +1,23 @@
|
||||
Optimizer
|
||||
=========
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autoclass:: Optimizer
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Optimizer.state
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Optimizer.apply_gradients
|
||||
Optimizer.init
|
||||
Optimizer.update
|
15
docs/src/python/optimizers/schedulers.rst
Normal file
15
docs/src/python/optimizers/schedulers.rst
Normal file
@@ -0,0 +1,15 @@
|
||||
.. _schedulers:
|
||||
|
||||
Schedulers
|
||||
==========
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
cosine_decay
|
||||
exponential_decay
|
||||
join_schedules
|
||||
linear_schedule
|
||||
step_decay
|
@@ -33,13 +33,16 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
seed
|
||||
key
|
||||
split
|
||||
bernoulli
|
||||
categorical
|
||||
gumbel
|
||||
key
|
||||
normal
|
||||
multivariate_normal
|
||||
randint
|
||||
uniform
|
||||
seed
|
||||
split
|
||||
truncated_normal
|
||||
uniform
|
||||
laplace
|
||||
permutation
|
||||
|
@@ -9,9 +9,13 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
async_eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
enable_compile
|
||||
grad
|
||||
value_and_grad
|
||||
jvp
|
||||
vjp
|
||||
vmap
|
||||
simplify
|
||||
|
@@ -19,3 +19,5 @@ return python trees will be using the default python ``dict``, ``list`` and
|
||||
tree_flatten
|
||||
tree_unflatten
|
||||
tree_map
|
||||
tree_map_with_path
|
||||
tree_reduce
|
||||
|
497
docs/src/usage/compile.rst
Normal file
497
docs/src/usage/compile.rst
Normal file
@@ -0,0 +1,497 @@
|
||||
.. _compile:
|
||||
|
||||
Compilation
|
||||
===========
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX has a :func:`compile` function transformation which compiles computation
|
||||
graphs. Function compilation results in smaller graphs by merging common work
|
||||
and fusing certain operations. In many cases this can lead to big improvements
|
||||
in run-time and memory use.
|
||||
|
||||
Getting started with :func:`compile` is simple, but there are some edge cases
|
||||
that are good to be aware of for more complex graphs and advanced usage.
|
||||
|
||||
Basics of Compile
|
||||
-----------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(-x) + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(2.0)
|
||||
|
||||
# Regular call, no compilation
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(fun(x, y))
|
||||
|
||||
# Compile the function
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
The output of both the regular function and the compiled function is the same
|
||||
up to numerical precision.
|
||||
|
||||
The first time you call a compiled function, MLX will build the compute
|
||||
graph, optimize it, and generate and compile code. This can be relatively
|
||||
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||
function multiple times will not initiate a new compilation. This means you
|
||||
should typically compile functions that you plan to use more than once.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(-x) + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(2.0)
|
||||
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Compiled here
|
||||
compiled_fun(x, y)
|
||||
|
||||
# Not compiled again
|
||||
compiled_fun(x, y)
|
||||
|
||||
# Not compiled again
|
||||
mx.compile(fun)(x, y)
|
||||
|
||||
There are some important cases to be aware of that can cause a function to
|
||||
be recompiled:
|
||||
|
||||
* Changing the shape or number of dimensions
|
||||
* Changing the type of any of the inputs
|
||||
* Changing the number of inputs to the function
|
||||
|
||||
In certain cases only some of the compilation stack will be rerun (for
|
||||
example when changing the shapes) and in other cases the full compilation
|
||||
stack will be rerun (for example when changing the types). In general you
|
||||
should avoid compiling functions too frequently.
|
||||
|
||||
Another idiom to watch out for is compiling functions which get created and
|
||||
destroyed frequently. This can happen, for example, when compiling an anonymous
|
||||
function in a loop:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
a = mx.array(1.0)
|
||||
# Don't do this, compiles lambda at each iteration
|
||||
for _ in range(5):
|
||||
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
|
||||
|
||||
Example Speedup
|
||||
---------------
|
||||
|
||||
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
|
||||
Transformer-based models. The implementation involves several unary and binary
|
||||
element-wise operations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
If you use this function with small arrays, it will be overhead bound. If you
|
||||
use it with large arrays it will be memory bandwidth bound. However, all of
|
||||
the operations in the ``gelu`` are fusible into a single kernel with
|
||||
:func:`compile`. This can speedup both cases considerably.
|
||||
|
||||
Let's compare the runtime of the regular function versus the compiled
|
||||
function. We'll use the following timing helper which does a warm up and
|
||||
handles synchronization:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
|
||||
def timeit(fun, x):
|
||||
# warm up
|
||||
for _ in range(10):
|
||||
mx.eval(fun(x))
|
||||
|
||||
tic = time.perf_counter()
|
||||
for _ in range(100):
|
||||
mx.eval(fun(x))
|
||||
toc = time.perf_counter()
|
||||
tpi = 1e3 * (toc - tic) / 100
|
||||
print(f"Time per iteration {tpi:.3f} (ms)")
|
||||
|
||||
|
||||
Now make an array, and benchmark both functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||
timeit(nn.gelu, x)
|
||||
timeit(mx.compile(nn.gelu), x)
|
||||
|
||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||
five times faster.
|
||||
|
||||
Debugging
|
||||
---------
|
||||
|
||||
When a compiled function is first called, it is traced with placeholder
|
||||
inputs. This means you can't evaluate arrays (for example to print their
|
||||
contents) inside compiled functions.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
z = -x
|
||||
print(z) # Crash
|
||||
return mx.exp(z)
|
||||
|
||||
fun(mx.array(5.0))
|
||||
|
||||
For debugging, inspecting arrays can be helpful. One way to do that is to
|
||||
globally disable compilation using the :func:`disable_compile` function or
|
||||
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
|
||||
``fun`` is compiled:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
z = -x
|
||||
print(z) # Okay
|
||||
return mx.exp(z)
|
||||
|
||||
mx.disable_compile()
|
||||
fun(mx.array(5.0))
|
||||
|
||||
|
||||
Pure Functions
|
||||
--------------
|
||||
|
||||
Compiled functions are intended to be *pure*; that is they should not have side
|
||||
effects. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = []
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z)
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Crash!
|
||||
print(state)
|
||||
|
||||
After the first call of ``fun``, the ``state`` list will hold a placeholder
|
||||
array. The placeholder does not have any data; it is only used to build the
|
||||
computation graph. Printing such an array results in a crash.
|
||||
|
||||
You have two options to deal with this. The first option is to simply return
|
||||
``state`` as an output:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = []
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z), state
|
||||
|
||||
_, state = fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
print(state)
|
||||
|
||||
In some cases returning updated state can be pretty inconvenient. Hence,
|
||||
:func:`compile` has a parameter to capture implicit outputs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from functools import partial
|
||||
|
||||
state = []
|
||||
|
||||
# Tell compile to capture state as an output
|
||||
@partial(mx.compile, outputs=state)
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z), state
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
print(state)
|
||||
|
||||
This is particularly useful for compiling a function which includes an update
|
||||
to a container of arrays, as is commonly done when training the parameters of a
|
||||
:class:`mlx.nn.Module`.
|
||||
|
||||
Compiled functions will also treat any inputs not in the parameter list as
|
||||
constants. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = [mx.array(1.0)]
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return x + state[0]
|
||||
|
||||
# Prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
# Update state
|
||||
state[0] = mx.array(5.0)
|
||||
|
||||
# Still prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
In order to have the change of state reflected in the outputs of ``fun`` you
|
||||
again have two options. The first option is to simply pass ``state`` as input
|
||||
to the function. In some cases this can be pretty inconvenient. Hence,
|
||||
:func:`compile` also has a parameter to capture implicit inputs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from functools import partial
|
||||
state = [mx.array(1.0)]
|
||||
|
||||
# Tell compile to capture state as an input
|
||||
@partial(mx.compile, inputs=state)
|
||||
def fun(x):
|
||||
return x + state[0]
|
||||
|
||||
# Prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
# Update state
|
||||
state[0] = mx.array(5.0)
|
||||
|
||||
# Prints array(6, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
|
||||
Compiling Training Graphs
|
||||
-------------------------
|
||||
|
||||
This section will step through how to use :func:`compile` with a simple example
|
||||
of a common setup: training a model with :obj:`mlx.nn.Module` using an
|
||||
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
|
||||
full forward, backward, and update with :func:`compile`.
|
||||
|
||||
To start, here is the simple example without any compilation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
# 4 examples with 10 features each
|
||||
x = mx.random.uniform(shape=(4, 10))
|
||||
|
||||
# 0, 1 targets
|
||||
y = mx.array([0, 1, 0, 1])
|
||||
|
||||
# Simple linear model
|
||||
model = nn.Linear(10, 1)
|
||||
|
||||
# SGD with momentum
|
||||
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||
|
||||
def loss_fn(model, x, y):
|
||||
logits = model(x).squeeze()
|
||||
return nn.losses.binary_cross_entropy(logits, y)
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
# Perform 10 steps of gradient descent
|
||||
for it in range(10):
|
||||
loss, grads = loss_and_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
To compile the update we can put it all in a function and compile it with the
|
||||
appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from functools import partial
|
||||
|
||||
# 4 examples with 10 features each
|
||||
x = mx.random.uniform(shape=(4, 10))
|
||||
|
||||
# 0, 1 targets
|
||||
y = mx.array([0, 1, 0, 1])
|
||||
|
||||
# Simple linear model
|
||||
model = nn.Linear(10, 1)
|
||||
|
||||
# SGD with momentum
|
||||
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||
|
||||
def loss_fn(model, x, y):
|
||||
logits = model(x).squeeze()
|
||||
return nn.losses.binary_cross_entropy(logits, y)
|
||||
|
||||
# The state that will be captured as input and output
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x, y):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
# Perform 10 steps of gradient descent
|
||||
for it in range(10):
|
||||
loss = step(x, y)
|
||||
# Evaluate the model and optimizer state
|
||||
mx.eval(state)
|
||||
print(loss)
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
If you are using a module which performs random sampling such as
|
||||
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
|
||||
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
|
||||
optimizer.state, mx.random.state]``.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
For more examples of compiling full training graphs checkout the `MLX
|
||||
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
|
||||
|
||||
Transformations with Compile
|
||||
----------------------------
|
||||
|
||||
In MLX function transformations are composable. You can apply any function
|
||||
transformation to the output of any other function transformation. For more on
|
||||
this, see the documentation on :ref:`function transforms
|
||||
<function_transforms>`.
|
||||
|
||||
Compiling transformed functions works just as expected:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
grad_fn = mx.grad(mx.exp)
|
||||
|
||||
compiled_grad_fn = mx.compile(grad_fn)
|
||||
|
||||
# Prints: array(2.71828, dtype=float32)
|
||||
print(grad_fn(mx.array(1.0)))
|
||||
|
||||
# Also prints: array(2.71828, dtype=float32)
|
||||
print(compiled_grad_fn(mx.array(1.0)))
|
||||
|
||||
.. note::
|
||||
|
||||
In order to compile as much as possible, a transformation of a compiled
|
||||
function will not by default be compiled. To compile the transformed
|
||||
function simply pass it through :func:`compile`.
|
||||
|
||||
You can also compile functions which themselves call compiled functions. A
|
||||
good practice is to compile the outer most function to give :func:`compile`
|
||||
the most opportunity to optimize the computation graph:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def inner(x):
|
||||
return mx.exp(-mx.abs(x))
|
||||
|
||||
def outer(x):
|
||||
inner(inner(x))
|
||||
|
||||
# Compiling the outer function is good to do as it will likely
|
||||
# be faster even though the inner functions are compiled
|
||||
fun = mx.compile(outer)
|
||||
|
||||
|
||||
|
||||
.. _shapeless_compile:
|
||||
|
||||
Shapeless Compilation
|
||||
---------------------
|
||||
|
||||
When the shape of an input to a compiled function changes, the function is
|
||||
recompiled. You can compile a function once and run it on inputs with
|
||||
variable shapes by specifying ``shapeless=True`` to :func:`compile`. In this
|
||||
case changes to the shapes of the inputs do not cause the function to be
|
||||
recompiled.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.abs(x + y)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(-2.0)
|
||||
|
||||
# Firt call compiles the function
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
# Second call with different shapes
|
||||
# does not recompile the function
|
||||
x = mx.array([1.0, -6.0])
|
||||
y = mx.array([-2.0, 3.0])
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
|
||||
Use shapeless compilations carefully. Since compilation is not triggered when
|
||||
shapes change, any graphs which are conditional on the input shapes will not
|
||||
work as expected. Shape-dependent computations are common and sometimes subtle
|
||||
to detect. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return x.reshape(x.shape[0] * x.shape[1], -1)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.random.uniform(shape=(2, 3, 4))
|
||||
|
||||
out = compiled_fun(x)
|
||||
|
||||
x = mx.random.uniform(shape=(5, 5, 3))
|
||||
|
||||
# Error, can't reshape (5, 5, 3) to (6, -1)
|
||||
out = compiled_fun(x)
|
||||
|
||||
The second call to the ``compiled_fun`` fails because of the call to
|
||||
:func:`reshape` which uses the static shape of ``x`` in the first call. We can
|
||||
fix this by using :func:`flatten` to avoid hardcoding the shape of ``x``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return x.flatten(0, 1)
|
||||
|
||||
compiled_fun = mx.compile(fun, shapeless=True)
|
||||
|
||||
x = mx.random.uniform(shape=(2, 3, 4))
|
||||
|
||||
out = compiled_fun(x)
|
||||
|
||||
x = mx.random.uniform(shape=(5, 5, 3))
|
||||
|
||||
# Ok
|
||||
out = compiled_fun(x)
|
344
docs/src/usage/distributed.rst
Normal file
344
docs/src/usage/distributed.rst
Normal file
@@ -0,0 +1,344 @@
|
||||
.. _usage_distributed:
|
||||
|
||||
Distributed Communication
|
||||
=========================
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
MLX supports distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. At the
|
||||
moment we support two different communication backends:
|
||||
|
||||
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
||||
full-featured and mature distributed communications library
|
||||
* A **ring** backend of our own that uses native TCP sockets and should be
|
||||
faster for thunderbolt connections.
|
||||
|
||||
The list of all currently supported operations and their documentation can be
|
||||
seen in the :ref:`API docs<distributed>`.
|
||||
|
||||
.. note::
|
||||
Some operations may not be supported or not as fast as they should be.
|
||||
We are adding more and tuning the ones we have as we are figuring out the
|
||||
best way to do distributed computing on Macs using MLX.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
A distributed program in MLX is as simple as:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
world = mx.distributed.init()
|
||||
x = mx.distributed.all_sum(mx.ones(10))
|
||||
print(world.rank(), x)
|
||||
|
||||
The program above sums the array ``mx.ones(10)`` across all
|
||||
distributed processes. However, when this script is run with ``python`` only
|
||||
one process is launched and no distributed communication takes place. Namely,
|
||||
all operations in ``mx.distributed`` are noops when the distributed group has a
|
||||
size of one. This property allows us to avoid code that checks if we are in a
|
||||
distributed setting similar to the one below:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
x = ...
|
||||
world = mx.distributed.init()
|
||||
# No need for the check we can simply do x = mx.distributed.all_sum(x)
|
||||
if world.size() > 1:
|
||||
x = mx.distributed.all_sum(x)
|
||||
|
||||
Running Distributed Programs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
|
||||
Continuing with our initial example we can run it on localhost with 4 processes using
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mlx.launch -n 4 my_script.py
|
||||
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
|
||||
We can also run it on some remote hosts by providing their IPs (provided that
|
||||
the script exists on all hosts and they are reachable by ssh)
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
|
||||
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
|
||||
Consult the dedicated :doc:`usage guide<launching_distributed>` for more
|
||||
information on using ``mlx.launch``.
|
||||
|
||||
Selecting Backend
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
You can select the backend you want to use when calling :func:`init` by passing
|
||||
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
|
||||
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
|
||||
both fail then a singleton group is created.
|
||||
|
||||
.. note::
|
||||
After a distributed backend is successfully initialized :func:`init` will
|
||||
return **the same backend** if called without arguments or with backend set to
|
||||
``any``.
|
||||
|
||||
The following examples aim to clarify the backend initialization logic in MLX:
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
|
||||
world = mx.distributed.init(backend="mpi")
|
||||
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
|
||||
|
||||
# Case 2: Initialize any backend
|
||||
world = mx.distributed.init(backend="any") # equivalent to no arguments
|
||||
world2 = mx.distributed.init() # same as above
|
||||
|
||||
# Case 3: Initialize both backends at the same time
|
||||
world_mpi = mx.distributed.init(backend="mpi")
|
||||
world_ring = mx.distributed.init(backend="ring")
|
||||
world_any = mx.distributed.init() # same as MPI because it was initialized first!
|
||||
|
||||
Training Example
|
||||
----------------
|
||||
|
||||
In this section we will adapt an MLX training loop to support data parallel
|
||||
distributed training. Namely, we will average the gradients across a set of
|
||||
hosts before applying them to the model.
|
||||
|
||||
Our training loop looks like the following code snippet if we omit the model,
|
||||
dataset and optimizer initialization.
|
||||
|
||||
.. code:: python
|
||||
|
||||
model = ...
|
||||
optimizer = ...
|
||||
dataset = ...
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(model, x, y)
|
||||
mx.eval(loss, model.parameters())
|
||||
|
||||
All we have to do to average the gradients across machines is perform an
|
||||
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
|
||||
have to :func:`mlx.utils.tree_map` the gradients with following function.
|
||||
|
||||
.. code:: python
|
||||
|
||||
def all_avg(x):
|
||||
return mx.distributed.all_sum(x) / mx.distributed.init().size()
|
||||
|
||||
Putting everything together our training loop step looks as follows with
|
||||
everything else remaining the same.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def all_reduce_grads(grads):
|
||||
N = mx.distributed.init().size()
|
||||
if N == 1:
|
||||
return grads
|
||||
return tree_map(
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads
|
||||
)
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = all_reduce_grads(grads) # <--- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
Utilizing ``nn.average_gradients``
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Although the code example above works correctly; it performs one communication
|
||||
per gradient. It is significantly more efficient to aggregate several gradients
|
||||
together and perform fewer communication steps.
|
||||
|
||||
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
|
||||
almost identical to the example above:
|
||||
|
||||
.. code:: python
|
||||
|
||||
model = ...
|
||||
optimizer = ...
|
||||
dataset = ...
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(model, x, y)
|
||||
mx.eval(loss, model.parameters())
|
||||
|
||||
|
||||
Getting Started with MPI
|
||||
------------------------
|
||||
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. Launching distributed MLX programs that use MPI can be done with
|
||||
``mpirun`` as expected. However, in the following examples we will be using
|
||||
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
|
||||
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
|
||||
library.
|
||||
|
||||
The simplest possible usage is the following which, assuming the minimal
|
||||
example in the beginning of this page, should result in:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mlx.launch --backend mpi -n 2 test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ conda install conda-forge::openmpi
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
||||
done automatically by ``mlx.launch``.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
$ # or simply
|
||||
$ mlx.launch -n 2 test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
|
||||
Tuning MPI All Reduce
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. note::
|
||||
|
||||
For faster all reduce consider using the ring backend either with Thunderbolt
|
||||
connections or over Ethernet.
|
||||
|
||||
Configure MPI to use N tcp connections between each host to improve bandwidth
|
||||
by passing ``--mca btl_tcp_links N``.
|
||||
|
||||
Force MPI to use the most performant network interface by setting ``--mca
|
||||
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
|
||||
to use.
|
||||
|
||||
Getting Started with Ring
|
||||
-------------------------
|
||||
|
||||
The ring backend does not depend on any third party library so it is always
|
||||
available. It uses TCP sockets so the nodes need to be reachable via a network.
|
||||
As the name suggests the nodes are connected in a ring which means that rank 1
|
||||
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
|
||||
and so on and so forth. As a result :func:`send` and :func:`recv` with
|
||||
arbitrary sender and receiver is not supported in the ring backend.
|
||||
|
||||
Defining a Ring
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
The easiest way to define and use a ring is via a JSON hostfile and the
|
||||
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
|
||||
defines a hostname to ssh into to run commands on this node and one or more IPs
|
||||
that this node will listen to for connections.
|
||||
|
||||
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
|
||||
rank 0, ``hostname2`` rank 1 etc.
|
||||
|
||||
.. code:: json
|
||||
|
||||
[
|
||||
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
||||
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
||||
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
|
||||
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
|
||||
]
|
||||
|
||||
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
|
||||
node, run the script which will listen for connections in each of the provided
|
||||
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
|
||||
connection from ``123.123.123.4`` and so on and so forth.
|
||||
|
||||
Thunderbolt Ring
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Although the ring backend can have benefits over MPI even for Ethernet, its
|
||||
main purpose is to use Thunderbolt rings for higher bandwidth communication.
|
||||
Setting up such thunderbolt rings can be done manually, but is a relatively
|
||||
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
|
||||
|
||||
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
|
||||
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
|
||||
utility as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
|
||||
|
||||
By default the script will attempt to discover the thunderbolt ring and provide
|
||||
you with the commands to configure each node as well as the ``hostfile.json``
|
||||
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
|
||||
then ``--auto-setup`` can be used to configure them automatically.
|
||||
|
||||
To validate your connection without configuring anything
|
||||
``mlx.distributed_config`` can also plot the ring using DOT format.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
|
||||
dot -Tpng ring.dot >ring.png
|
||||
open ring.png
|
||||
|
||||
If you want to go through the process manually, the steps are as follows:
|
||||
|
||||
* Disable the thunderbolt bridge interface
|
||||
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
|
||||
corresponding to that cable in nodes ``i`` and ``i + 1``.
|
||||
* Set up a unique subnetwork connecting the two nodes for the corresponding
|
||||
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
|
||||
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
|
||||
``192.168.0.2`` respectively to the two nodes. For more details you can see
|
||||
the commands prepared by the utility script.
|
288
docs/src/usage/export.rst
Normal file
288
docs/src/usage/export.rst
Normal file
@@ -0,0 +1,288 @@
|
||||
.. _export_usage:
|
||||
|
||||
Exporting Functions
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX has an API to export and import functions to and from a file. This lets you
|
||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||
front-end (e.g. C++).
|
||||
|
||||
This guide walks through the basics of the MLX export API with some examples.
|
||||
To see the full list of functions check-out the :ref:`API documentation
|
||||
<export>`.
|
||||
|
||||
Basics of Exporting
|
||||
-------------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
To export a function, provide sample input arrays that the function
|
||||
can be called with. The data doesn't matter, but the shapes and types of the
|
||||
arrays do. In the above example we exported ``fun`` with two ``float32``
|
||||
scalar arrays. We can then import the function and run it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
add_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
out, = add_fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints: array(3, dtype=float32)
|
||||
print(out)
|
||||
|
||||
out, = add_fun(mx.array(1.0), mx.array(3.0))
|
||||
# Prints: array(4, dtype=float32)
|
||||
print(out)
|
||||
|
||||
# Raises an exception
|
||||
add_fun(mx.array(1), mx.array(3.0))
|
||||
|
||||
# Raises an exception
|
||||
add_fun(mx.array([1.0, 2.0]), mx.array(3.0))
|
||||
|
||||
Notice the third and fourth calls to ``add_fun`` raise exceptions because the
|
||||
shapes and types of the inputs are different than the shapes and types of the
|
||||
example inputs we exported the function with.
|
||||
|
||||
Also notice that even though the original ``fun`` returns a single output
|
||||
array, the imported function always returns a tuple of one or more arrays.
|
||||
|
||||
The inputs to :func:`export_function` and to an imported function can be
|
||||
specified as variable positional arguments or as a tuple of arrays:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
# Both arguments to fun are positional
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
# Same as above
|
||||
mx.export_function("add.mlxfn", fun, (x, y))
|
||||
|
||||
imported_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_fun(x, y)
|
||||
|
||||
# Also ok
|
||||
out, = imported_fun((x, y))
|
||||
|
||||
You can pass example inputs to functions as positional or keyword arguments. If
|
||||
you use keyword arguments to export the function, then you have to use the same
|
||||
keyword arguments when calling the imported function.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
# One argument to fun is positional, the other is a kwarg
|
||||
mx.export_function("add.mlxfn", fun, x, y=y)
|
||||
|
||||
imported_fun = mx.import_function("add.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_fun(x, y=y)
|
||||
|
||||
# Also ok
|
||||
out, = imported_fun((x,), {"y": y})
|
||||
|
||||
# Raises since the keyword argument is missing
|
||||
out, = imported_fun(x, y)
|
||||
|
||||
# Raises since the keyword argument has the wrong key
|
||||
out, = imported_fun(x, z=y)
|
||||
|
||||
|
||||
Exporting Modules
|
||||
-----------------
|
||||
|
||||
An :obj:`mlx.nn.Module` can be exported with or without the parameters included
|
||||
in the exported function. Here's an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Linear(4, 4)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def call(x):
|
||||
return model(x)
|
||||
|
||||
mx.export_function("model.mlxfn", call, mx.zeros(4))
|
||||
|
||||
In the above example, the :obj:`mlx.nn.Linear` module is exported. Its
|
||||
parameters are also saved to the ``model.mlxfn`` file.
|
||||
|
||||
.. note::
|
||||
|
||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||
they are evaluated. The computation graph that gets exported will include
|
||||
the computation that produces enclosed inputs.
|
||||
|
||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||
exported function would include the random initialization of the
|
||||
:obj:`mlx.nn.Module` parameters.
|
||||
|
||||
If you only want to export the ``Module.__call__`` function without the
|
||||
parameters, pass them as inputs to the ``call`` wrapper:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = nn.Linear(4, 4)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def call(x, **params):
|
||||
# Set the model's parameters to the input parameters
|
||||
model.update(tree_unflatten(list(params.items())))
|
||||
return model(x)
|
||||
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||
|
||||
|
||||
Shapeless Exports
|
||||
-----------------
|
||||
|
||||
Just like :func:`compile`, functions can also be exported for dynamically shaped
|
||||
inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter`
|
||||
to export a function which can be used for inputs with variable shapes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True)
|
||||
imported_abs = mx.import_function("fun.mlxfn")
|
||||
|
||||
# Ok
|
||||
out, = imported_abs(mx.array(-1.0))
|
||||
|
||||
# Also ok
|
||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||
|
||||
With ``shapeless=False`` (which is the default), the second call to
|
||||
``imported_abs`` would raise an exception with a shape mismatch.
|
||||
|
||||
Shapeless exporting works the same as shapeless compilation and should be
|
||||
used carefully. See the :ref:`documentation on shapeless compilation
|
||||
<shapeless_compile>` for more information.
|
||||
|
||||
Exporting Multiple Traces
|
||||
-------------------------
|
||||
|
||||
In some cases, functions build different computation graphs for different
|
||||
input arguments. A simple way to manage this is to export to a new file with
|
||||
each set of inputs. This is a fine option in many cases. But it can be
|
||||
suboptimal if the exported functions have a large amount of duplicate constant
|
||||
data (for example the parameters of a :obj:`mlx.nn.Module`).
|
||||
|
||||
The export API in MLX lets you export multiple traces of the same function to
|
||||
a single file by creating an exporting context manager with :func:`exporter`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y=None):
|
||||
constant = mx.array(3.0)
|
||||
if y is not None:
|
||||
x += y
|
||||
return x + constant
|
||||
|
||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||
exporter(mx.array(1.0))
|
||||
exporter(mx.array(1.0), y=mx.array(0.0))
|
||||
|
||||
imported_function = mx.import_function("fun.mlxfn")
|
||||
|
||||
# Call the function with y=None
|
||||
out, = imported_function(mx.array(1.0))
|
||||
print(out)
|
||||
|
||||
# Call the function with y specified
|
||||
out, = imported_function(mx.array(1.0), y=mx.array(1.0))
|
||||
print(out)
|
||||
|
||||
In the above example the function constant data, (i.e. ``constant``), is only
|
||||
saved once.
|
||||
|
||||
Transformations with Imported Functions
|
||||
---------------------------------------
|
||||
|
||||
Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work
|
||||
on imported functions just like regular Python functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x):
|
||||
return mx.sin(x)
|
||||
|
||||
x = mx.array(0.0)
|
||||
mx.export_function("sine.mlxfn", fun, x)
|
||||
|
||||
imported_fun = mx.import_function("sine.mlxfn")
|
||||
|
||||
# Take the derivative of the imported function
|
||||
dfdx = mx.grad(lambda x: imported_fun(x)[0])
|
||||
# Prints: array(1, dtype=float32)
|
||||
print(dfdx(x))
|
||||
|
||||
# Compile the imported function
|
||||
mx.compile(imported_fun)
|
||||
# Prints: array(0, dtype=float32)
|
||||
print(compiled_fun(x)[0])
|
||||
|
||||
|
||||
Importing Functions in C++
|
||||
--------------------------
|
||||
|
||||
Importing and running functions in C++ is basically the same as importing and
|
||||
running them in Python. First, follow the :ref:`instructions <mlx_in_cpp>` to
|
||||
setup a simple C++ project that uses MLX as a library.
|
||||
|
||||
Next, export a simple function from Python:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(x + y)
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
mx.export_function("fun.mlxfn", fun, x, y)
|
||||
|
||||
|
||||
Import and run the function in C++ with only a few lines of code:
|
||||
|
||||
.. code-block:: c++
|
||||
|
||||
auto fun = mx::import_function("fun.mlxfn");
|
||||
|
||||
auto inputs = {mx::array(1.0), mx::array(1.0)};
|
||||
auto outputs = fun(inputs);
|
||||
|
||||
// Prints: array(2, dtype=float32)
|
||||
std::cout << outputs[0] << std::endl;
|
||||
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||
|
||||
More Examples
|
||||
-------------
|
||||
|
||||
Here are a few more complete examples exporting more complex functions from
|
||||
Python and importing and running them in C++:
|
||||
|
||||
* `Inference and training a multi-layer perceptron <https://github.com/ml-explore/mlx/tree/main/examples/export>`_
|
191
docs/src/usage/function_transforms.rst
Normal file
191
docs/src/usage/function_transforms.rst
Normal file
@@ -0,0 +1,191 @@
|
||||
.. _function_transforms:
|
||||
|
||||
Function Transforms
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX uses composable function transformations for automatic differentiation,
|
||||
vectorization, and compute graph optimizations. To see the complete list of
|
||||
function transformations check-out the :ref:`API documentation <transforms>`.
|
||||
|
||||
The key idea behind composable function transformations is that every
|
||||
transformation returns a function which can be further transformed.
|
||||
|
||||
Here is a simple example:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> dfdx = mx.grad(mx.sin)
|
||||
>>> dfdx(mx.array(mx.pi))
|
||||
array(-1, dtype=float32)
|
||||
>>> mx.cos(mx.array(mx.pi))
|
||||
array(-1, dtype=float32)
|
||||
|
||||
|
||||
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
||||
case it is the gradient of the sine function which is exactly the cosine
|
||||
function. To get the second derivative you can do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> d2fdx2 = mx.grad(mx.grad(mx.sin))
|
||||
>>> d2fdx2(mx.array(mx.pi / 2))
|
||||
array(-1, dtype=float32)
|
||||
>>> mx.sin(mx.array(mx.pi / 2))
|
||||
array(1, dtype=float32)
|
||||
|
||||
Using :func:`grad` on the output of :func:`grad` is always ok. You keep
|
||||
getting higher order derivatives.
|
||||
|
||||
Any of the MLX function transformations can be composed in any order to any
|
||||
depth. See the following sections for more information on :ref:`automatic
|
||||
differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.
|
||||
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
|
||||
|
||||
|
||||
Automatic Differentiation
|
||||
-------------------------
|
||||
|
||||
.. _auto diff:
|
||||
|
||||
Automatic differentiation in MLX works on functions rather than on implicit
|
||||
graphs.
|
||||
|
||||
.. note::
|
||||
|
||||
If you are coming to MLX from PyTorch, you no longer need functions like
|
||||
``backward``, ``zero_grad``, and ``detach``, or properties like
|
||||
``requires_grad``.
|
||||
|
||||
The most basic example is taking the gradient of a scalar-valued function as we
|
||||
saw above. You can use the :func:`grad` and :func:`value_and_grad` function to
|
||||
compute gradients of more complex functions. By default these functions compute
|
||||
the gradient with respect to the first argument:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def loss_fn(w, x, y):
|
||||
return mx.mean(mx.square(w * x - y))
|
||||
|
||||
w = mx.array(1.0)
|
||||
x = mx.array([0.5, -0.5])
|
||||
y = mx.array([1.5, -1.5])
|
||||
|
||||
# Computes the gradient of loss_fn with respect to w:
|
||||
grad_fn = mx.grad(loss_fn)
|
||||
dloss_dw = grad_fn(w, x, y)
|
||||
# Prints array(-1, dtype=float32)
|
||||
print(dloss_dw)
|
||||
|
||||
# To get the gradient with respect to x we can do:
|
||||
grad_fn = mx.grad(loss_fn, argnums=1)
|
||||
dloss_dx = grad_fn(w, x, y)
|
||||
# Prints array([-1, 1], dtype=float32)
|
||||
print(dloss_dx)
|
||||
|
||||
|
||||
One way to get the loss and gradient is to call ``loss_fn`` followed by
|
||||
``grad_fn``, but this can result in a lot of redundant work. Instead, you
|
||||
should use :func:`value_and_grad`. Continuing the above example:
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Computes the gradient of loss_fn with respect to w:
|
||||
loss_and_grad_fn = mx.value_and_grad(loss_fn)
|
||||
loss, dloss_dw = loss_and_grad_fn(w, x, y)
|
||||
|
||||
# Prints array(1, dtype=float32)
|
||||
print(loss)
|
||||
|
||||
# Prints array(-1, dtype=float32)
|
||||
print(dloss_dw)
|
||||
|
||||
|
||||
You can also take the gradient with respect to arbitrarily nested Python
|
||||
containers of arrays (specifically any of :obj:`list`, :obj:`tuple`, or
|
||||
:obj:`dict`).
|
||||
|
||||
Suppose we wanted a weight and a bias parameter in the above example. A nice
|
||||
way to do that is the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def loss_fn(params, x, y):
|
||||
w, b = params["weight"], params["bias"]
|
||||
h = w * x + b
|
||||
return mx.mean(mx.square(h - y))
|
||||
|
||||
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
||||
x = mx.array([0.5, -0.5])
|
||||
y = mx.array([1.5, -1.5])
|
||||
|
||||
# Computes the gradient of loss_fn with respect to both the
|
||||
# weight and bias:
|
||||
grad_fn = mx.grad(loss_fn)
|
||||
grads = grad_fn(params, x, y)
|
||||
|
||||
# Prints
|
||||
# {'weight': array(-1, dtype=float32), 'bias': array(0, dtype=float32)}
|
||||
print(grads)
|
||||
|
||||
Notice the tree structure of the parameters is preserved in the gradients.
|
||||
|
||||
In some cases you may want to stop gradients from propagating through a
|
||||
part of the function. You can use the :func:`stop_gradient` for that.
|
||||
|
||||
|
||||
Automatic Vectorization
|
||||
-----------------------
|
||||
|
||||
.. _vmap:
|
||||
|
||||
Use :func:`vmap` to automate vectorizing complex functions. Here we'll go
|
||||
through a basic and contrived example for the sake of clarity, but :func:`vmap`
|
||||
can be quite powerful for more complex functions which are difficult to optimize
|
||||
by hand.
|
||||
|
||||
.. warning::
|
||||
|
||||
Some operations are not yet supported with :func:`vmap`. If you encounter an error
|
||||
like: ``ValueError: Primitive's vmap not implemented.`` file an `issue
|
||||
<https://github.com/ml-explore/mlx/issues>`_ and include your function.
|
||||
We will prioritize including it.
|
||||
|
||||
A naive way to add the elements from two sets of vectors is with a loop:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
xs = mx.random.uniform(shape=(4096, 100))
|
||||
ys = mx.random.uniform(shape=(100, 4096))
|
||||
|
||||
def naive_add(xs, ys):
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
|
||||
|
||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Vectorize over the second dimension of x and the
|
||||
# first dimension of y
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
|
||||
|
||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||
where the vectorized axes should be in the outputs.
|
||||
|
||||
Let's time these two different versions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import timeit
|
||||
|
||||
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||
|
||||
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
|
||||
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
|
||||
|
||||
Of course, this operation is quite contrived. A better approach is to simply do
|
||||
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
123
docs/src/usage/indexing.rst
Normal file
123
docs/src/usage/indexing.rst
Normal file
@@ -0,0 +1,123 @@
|
||||
.. _indexing:
|
||||
|
||||
Indexing Arrays
|
||||
===============
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
For the most part, indexing an MLX :obj:`array` works the same as indexing a
|
||||
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation
|
||||
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on
|
||||
how that works.
|
||||
|
||||
For example, you can use regular integers and slices (:obj:`slice`) to index arrays:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> arr = mx.arange(10)
|
||||
>>> arr[3]
|
||||
array(3, dtype=int32)
|
||||
>>> arr[-2] # negative indexing works
|
||||
array(8, dtype=int32)
|
||||
>>> arr[2:8:2] # start, stop, stride
|
||||
array([2, 4, 6], dtype=int32)
|
||||
|
||||
For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> arr = mx.arange(8).reshape(2, 2, 2)
|
||||
>>> arr[:, :, 0]
|
||||
array(3, dtype=int32)
|
||||
array([[0, 2],
|
||||
[4, 6]], dtype=int32
|
||||
>>> arr[..., 0]
|
||||
array([[0, 2],
|
||||
[4, 6]], dtype=int32
|
||||
|
||||
You can index with ``None`` to create a new axis:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> arr = mx.arange(8)
|
||||
>>> arr.shape
|
||||
[8]
|
||||
>>> arr[None].shape
|
||||
[1, 8]
|
||||
|
||||
|
||||
You can also use an :obj:`array` to index another :obj:`array`:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> arr = mx.arange(10)
|
||||
>>> idx = mx.array([5, 7])
|
||||
>>> arr[idx]
|
||||
array([5, 7], dtype=int32)
|
||||
|
||||
Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices
|
||||
works just as in NumPy.
|
||||
|
||||
Other functions which may be useful for indexing arrays are :func:`take` and
|
||||
:func:`take_along_axis`.
|
||||
|
||||
Differences from NumPy
|
||||
----------------------
|
||||
|
||||
.. Note::
|
||||
|
||||
MLX indexing is different from NumPy indexing in two important ways:
|
||||
|
||||
* Indexing does not perform bounds checking. Indexing out of bounds is
|
||||
undefined behavior.
|
||||
* Boolean mask based indexing is not yet supported.
|
||||
|
||||
The reason for the lack of bounds checking is that exceptions cannot propagate
|
||||
from the GPU. Performing bounds checking for array indices before launching the
|
||||
kernel would be extremely inefficient.
|
||||
|
||||
Indexing with boolean masks is something that MLX may support in the future. In
|
||||
general, MLX has limited support for operations for which output
|
||||
*shapes* are dependent on input *data*. Other examples of these types of
|
||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||
single input version of :func:`numpy.where`.
|
||||
|
||||
In Place Updates
|
||||
----------------
|
||||
|
||||
In place updates to indexed arrays are possible in MLX. For example:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> a[2] = 0
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
Just as in NumPy, in place updates will be reflected in all references to the
|
||||
same array:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> b = a
|
||||
>>> b[2] = 0
|
||||
>>> b
|
||||
array([1, 2, 0], dtype=int32)
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
Transformations of functions which use in-place updates are allowed and work as
|
||||
expected. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, idx):
|
||||
x[idx] = 2.0
|
||||
return x.sum()
|
||||
|
||||
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
|
||||
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
|
||||
|
||||
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
||||
and ones elsewhere.
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user