mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-08 10:14:43 +08:00
Compare commits
607 Commits
compile-te
...
v0.21.1
Author | SHA1 | Date | |
---|---|---|---|
![]() |
50fa705125 | ||
![]() |
69a2991614 | ||
![]() |
fd3377dd1f | ||
![]() |
d0b6cb0425 | ||
![]() |
95c4a2e3af | ||
![]() |
bc2a29f033 | ||
![]() |
3bb5b4a302 | ||
![]() |
fc88fd9097 | ||
![]() |
c5b0928c1f | ||
![]() |
e047fd977d | ||
![]() |
9d40e521d7 | ||
![]() |
1445dcaa60 | ||
![]() |
e4eeb4e910 | ||
![]() |
aa86876813 | ||
![]() |
974bb54ab2 | ||
![]() |
9bc2183a31 | ||
![]() |
d4b222b6d3 | ||
![]() |
af2af818a6 | ||
![]() |
698e63a608 | ||
![]() |
211411faf2 | ||
![]() |
bb303c45a5 | ||
![]() |
6f7986d592 | ||
![]() |
7cbb4aef17 | ||
![]() |
02bec0bb6d | ||
![]() |
c79f6a4a8c | ||
![]() |
0c5eea226b | ||
![]() |
dcca0d7477 | ||
![]() |
0d5e7716ad | ||
![]() |
d8c824c594 | ||
![]() |
cb431dfc9f | ||
![]() |
61d787726a | ||
![]() |
5e89aace9b | ||
![]() |
2af7e8a9a6 | ||
![]() |
2419edd5b2 | ||
![]() |
bf481e8e5d | ||
![]() |
9d7fa6b8e6 | ||
![]() |
073076ac7d | ||
![]() |
9bd03dd9b4 | ||
![]() |
6931f84412 | ||
![]() |
16ec0556a0 | ||
![]() |
610af352d4 | ||
![]() |
b35f1e3c9c | ||
![]() |
dfa0b9aab4 | ||
![]() |
a4c47b0276 | ||
![]() |
111fefd5e9 | ||
![]() |
c1fe1ef081 | ||
![]() |
8c34c9dac4 | ||
![]() |
91c0277356 | ||
![]() |
9f0d5c12fc | ||
![]() |
59247c2b62 | ||
![]() |
9a3842a2d9 | ||
![]() |
726dbd9267 | ||
![]() |
54f05e7195 | ||
![]() |
26be608470 | ||
![]() |
248431eb3c | ||
![]() |
76f275b4df | ||
![]() |
f1951d6cce | ||
![]() |
62f297b51d | ||
![]() |
09bc32f62f | ||
![]() |
46d8b16ab4 | ||
![]() |
42533931fa | ||
![]() |
9bd3a7102f | ||
![]() |
9e516b71ea | ||
![]() |
eac961ddb1 | ||
![]() |
57c6aa7188 | ||
![]() |
cde5b4ad80 | ||
![]() |
4f72c66911 | ||
![]() |
960e3f0f05 | ||
![]() |
884af42da2 | ||
![]() |
048fabdabd | ||
![]() |
917252a5a1 | ||
![]() |
1a992e31e8 | ||
![]() |
d2ff04a4f2 | ||
![]() |
015c247393 | ||
![]() |
d3cd26820e | ||
![]() |
91f6c499d7 | ||
![]() |
35e9c87ab9 | ||
![]() |
8e88e30d95 | ||
![]() |
0eb56d5be0 | ||
![]() |
f70764a162 | ||
![]() |
dad1b00b13 | ||
![]() |
430ffef58a | ||
![]() |
3d17077187 | ||
![]() |
c9b41d460f | ||
![]() |
32972a5924 | ||
![]() |
f6afb9c09b | ||
![]() |
3ddc07e936 | ||
![]() |
c26208f67d | ||
![]() |
d15fa13daf | ||
![]() |
58a855682c | ||
![]() |
92d7cb71f8 | ||
![]() |
50d8bed468 | ||
![]() |
9dd72cd421 | ||
![]() |
343aa46b78 | ||
![]() |
b8ab89b413 | ||
![]() |
f9f8c167d4 | ||
![]() |
3f86399922 | ||
![]() |
2b8ace6a03 | ||
![]() |
0ab8e099e8 | ||
![]() |
020f048cd0 | ||
![]() |
881615b072 | ||
![]() |
0eef4febfd | ||
![]() |
b54a70ec2d | ||
![]() |
bf6ec92216 | ||
![]() |
c21331d47f | ||
![]() |
e1c9600da3 | ||
![]() |
1fa0d20a30 | ||
![]() |
3274c6a087 | ||
![]() |
9b12093739 | ||
![]() |
f374b6ca4d | ||
![]() |
0070e1db40 | ||
![]() |
95d04805b3 | ||
![]() |
e4534dac17 | ||
![]() |
fef3c4ec1d | ||
![]() |
1bdc038bf9 | ||
![]() |
5523d9c426 | ||
![]() |
d878015228 | ||
![]() |
5900e3249f | ||
![]() |
bacced53d3 | ||
![]() |
4a64d4bff1 | ||
![]() |
b1e2b53c2d | ||
![]() |
11354d5bff | ||
![]() |
718aea3f1d | ||
![]() |
5b6f38df2b | ||
![]() |
0b4a58699e | ||
![]() |
4f9f9ebb6f | ||
![]() |
afc9c0ec1b | ||
![]() |
195b429d99 | ||
![]() |
2b878e9dd7 | ||
![]() |
67b6bf530d | ||
![]() |
6af5ca35b2 | ||
![]() |
4f46e9c997 | ||
![]() |
c6739ba7f3 | ||
![]() |
914409fef9 | ||
![]() |
8d68a3e805 | ||
![]() |
6bbcc453ef | ||
![]() |
d5ed4d7a71 | ||
![]() |
669c27140d | ||
![]() |
adcc88e208 | ||
![]() |
d6492b0163 | ||
![]() |
b3f52c9fbe | ||
![]() |
bd8396fad8 | ||
![]() |
d0c58841d1 | ||
![]() |
881f09b2e2 | ||
![]() |
8b30acd7eb | ||
![]() |
02efb310ca | ||
![]() |
e7e59c6f05 | ||
![]() |
3ae6aabe9f | ||
![]() |
dc627dcb5e | ||
![]() |
efeb9c0f02 | ||
![]() |
ba3e913c7a | ||
![]() |
7cca1727af | ||
![]() |
11371fe251 | ||
![]() |
41c603d48a | ||
![]() |
969337345f | ||
![]() |
9592766939 | ||
![]() |
58dca7d846 | ||
![]() |
0d302cd25b | ||
![]() |
da691257ec | ||
![]() |
1600092e92 | ||
![]() |
dba2bd1105 | ||
![]() |
28be4de7c2 | ||
![]() |
a6c3b38fba | ||
![]() |
fcb65a3897 | ||
![]() |
4e22a1dffe | ||
![]() |
291cf40aca | ||
![]() |
bd47e1f066 | ||
![]() |
e6b223df5f | ||
![]() |
e64349bbdd | ||
![]() |
cdb59faea6 | ||
![]() |
1d94ac3f90 | ||
![]() |
5f7d19d1f5 | ||
![]() |
2fdf9eb535 | ||
![]() |
860d3a50d7 | ||
![]() |
d1183821a7 | ||
![]() |
8081df79be | ||
![]() |
64bec4fad7 | ||
![]() |
b96e105244 | ||
![]() |
3b4d5484c7 | ||
![]() |
684e11c664 | ||
![]() |
b57a52813b | ||
![]() |
da8deb2b62 | ||
![]() |
98b6ce3460 | ||
![]() |
f9e00efe31 | ||
![]() |
0fd2a1f4b0 | ||
![]() |
df3233454d | ||
![]() |
82db84b899 | ||
![]() |
8ae751d3da | ||
![]() |
d40e76809f | ||
![]() |
bb1b76d9dc | ||
![]() |
9d26441224 | ||
![]() |
f12f24a77c | ||
![]() |
ae5b5cabfd | ||
![]() |
d0630ffe8c | ||
![]() |
99bb7d3a58 | ||
![]() |
63ae767232 | ||
![]() |
eaaea02010 | ||
![]() |
a098bc92e0 | ||
![]() |
1086dc4db0 | ||
![]() |
19fb69e2ed | ||
![]() |
9231617eb3 | ||
![]() |
32668a7317 | ||
![]() |
780c197f95 | ||
![]() |
eb8819e91e | ||
![]() |
30bbea2f08 | ||
![]() |
635ccd9e25 | ||
![]() |
8c9f0278b9 | ||
![]() |
58d0e199e1 | ||
![]() |
10b5835501 | ||
![]() |
6c8dd307eb | ||
![]() |
43ffdab172 | ||
![]() |
40b6d67333 | ||
![]() |
c52d1600f0 | ||
![]() |
aa1d6cadad | ||
![]() |
6e06e3a904 | ||
![]() |
8cfb9fc0b8 | ||
![]() |
7b456fd2c0 | ||
![]() |
e9e53856d2 | ||
![]() |
5029894662 | ||
![]() |
baf9fa5f42 | ||
![]() |
7f914365fd | ||
![]() |
ebd7135b50 | ||
![]() |
50eff6a10a | ||
![]() |
c34a5ae7f7 | ||
![]() |
e2aa6ec8ae | ||
![]() |
6768c6a54a | ||
![]() |
6307d166eb | ||
![]() |
1fba87b0df | ||
![]() |
df124e018a | ||
![]() |
2f83d6e4b7 | ||
![]() |
987785d8d7 | ||
![]() |
8c01a7893b | ||
![]() |
218047c75a | ||
![]() |
d0da74209b | ||
![]() |
5c1fa64fb0 | ||
![]() |
a3c287354f | ||
![]() |
03cf033f82 | ||
![]() |
bdb36c9a63 | ||
![]() |
20bb301195 | ||
![]() |
d6383a1c6a | ||
![]() |
b05bcfd27f | ||
![]() |
2615660e62 | ||
![]() |
5b0af4cdb1 | ||
![]() |
8c2e15e6c8 | ||
![]() |
56c8a33439 | ||
![]() |
4eef1e8a3e | ||
![]() |
95d11bda06 | ||
![]() |
af9079cc1f | ||
![]() |
2d6cd47713 | ||
![]() |
fe3167d7ea | ||
![]() |
31e134be35 | ||
![]() |
e84ba8056d | ||
![]() |
f20e97b092 | ||
![]() |
934683088e | ||
![]() |
de2b9e7d0a | ||
![]() |
dd7d8e5e29 | ||
![]() |
df964132fb | ||
![]() |
709ccc6800 | ||
![]() |
cf236fc390 | ||
![]() |
27d70c7d9d | ||
![]() |
0e585b4409 | ||
![]() |
0163a8e57a | ||
![]() |
578842954c | ||
![]() |
496315fe1d | ||
![]() |
0fe6895893 | ||
![]() |
0b7d71fd2f | ||
![]() |
83b11bc58d | ||
![]() |
375a8bbdcc | ||
![]() |
ea9090bbc4 | ||
![]() |
81def6ac76 | ||
![]() |
3de8ce3f3c | ||
![]() |
4d485fca24 | ||
![]() |
1865299a30 | ||
![]() |
3576b547c5 | ||
![]() |
079882495d | ||
![]() |
ab977109db | ||
![]() |
fd1c08137b | ||
![]() |
76b6cece46 | ||
![]() |
9f0df51f8d | ||
![]() |
e7a2a3dcd1 | ||
![]() |
a87ef5bfc1 | ||
![]() |
9f9cb7a2ef | ||
![]() |
7e26fd8032 | ||
![]() |
eab2685c67 | ||
![]() |
50dfb664db | ||
![]() |
0189ab6ab6 | ||
![]() |
9401507336 | ||
![]() |
eb8321d863 | ||
![]() |
79ef49b2c2 | ||
![]() |
e110ca11e2 | ||
![]() |
226748b3e7 | ||
![]() |
d568c7ee36 | ||
![]() |
e6fecbb3e1 | ||
![]() |
da83f899bb | ||
![]() |
7e5674d8be | ||
![]() |
0a558577bf | ||
![]() |
fb71a82ada | ||
![]() |
23406c9e9e | ||
![]() |
b3ec792380 | ||
![]() |
6a9b584f3d | ||
![]() |
81dd33af66 | ||
![]() |
8b76571896 | ||
![]() |
e78a6518fa | ||
![]() |
1873ffda01 | ||
![]() |
c417e42116 | ||
![]() |
358e1fd6ab | ||
![]() |
631dfbe673 | ||
![]() |
56a4eaed72 | ||
![]() |
bf925d9dc7 | ||
![]() |
1a7ed5dcb6 | ||
![]() |
5be5daa6ef | ||
![]() |
60cb11764e | ||
![]() |
cbd5445ea7 | ||
![]() |
2c7e9b5158 | ||
![]() |
2263e4b279 | ||
![]() |
863039da4c | ||
![]() |
7178ac0111 | ||
![]() |
e7f9710499 | ||
![]() |
ff4223904d | ||
![]() |
a9f80d60f6 | ||
![]() |
2e158cf6d0 | ||
![]() |
8bd6bfa4b5 | ||
![]() |
8b1906abd0 | ||
![]() |
06375e6605 | ||
![]() |
b21242faf1 | ||
![]() |
cc05a281c4 | ||
![]() |
fe96ceee66 | ||
![]() |
9814a2ae12 | ||
![]() |
6992498e7a | ||
![]() |
21623156a3 | ||
![]() |
79c859e2e0 | ||
![]() |
b00ac960b4 | ||
![]() |
02a9fc7bfa | ||
![]() |
f390957685 | ||
![]() |
17f57df797 | ||
![]() |
7f7b9662ea | ||
![]() |
19bef39f5c | ||
![]() |
a30e7ed2da | ||
![]() |
8db7161c94 | ||
![]() |
09f1777896 | ||
![]() |
490c0c4fdc | ||
![]() |
c4a471c99d | ||
![]() |
86f495985b | ||
![]() |
67d1894759 | ||
![]() |
5bfe89bdb1 | ||
![]() |
82463e9938 | ||
![]() |
771575d27b | ||
![]() |
20a01bbd9f | ||
![]() |
ec8578d41a | ||
![]() |
d0dbfe0b97 | ||
![]() |
3d405fb3b1 | ||
![]() |
b0012cdd0f | ||
![]() |
84d61d27aa | ||
![]() |
ed83908931 | ||
![]() |
ef5f7d1aea | ||
![]() |
090ff659dc | ||
![]() |
85c8a91a27 | ||
![]() |
581b699ac9 | ||
![]() |
8a0677d56d | ||
![]() |
b18468bf81 | ||
![]() |
107ba2891a | ||
![]() |
cd9e184529 | ||
![]() |
2e7c02d5cd | ||
![]() |
ae18326533 | ||
![]() |
91eba8e485 | ||
![]() |
d07e295c62 | ||
![]() |
dce4bd74a4 | ||
![]() |
ffff671273 | ||
![]() |
12d4507ee3 | ||
![]() |
8580d997ff | ||
![]() |
061cf9a4ce | ||
![]() |
99abb9eff4 | ||
![]() |
fffe072028 | ||
![]() |
a1a31eed27 | ||
![]() |
ae812350f9 | ||
![]() |
b63ef10a7f | ||
![]() |
42afe27e12 | ||
![]() |
76e63212ff | ||
![]() |
aac2f9fb61 | ||
![]() |
bddf23f175 | ||
![]() |
039da779d1 | ||
![]() |
d88d2124b5 | ||
![]() |
e142aaf8a1 | ||
![]() |
0caf35f4b8 | ||
![]() |
3fc993f82d | ||
![]() |
741eb28443 | ||
![]() |
1a87dc5ea8 | ||
![]() |
2427fa171e | ||
![]() |
639e06e1f3 | ||
![]() |
02fedbf1da | ||
![]() |
110d9b149d | ||
![]() |
9cbff5ec1d | ||
![]() |
433c0206b0 | ||
![]() |
8915901966 | ||
![]() |
f48bc496c7 | ||
![]() |
913b19329c | ||
![]() |
d8cb3128f6 | ||
![]() |
5f9ba3019f | ||
![]() |
46caf0bef0 | ||
![]() |
45f636e759 | ||
![]() |
a7b404ff53 | ||
![]() |
c4fd0e5ede | ||
![]() |
bab5386306 | ||
![]() |
aca7584635 | ||
![]() |
d611251502 | ||
![]() |
f30b659291 | ||
![]() |
90dfa43ff1 | ||
![]() |
dc175f08d3 | ||
![]() |
29221fa238 | ||
![]() |
a789685c63 | ||
![]() |
240d10699c | ||
![]() |
925014b661 | ||
![]() |
5611e1a95e | ||
![]() |
570f2bf29e | ||
![]() |
9948eddf11 | ||
![]() |
a3ee03da01 | ||
![]() |
28fcd2b519 | ||
![]() |
8e686764ac | ||
![]() |
479051ce1c | ||
![]() |
bfb5bad4f0 | ||
![]() |
1e16331d9c | ||
![]() |
be98f4ab6b | ||
![]() |
6ee1112f30 | ||
![]() |
8e5a5a1ccd | ||
![]() |
fcda3a0e66 | ||
![]() |
9663c22fe9 | ||
![]() |
f0ae00da12 | ||
![]() |
44390bd3d0 | ||
![]() |
2225374060 | ||
![]() |
105d236889 | ||
![]() |
53e6a9367c | ||
![]() |
f5a1582fe8 | ||
![]() |
a54f06b16f | ||
![]() |
4650d94d98 | ||
![]() |
a5681ebc52 | ||
![]() |
e849b3424a | ||
![]() |
b219d12a6b | ||
![]() |
cec8661113 | ||
![]() |
73a8c090e0 | ||
![]() |
db6796ac61 | ||
![]() |
9a8ee00246 | ||
![]() |
d39ed54f8e | ||
![]() |
16546c70d8 | ||
![]() |
eaba55c9bf | ||
![]() |
19ec023256 | ||
![]() |
63ab0ab580 | ||
![]() |
8dfc376c00 | ||
![]() |
1efee9db09 | ||
![]() |
43abc402d8 | ||
![]() |
3f8b1668c4 | ||
![]() |
76c919b4ec | ||
![]() |
29d0c10ee5 | ||
![]() |
5ad133f8bb | ||
![]() |
d0c544a868 | ||
![]() |
ffb19df3c0 | ||
![]() |
8b7532b9ab | ||
![]() |
366478c560 | ||
![]() |
8e5600022a | ||
![]() |
0e95b64942 | ||
![]() |
0ae22b915b | ||
![]() |
7c441600fe | ||
![]() |
a4d290adb9 | ||
![]() |
28301807c2 | ||
![]() |
74ed0974b3 | ||
![]() |
ec8a4864fa | ||
![]() |
b7588fd5d7 | ||
![]() |
f512b905c7 | ||
![]() |
afd5274049 | ||
![]() |
1074674e32 | ||
![]() |
7762e07fde | ||
![]() |
cbefd9129e | ||
![]() |
e39bebe13e | ||
![]() |
14b4e51a7c | ||
![]() |
cbcf44a4ca | ||
![]() |
859ae15a54 | ||
![]() |
0787724c44 | ||
![]() |
7b463ffb07 | ||
![]() |
6686e61ca4 | ||
![]() |
c096a77b9b | ||
![]() |
5121f028d9 | ||
![]() |
6a665ea6ed | ||
![]() |
bc06cb9ff6 | ||
![]() |
8e281c76c3 | ||
![]() |
d5964a2710 | ||
![]() |
cf3eb87e52 | ||
![]() |
ab3a466711 | ||
![]() |
4494970f47 | ||
![]() |
776c3d226d | ||
![]() |
f5f18b704f | ||
![]() |
420ff2f331 | ||
![]() |
56ba3ec40e | ||
![]() |
de3d2467a3 | ||
![]() |
fe1dabf272 | ||
![]() |
08226ab491 | ||
![]() |
3b661b7394 | ||
![]() |
e6418781ab | ||
![]() |
ac02cf33bd | ||
![]() |
22364c40b7 | ||
![]() |
d729a1991b | ||
![]() |
126c9869c8 | ||
![]() |
ad4a45e615 | ||
![]() |
04fc896016 | ||
![]() |
884b4ed43b | ||
![]() |
972d9a3aea | ||
![]() |
7dcdd88e27 | ||
![]() |
8120a3b65c | ||
![]() |
5798256fcf | ||
![]() |
d0fda82595 | ||
![]() |
f883fcede0 | ||
![]() |
e1bdf6a8d9 | ||
![]() |
1a4f4c5ea6 | ||
![]() |
0925af43b0 | ||
![]() |
dc937b8ed3 | ||
![]() |
c3965fc5ee | ||
![]() |
bf7cd29970 | ||
![]() |
a000d2288c | ||
![]() |
165abf0e4c | ||
![]() |
818cda16bc | ||
![]() |
85143fecdd | ||
![]() |
35431a4ac8 | ||
![]() |
ccf1645995 | ||
![]() |
1a48713d32 | ||
![]() |
1eb04aa23f | ||
![]() |
0c65517e91 | ||
![]() |
2fdc2462c3 | ||
![]() |
be6e9d6a9f | ||
![]() |
e54cbb7ba6 | ||
![]() |
40c108766b | ||
![]() |
4cc70290f7 | ||
![]() |
74caa68d02 | ||
![]() |
3756381358 | ||
![]() |
d12573daa6 | ||
![]() |
0dbc4c7547 | ||
![]() |
06072601ce | ||
![]() |
11d2c8f7a1 | ||
![]() |
7f3f8d8f8d | ||
![]() |
b96be943dc | ||
![]() |
b670485185 | ||
![]() |
b57bd0488d | ||
![]() |
221f8d3fc2 | ||
![]() |
5c03efaf29 | ||
![]() |
7dccd42133 | ||
![]() |
1b97b2958b | ||
![]() |
e5e816a5ef | ||
![]() |
28eac18571 | ||
![]() |
5fd11c347d | ||
![]() |
ef73393a19 | ||
![]() |
ea406d5e33 | ||
![]() |
146bd69470 | ||
![]() |
316ff490b3 | ||
![]() |
d40a04f8dc | ||
![]() |
d75ae52ecd | ||
![]() |
31fea3758e | ||
![]() |
e319383ef9 | ||
![]() |
5c3ac52dd7 | ||
![]() |
ebfd3618b0 | ||
![]() |
11a9fd40f0 | ||
![]() |
4fd2fb84a6 | ||
![]() |
9852af1a19 | ||
![]() |
16750f3c51 | ||
![]() |
95b5fb8245 | ||
![]() |
83f63f2184 | ||
![]() |
cb6156d35d | ||
![]() |
506d43035c | ||
![]() |
36cff34701 | ||
![]() |
e88e474fd1 | ||
![]() |
601c6d6aa8 | ||
![]() |
ba8d6bf365 | ||
![]() |
4a5f3b21bb | ||
![]() |
fcc5ac1c64 | ||
![]() |
bad67fec37 | ||
![]() |
199aebcf77 | ||
![]() |
0de5988f92 | ||
![]() |
143e2690d5 | ||
![]() |
375446453e | ||
![]() |
1895d34c20 | ||
![]() |
09b9275027 | ||
![]() |
d3a9005454 | ||
![]() |
3f7aba8498 | ||
![]() |
65d0b8df9f | ||
![]() |
3c2f192345 | ||
![]() |
37d98ba6ff | ||
![]() |
8993382aaa | ||
![]() |
07f35c9d8a | ||
![]() |
bf17ab5002 | ||
![]() |
8fa6b322b9 | ||
![]() |
874b739f3c | ||
![]() |
077c1ee64a | ||
![]() |
2463496471 | ||
![]() |
87b7fa9ba2 | ||
![]() |
624065c074 | ||
![]() |
f27ec5e097 | ||
![]() |
f30e63353a | ||
![]() |
4fe2fa2a64 | ||
![]() |
37fc9db82c | ||
![]() |
755dcf6137 | ||
![]() |
6b4b30e3fc | ||
![]() |
86e0c79467 | ||
![]() |
98c37d3a22 | ||
![]() |
f326dd8334 | ||
![]() |
6d3bee3364 | ||
![]() |
ecb174ca9d | ||
![]() |
7a34e46677 | ||
![]() |
92c22c1ea3 | ||
![]() |
d52383367a | ||
![]() |
363d3add6d | ||
![]() |
b207c2c86b |
@@ -1,5 +1,8 @@
|
|||||||
version: 2.1
|
version: 2.1
|
||||||
|
|
||||||
|
orbs:
|
||||||
|
apple: ml-explore/pr-approval@0.1.0
|
||||||
|
|
||||||
parameters:
|
parameters:
|
||||||
nightly_build:
|
nightly_build:
|
||||||
type: boolean
|
type: boolean
|
||||||
@@ -7,8 +10,65 @@ parameters:
|
|||||||
weekly_build:
|
weekly_build:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
|
test_release:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
linux_release:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
build_documentation:
|
||||||
|
parameters:
|
||||||
|
upload-docs:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
macos:
|
||||||
|
xcode: "15.2.0"
|
||||||
|
resource_class: macos.m1.medium.gen1
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install
|
||||||
|
command: |
|
||||||
|
brew install python@3.9
|
||||||
|
brew install doxygen
|
||||||
|
python3.9 -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
|
pip install --upgrade cmake
|
||||||
|
pip install -r docs/requirements.txt
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
||||||
|
- when:
|
||||||
|
condition:
|
||||||
|
not: << parameters.upload-docs >>
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Build documentation
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
cd docs && doxygen && make html O=-W
|
||||||
|
- when:
|
||||||
|
condition: << parameters.upload-docs >>
|
||||||
|
steps:
|
||||||
|
- add_ssh_keys:
|
||||||
|
fingerprints:
|
||||||
|
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
||||||
|
- run:
|
||||||
|
name: Upload documentation
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
git config user.email "mlx@group.apple.com"
|
||||||
|
git config user.name "CircleCI Docs"
|
||||||
|
git checkout gh-pages
|
||||||
|
git rebase main
|
||||||
|
cd docs
|
||||||
|
git rm -rf build/html
|
||||||
|
doxygen && make html O=-W
|
||||||
|
git add -f build/html
|
||||||
|
git commit -m "rebase"
|
||||||
|
git push -f origin gh-pages
|
||||||
|
|
||||||
linux_build_and_test:
|
linux_build_and_test:
|
||||||
docker:
|
docker:
|
||||||
- image: cimg/python:3.9
|
- image: cimg/python:3.9
|
||||||
@@ -25,196 +85,262 @@ jobs:
|
|||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install --upgrade pybind11[global]
|
pip install nanobind==2.2.0
|
||||||
pip install numpy
|
pip install numpy
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
- run:
|
- run:
|
||||||
name: Build python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
python3 setup.py build_ext --inplace
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
python3 setup.py develop
|
||||||
- run:
|
- run:
|
||||||
name: Run the python tests
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
python3 -m unittest discover python/tests
|
echo "stubs"
|
||||||
# TODO: Reenable when extension api becomes stable
|
pip install typing_extensions
|
||||||
# - run:
|
python setup.py generate_stubs
|
||||||
# name: Build example extension
|
- run:
|
||||||
# command: |
|
name: Run Python tests
|
||||||
# cd examples/extensions && python3 -m pip install .
|
command: |
|
||||||
|
python3 -m unittest discover python/tests -v
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||||
|
make -j `nproc`
|
||||||
- run:
|
- run:
|
||||||
name: Run CPP tests
|
name: Run CPP tests
|
||||||
command: ./build/tests/tests
|
command: ./build/tests/tests
|
||||||
|
|
||||||
mac_build_and_test:
|
mac_build_and_test:
|
||||||
machine: true
|
parameters:
|
||||||
resource_class: ml-explore/m-builder
|
xcode_version:
|
||||||
|
type: string
|
||||||
|
default: "15.2.0"
|
||||||
|
macos:
|
||||||
|
xcode: << parameters.xcode_version >>
|
||||||
|
resource_class: macos.m1.medium.gen1
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
brew install python@3.9
|
||||||
rm -r $CONDA_PREFIX/envs/runner-env
|
brew install openmpi
|
||||||
conda create -y -n runner-env python=3.9
|
python3.9 -m venv env
|
||||||
conda activate runner-env
|
source env/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install --upgrade pybind11[global]
|
pip install nanobind==2.2.0
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install torch
|
pip install torch
|
||||||
pip install tensorflow
|
pip install tensorflow
|
||||||
pip install unittest-xml-reporting
|
pip install unittest-xml-reporting
|
||||||
- run:
|
- run:
|
||||||
name: Build python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
source env/bin/activate
|
||||||
conda activate runner-env
|
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
|
|
||||||
- run:
|
- run:
|
||||||
name: Run the python tests
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
source env/bin/activate
|
||||||
conda activate runner-env
|
pip install typing_extensions
|
||||||
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
python setup.py generate_stubs
|
||||||
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
|
- run:
|
||||||
# TODO: Reenable when extension api becomes stable
|
name: Run Python tests
|
||||||
# - run:
|
command: |
|
||||||
# name: Build example extension
|
source env/bin/activate
|
||||||
# command: |
|
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
# eval "$(conda shell.bash hook)"
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
# conda activate runner-env
|
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||||
# cd examples/extensions && python -m pip install .
|
- run:
|
||||||
|
name: Build example extension
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
cd examples/extensions
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python setup.py build_ext -j8
|
||||||
- store_test_results:
|
- store_test_results:
|
||||||
path: test-results
|
path: test-results
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
mkdir -p build && cd build && cmake .. && make -j
|
source env/bin/activate
|
||||||
|
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||||
- run:
|
- run:
|
||||||
name: Run CPP tests
|
name: Run CPP tests
|
||||||
command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
command: |
|
||||||
|
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||||
|
- run:
|
||||||
|
name: Build small binary
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
cd build/
|
||||||
|
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
|
-DMLX_BUILD_CPU=OFF \
|
||||||
|
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||||
|
-DMLX_BUILD_GGUF=OFF \
|
||||||
|
-DMLX_METAL_JIT=ON
|
||||||
|
make -j `sysctl -n hw.ncpu`
|
||||||
|
- run:
|
||||||
|
name: Run Python tests with JIT
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||||
|
pip install -e . -v
|
||||||
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||||
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
|
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
machine: true
|
|
||||||
resource_class: ml-explore/m-builder
|
|
||||||
parameters:
|
parameters:
|
||||||
python_version:
|
python_version:
|
||||||
type: string
|
type: string
|
||||||
default: "3.9"
|
default: "3.9"
|
||||||
macos_version:
|
xcode_version:
|
||||||
type: string
|
type: string
|
||||||
default: "14"
|
default: "15.2.0"
|
||||||
|
build_env:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
|
macos:
|
||||||
|
xcode: << parameters.xcode_version >>
|
||||||
|
resource_class: macos.m1.medium.gen1
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
brew install python@<< parameters.python_version >>
|
||||||
rm -r $CONDA_PREFIX/envs/runner-env
|
brew install openmpi
|
||||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
python<< parameters.python_version >> -m venv env
|
||||||
conda activate runner-env
|
source env/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install --upgrade pybind11[global]
|
pip install nanobind==2.2.0
|
||||||
|
pip install --upgrade setuptools
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install twine
|
pip install twine
|
||||||
|
pip install build
|
||||||
- run:
|
- run:
|
||||||
name: Build package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
source env/bin/activate
|
||||||
conda activate runner-env
|
DEV_RELEASE=1 \
|
||||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
PYPI_RELEASE=1 \
|
pip install . -v
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
- run:
|
||||||
python setup.py bdist_wheel
|
name: Generate package stubs
|
||||||
twine upload dist/* --repository mlx
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
- run:
|
||||||
|
name: Build Python package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
<< parameters.build_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||||
|
python -m build -w
|
||||||
|
- when:
|
||||||
|
condition: << parameters.build_env >>
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Upload package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload dist/*
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: dist/
|
path: dist/
|
||||||
|
|
||||||
build_dev_release:
|
build_linux_release:
|
||||||
machine: true
|
|
||||||
resource_class: ml-explore/m-builder
|
|
||||||
parameters:
|
parameters:
|
||||||
python_version:
|
python_version:
|
||||||
type: string
|
type: string
|
||||||
default: "3.9"
|
default: "3.9"
|
||||||
macos_version:
|
extra_env:
|
||||||
type: string
|
type: string
|
||||||
default: "14"
|
default: "DEV_RELEASE=1"
|
||||||
|
docker:
|
||||||
|
- image: ubuntu:20.04
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Build wheel
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
PYTHON=python<< parameters.python_version >>
|
||||||
rm -r $CONDA_PREFIX/envs/runner-env
|
apt-get update
|
||||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
apt-get upgrade -y
|
||||||
conda activate runner-env
|
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
||||||
|
apt-get install -y apt-utils
|
||||||
|
apt-get install -y software-properties-common
|
||||||
|
add-apt-repository -y ppa:deadsnakes/ppa
|
||||||
|
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||||
|
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
apt-get install -y build-essential git
|
||||||
|
$PYTHON -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install --upgrade pybind11[global]
|
pip install nanobind==2.2.0
|
||||||
|
pip install --upgrade setuptools
|
||||||
pip install numpy
|
pip install numpy
|
||||||
|
pip install auditwheel
|
||||||
|
pip install patchelf
|
||||||
|
pip install build
|
||||||
pip install twine
|
pip install twine
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
pip install . -v
|
||||||
|
pip install typing_extensions
|
||||||
|
python setup.py generate_stubs
|
||||||
|
<< parameters.extra_env >> \
|
||||||
|
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||||
|
python -m build --wheel
|
||||||
|
auditwheel show dist/*
|
||||||
|
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||||
- run:
|
- run:
|
||||||
name: Build package
|
name: Upload package
|
||||||
command: |
|
command: |
|
||||||
eval "$(conda shell.bash hook)"
|
source env/bin/activate
|
||||||
conda activate runner-env
|
twine upload wheelhouse/*
|
||||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
|
||||||
DEV_RELEASE=1 \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
|
||||||
python setup.py bdist_wheel
|
|
||||||
twine upload dist/* --repository mlx
|
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: dist/
|
path: wheelhouse/
|
||||||
|
|
||||||
build_package:
|
|
||||||
machine: true
|
|
||||||
resource_class: ml-explore/m-builder
|
|
||||||
parameters:
|
|
||||||
python_version:
|
|
||||||
type: string
|
|
||||||
default: "3.9"
|
|
||||||
macos_version:
|
|
||||||
type: string
|
|
||||||
default: "14"
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install dependencies
|
|
||||||
command: |
|
|
||||||
eval "$(conda shell.bash hook)"
|
|
||||||
rm -r $CONDA_PREFIX/envs/runner-env
|
|
||||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
|
||||||
conda activate runner-env
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install --upgrade pybind11[global]
|
|
||||||
pip install numpy
|
|
||||||
pip install twine
|
|
||||||
- run:
|
|
||||||
name: Build package
|
|
||||||
command: |
|
|
||||||
eval "$(conda shell.bash hook)"
|
|
||||||
conda activate runner-env
|
|
||||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
|
||||||
python setup.py bdist_wheel
|
|
||||||
- store_artifacts:
|
|
||||||
path: dist/
|
|
||||||
|
|
||||||
workflows:
|
workflows:
|
||||||
build_and_test:
|
build_and_test:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
|
- matches:
|
||||||
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
|
value: << pipeline.git.branch >>
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
- not: << pipeline.parameters.weekly_build >>
|
||||||
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
|
- mac_build_and_test:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
- mac_build_and_test
|
- build_documentation
|
||||||
|
|
||||||
|
build_pypi_release:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
|
- not: << pipeline.parameters.weekly_build >>
|
||||||
|
- not: << pipeline.parameters.test_release >>
|
||||||
|
jobs:
|
||||||
- build_release:
|
- build_release:
|
||||||
filters:
|
filters:
|
||||||
tags:
|
tags:
|
||||||
@@ -223,21 +349,65 @@ workflows:
|
|||||||
ignore: /.*/
|
ignore: /.*/
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
macos_version: ["13", "14"]
|
xcode_version: ["15.0.0", "15.2.0"]
|
||||||
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
- build_documentation:
|
||||||
|
filters:
|
||||||
|
tags:
|
||||||
|
only: /^v.*/
|
||||||
|
branches:
|
||||||
|
ignore: /.*/
|
||||||
|
upload-docs: true
|
||||||
|
|
||||||
|
prb:
|
||||||
|
when:
|
||||||
|
matches:
|
||||||
|
pattern: "^pull/\\d+(/head)?$"
|
||||||
|
value: << pipeline.git.branch >>
|
||||||
|
jobs:
|
||||||
|
- hold:
|
||||||
|
type: approval
|
||||||
|
- apple/authenticate:
|
||||||
|
context: pr-approval
|
||||||
|
- mac_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||||
|
- linux_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
nightly_build:
|
nightly_build:
|
||||||
when: << pipeline.parameters.nightly_build >>
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.nightly_build >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_package:
|
- build_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
macos_version: ["13", "14"]
|
xcode_version: ["15.0.0", "15.2.0"]
|
||||||
weekly_build:
|
weekly_build:
|
||||||
when: << pipeline.parameters.weekly_build >>
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.weekly_build >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_dev_release:
|
- build_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
macos_version: ["13", "14"]
|
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||||
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
linux_test_release:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
|
- << pipeline.parameters.linux_release >>
|
||||||
|
jobs:
|
||||||
|
- build_linux_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
2
.github/workflows/pull_request.yml
vendored
2
.github/workflows/pull_request.yml
vendored
@@ -17,4 +17,4 @@ jobs:
|
|||||||
pip install pre-commit black isort clang-format
|
pip install pre-commit black isort clang-format
|
||||||
- name: Run lint
|
- name: Run lint
|
||||||
run: |
|
run: |
|
||||||
pre-commit run --all-files
|
pre-commit run --all-files
|
||||||
|
@@ -1,16 +1,21 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v17.0.6
|
rev: v19.1.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 23.12.1
|
rev: 24.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.12.0
|
rev: 5.13.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args:
|
args:
|
||||||
- --profile=black
|
- --profile=black
|
||||||
|
- repo: https://github.com/cheshirekow/cmake-format-precommit
|
||||||
|
rev: v0.6.13
|
||||||
|
hooks:
|
||||||
|
- id: cmake-format
|
||||||
|
@@ -7,11 +7,18 @@ with a short description of your contribution(s) below. For example:
|
|||||||
|
|
||||||
MLX was developed with contributions from the following individuals:
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||||
|
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||||
|
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||||
|
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||||
|
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||||
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
@@ -252,4 +259,4 @@ Unless required by applicable law or agreed to in writing, software
|
|||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
|
24
CITATION.cff
Normal file
24
CITATION.cff
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
title: mlx
|
||||||
|
message: >-
|
||||||
|
If you use this software, please cite it using the
|
||||||
|
metadata from this file.
|
||||||
|
type: software
|
||||||
|
authors:
|
||||||
|
- given-names: Awni
|
||||||
|
family-names: Hannun
|
||||||
|
affiliation: Apple
|
||||||
|
- given-names: Jagrit
|
||||||
|
family-names: Digani
|
||||||
|
affiliation: Apple
|
||||||
|
- given-names: Angelos
|
||||||
|
family-names: Katharopoulos
|
||||||
|
affiliation: Apple
|
||||||
|
- given-names: Ronan
|
||||||
|
family-names: Collobert
|
||||||
|
affiliation: Apple
|
||||||
|
repository-code: 'https://github.com/ml-explore'
|
||||||
|
abstract: >-
|
||||||
|
MLX: efficient and flexible machine learning on Apple
|
||||||
|
silicon
|
||||||
|
license: MIT
|
302
CMakeLists.txt
302
CMakeLists.txt
@@ -15,36 +15,43 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
|||||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
|
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||||
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
|
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||||
|
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||||
|
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
|
||||||
if(NOT MLX_VERSION)
|
if(NOT MLX_VERSION)
|
||||||
set(MLX_VERSION 0.0.10)
|
set(MLX_VERSION 0.21.1)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
|
|
||||||
message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
message(
|
||||||
|
STATUS
|
||||||
|
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
||||||
|
)
|
||||||
|
|
||||||
set(MLX_BUILD_ARM OFF)
|
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
|
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
if(NOT MLX_ENABLE_X64_MAC)
|
||||||
|
message(
|
||||||
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
|
FATAL_ERROR
|
||||||
message(FATAL_ERROR
|
"Building for x86_64 on macOS is not supported."
|
||||||
"Building for x86_64 on macOS is not supported."
|
" If you are on an Apple silicon system, check the build"
|
||||||
" If you are on an Apple silicon system, check the build"
|
" documentation for possible fixes: "
|
||||||
" documentation for possible fixes: "
|
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
||||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
)
|
||||||
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
else()
|
||||||
message(WARNING
|
set(MLX_BUILD_METAL OFF)
|
||||||
"Building for x86_64 on macOS is not supported."
|
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||||
" If you are on an Apple silicon system, "
|
endif()
|
||||||
" make sure you are building for arm64.")
|
|
||||||
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
|
|
||||||
set(MLX_BUILD_ARM ON)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
else()
|
||||||
|
set(MLX_BUILD_METAL OFF)
|
||||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@@ -56,141 +63,199 @@ cmake_policy(SET CMP0135 NEW)
|
|||||||
|
|
||||||
add_library(mlx)
|
add_library(mlx)
|
||||||
|
|
||||||
if (MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
find_library(METAL_LIB Metal)
|
set(METAL_LIB "-framework Metal")
|
||||||
find_library(FOUNDATION_LIB Foundation)
|
set(FOUNDATION_LIB "-framework Foundation")
|
||||||
find_library(QUARTZ_LIB QuartzCore)
|
set(QUARTZ_LIB "-framework QuartzCore")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
message(STATUS "Metal not found. Unable to build GPU")
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
elseif (MLX_BUILD_METAL)
|
set(MLX_METAL_DEBUG OFF)
|
||||||
|
elseif(MLX_BUILD_METAL)
|
||||||
message(STATUS "Building METAL sources")
|
message(STATUS "Building METAL sources")
|
||||||
add_compile_definitions(_METAL_)
|
|
||||||
|
|
||||||
# Throw an error if xcrun not found
|
if(MLX_METAL_DEBUG)
|
||||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
add_compile_definitions(MLX_METAL_DEBUG)
|
||||||
OUTPUT_VARIABLE MACOS_VERSION
|
|
||||||
COMMAND_ERROR_IS_FATAL ANY)
|
|
||||||
|
|
||||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
|
||||||
|
|
||||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
|
||||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
|
||||||
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
|
|
||||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
FetchContent_Declare(
|
# Throw an error if xcrun not found
|
||||||
metal_cpp
|
execute_process(
|
||||||
URL ${METAL_CPP_URL}
|
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
|
OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
|
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||||
|
message(
|
||||||
|
FATAL_ERROR
|
||||||
|
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
||||||
|
endif()
|
||||||
|
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
||||||
|
|
||||||
|
set(METAL_CPP_URL
|
||||||
|
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
||||||
|
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||||
|
endif()
|
||||||
|
execute_process(
|
||||||
|
COMMAND
|
||||||
|
zsh "-c"
|
||||||
|
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||||
|
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||||
|
|
||||||
FetchContent_MakeAvailable(metal_cpp)
|
FetchContent_MakeAvailable(metal_cpp)
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx PUBLIC
|
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||||
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
$<INSTALL_INTERFACE:include/metal_cpp>)
|
||||||
$<INSTALL_INTERFACE:include/metal_cpp>
|
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
||||||
)
|
|
||||||
target_link_libraries(
|
|
||||||
mlx
|
|
||||||
${METAL_LIB}
|
|
||||||
${FOUNDATION_LIB}
|
|
||||||
${QUARTZ_LIB})
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
if(MLX_BUILD_CPU)
|
||||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
if(ACCELERATE_LIBRARY)
|
||||||
set(MLX_BUILD_ACCELERATE ON)
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
set(MLX_BUILD_ACCELERATE ON)
|
||||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||||
else()
|
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
else()
|
||||||
set(MLX_BUILD_ACCELERATE OFF)
|
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||||
#set(BLA_VENDOR Generic)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
find_package(BLAS REQUIRED)
|
if(${CMAKE_HOST_APPLE})
|
||||||
if (NOT BLAS_FOUND)
|
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||||
message(FATAL_ERROR "Must have BLAS installed")
|
# openblas instead.
|
||||||
|
set(BLA_VENDOR OpenBLAS)
|
||||||
|
set(LAPACK_ROOT
|
||||||
|
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||||
|
endif()
|
||||||
|
# Search and link with lapack.
|
||||||
|
find_package(LAPACK REQUIRED)
|
||||||
|
if(NOT LAPACK_FOUND)
|
||||||
|
message(FATAL_ERROR "Must have LAPACK installed")
|
||||||
|
endif()
|
||||||
|
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
|
||||||
|
/usr/local/opt/openblas/include)
|
||||||
|
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||||
|
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||||
|
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||||
|
# List blas after lapack otherwise we may accidentally incldue an old
|
||||||
|
# version of lapack.h from the include dirs of blas.
|
||||||
|
find_package(BLAS REQUIRED)
|
||||||
|
if(NOT BLAS_FOUND)
|
||||||
|
message(FATAL_ERROR "Must have BLAS installed")
|
||||||
|
endif()
|
||||||
|
# TODO find a cleaner way to do this
|
||||||
|
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
|
||||||
|
$ENV{BLAS_HOME}/include)
|
||||||
|
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||||
|
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||||
|
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
find_package(dlfcn-win32 REQUIRED)
|
||||||
|
message(STATUS "dlfcn-win32 lib " ${dlfcn-win32_LIBRARIES})
|
||||||
|
message(STATUS "dlfcn-win32 include " ${dlfcn-win32_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(mlx PUBLIC ${dlfcn-win32_LIBRARIES})
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
find_package(MPI)
|
||||||
|
if(MPI_FOUND)
|
||||||
|
execute_process(
|
||||||
|
COMMAND zsh "-c" "mpirun --version"
|
||||||
|
OUTPUT_VARIABLE MPI_VERSION
|
||||||
|
ERROR_QUIET)
|
||||||
|
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||||
|
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||||
|
elseif(MPI_VERSION STREQUAL "")
|
||||||
|
set(MPI_FOUND FALSE)
|
||||||
|
message(
|
||||||
|
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
||||||
|
else()
|
||||||
|
set(MPI_FOUND FALSE)
|
||||||
|
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
||||||
endif()
|
endif()
|
||||||
# TODO find a cleaner way to do this
|
|
||||||
find_path(BLAS_INCLUDE_DIRS cblas.h
|
|
||||||
/usr/include
|
|
||||||
/usr/local/include
|
|
||||||
$ENV{BLAS_HOME}/include)
|
|
||||||
message(STATUS ${BLAS_LIBRARIES})
|
|
||||||
message(STATUS ${BLAS_INCLUDE_DIRS})
|
|
||||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
|
||||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||||
|
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx
|
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
PUBLIC
|
$<INSTALL_INTERFACE:include>)
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
|
||||||
$<INSTALL_INTERFACE:include>
|
|
||||||
)
|
|
||||||
|
|
||||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
FetchContent_Declare(
|
||||||
|
fmt
|
||||||
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
|
GIT_TAG 10.2.1
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||||
|
|
||||||
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
message(STATUS "Building Python bindings.")
|
message(STATUS "Building Python bindings.")
|
||||||
find_package(Python COMPONENTS Interpreter Development)
|
find_package(
|
||||||
find_package(pybind11 CONFIG REQUIRED)
|
Python 3.8
|
||||||
|
COMPONENTS Interpreter Development.Module
|
||||||
|
REQUIRED)
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE NB_DIR)
|
||||||
|
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||||
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_TESTS)
|
if(MLX_BUILD_TESTS)
|
||||||
include(CTest)
|
include(CTest)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_EXAMPLES)
|
if(MLX_BUILD_EXAMPLES)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (MLX_BUILD_BENCHMARKS)
|
if(MLX_BUILD_BENCHMARKS)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------- Installation -----------------------------
|
# ----------------------------- Installation -----------------------------
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
# Install library
|
# Install library
|
||||||
install(
|
install(
|
||||||
TARGETS mlx
|
TARGETS mlx
|
||||||
EXPORT MLXTargets
|
EXPORT MLXTargets
|
||||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||||
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
INCLUDES
|
||||||
)
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||||
|
|
||||||
|
|
||||||
# Install headers
|
# Install headers
|
||||||
install(
|
install(
|
||||||
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||||
COMPONENT headers
|
COMPONENT headers
|
||||||
FILES_MATCHING PATTERN "*.h"
|
FILES_MATCHING
|
||||||
)
|
PATTERN "*.h"
|
||||||
|
PATTERN "backend/metal/kernels.h" EXCLUDE)
|
||||||
|
|
||||||
# Install metal dependencies
|
# Install metal dependencies
|
||||||
if (MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
|
|
||||||
# Install metal cpp
|
# Install metal cpp
|
||||||
install(
|
install(
|
||||||
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
||||||
COMPONENT metal_cpp_source
|
COMPONENT metal_cpp_source)
|
||||||
)
|
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@@ -202,31 +267,24 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
|
|||||||
install(
|
install(
|
||||||
EXPORT MLXTargets
|
EXPORT MLXTargets
|
||||||
FILE MLXTargets.cmake
|
FILE MLXTargets.cmake
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||||
)
|
|
||||||
|
|
||||||
include(CMakePackageConfigHelpers)
|
include(CMakePackageConfigHelpers)
|
||||||
|
|
||||||
write_basic_package_version_file(
|
write_basic_package_version_file(
|
||||||
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||||
COMPATIBILITY SameMajorVersion
|
COMPATIBILITY SameMajorVersion
|
||||||
VERSION ${MLX_VERSION}
|
VERSION ${MLX_VERSION})
|
||||||
)
|
|
||||||
|
|
||||||
configure_package_config_file(
|
configure_package_config_file(
|
||||||
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
|
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
|
||||||
${MLX_CMAKE_BUILD_CONFIG}
|
|
||||||
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
||||||
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
|
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
|
||||||
)
|
MLX_CMAKE_INSTALL_MODULE_DIR)
|
||||||
|
|
||||||
install(
|
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||||
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
|
||||||
)
|
|
||||||
|
|
||||||
install(
|
install(DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||||
DIRECTORY ${CMAKE_MODULE_PATH}/
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
|
||||||
)
|
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
include CMakeLists.txt
|
include CMakeLists.txt
|
||||||
recursive-include mlx/ *
|
recursive-include mlx/ *
|
||||||
include python/src/*
|
include python/src/*
|
||||||
|
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||||
|
26
README.md
26
README.md
@@ -6,15 +6,17 @@
|
|||||||
|
|
||||||
[](https://circleci.com/gh/ml-explore/mlx)
|
[](https://circleci.com/gh/ml-explore/mlx)
|
||||||
|
|
||||||
MLX is an array framework for machine learning on Apple silicon, brought to you
|
MLX is an array framework for machine learning on Apple silicon,
|
||||||
by Apple machine learning research.
|
brought to you by Apple machine learning research.
|
||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
|
||||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
|
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||||
MLX also has a fully featured C++ API, which closely mirrors the Python API.
|
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||||
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
|
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||||
that closely follow PyTorch to simplify building more complex models.
|
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||||
|
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||||
|
more complex models.
|
||||||
|
|
||||||
- **Composable function transformations**: MLX supports composable function
|
- **Composable function transformations**: MLX supports composable function
|
||||||
transformations for automatic differentiation, automatic vectorization,
|
transformations for automatic differentiation, automatic vectorization,
|
||||||
@@ -68,23 +70,31 @@ in the documentation.
|
|||||||
|
|
||||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||||
|
|
||||||
|
**With `pip`**:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**With `conda`**:
|
||||||
|
|
||||||
|
```
|
||||||
|
conda install -c conda-forge mlx
|
||||||
|
```
|
||||||
|
|
||||||
Checkout the
|
Checkout the
|
||||||
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
||||||
for more information on building the C++ and Python APIs from source.
|
for more information on building the C++ and Python APIs from source.
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
||||||
on contributing to MLX. See the
|
on contributing to MLX. See the
|
||||||
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
||||||
information on building from source, and running tests.
|
information on building from source, and running tests.
|
||||||
|
|
||||||
We are grateful for all of [our
|
We are grateful for all of [our
|
||||||
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||||
to MLX and wish to be acknowledged, please add your name to the list in your
|
to MLX and wish to be acknowledged, please add your name to the list in your
|
||||||
pull request.
|
pull request.
|
||||||
|
|
||||||
|
@@ -73,6 +73,7 @@ void time_unary_ops() {
|
|||||||
|
|
||||||
void time_binary_ops() {
|
void time_binary_ops() {
|
||||||
int M = 1000, N = 100, K = 10;
|
int M = 1000, N = 100, K = 10;
|
||||||
|
auto condition = random::randint(0, 2, {M, N, K});
|
||||||
auto a = random::uniform({M, N, K});
|
auto a = random::uniform({M, N, K});
|
||||||
auto b = random::uniform({M, N, K});
|
auto b = random::uniform({M, N, K});
|
||||||
auto device = default_device();
|
auto device = default_device();
|
||||||
@@ -84,7 +85,9 @@ void time_binary_ops() {
|
|||||||
TIME(divide, a, b, device);
|
TIME(divide, a, b, device);
|
||||||
TIME(maximum, a, b, device);
|
TIME(maximum, a, b, device);
|
||||||
TIME(minimum, a, b, device);
|
TIME(minimum, a, b, device);
|
||||||
|
TIME(where, condition, a, b, device);
|
||||||
|
|
||||||
|
condition = array({true});
|
||||||
b = random::uniform({1});
|
b = random::uniform({1});
|
||||||
eval(b);
|
eval(b);
|
||||||
TIMEM("scalar", add, a, b, device);
|
TIMEM("scalar", add, a, b, device);
|
||||||
@@ -93,7 +96,9 @@ void time_binary_ops() {
|
|||||||
TIMEM("scalar", multiply, a, b, device);
|
TIMEM("scalar", multiply, a, b, device);
|
||||||
TIMEM("vector-scalar", divide, a, b, device);
|
TIMEM("vector-scalar", divide, a, b, device);
|
||||||
TIMEM("scalar-vector", divide, b, a, device);
|
TIMEM("scalar-vector", divide, b, a, device);
|
||||||
|
TIMEM("scalar-vector", where, condition, a, b, device);
|
||||||
|
|
||||||
|
condition = broadcast_to(array({true}), {1000, 100});
|
||||||
a = broadcast_to(random::uniform({1}), {1000, 100});
|
a = broadcast_to(random::uniform({1}), {1000, 100});
|
||||||
b = broadcast_to(random::uniform({1}), {1000, 100});
|
b = broadcast_to(random::uniform({1}), {1000, 100});
|
||||||
eval(a, b);
|
eval(a, b);
|
||||||
@@ -101,6 +106,7 @@ void time_binary_ops() {
|
|||||||
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
|
||||||
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
|
||||||
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
TIMEM("scalar-scalar broadcast", divide, a, b, device);
|
||||||
|
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_strided_ops() {
|
void time_strided_ops() {
|
||||||
|
@@ -17,14 +17,13 @@
|
|||||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
|
||||||
#define TIMEM(MSG, FUNC, ...) \
|
#define TIMEM(MSG, FUNC, ...) \
|
||||||
std::cout << "Timing " \
|
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
|
||||||
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
|
<< std::flush << std::setprecision(5) \
|
||||||
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
template <typename F, typename... Args>
|
template <typename F, typename... Args>
|
||||||
double time_fn(F fn, Args... args) {
|
double time_fn(F fn, Args&&... args) {
|
||||||
// warmup
|
// warmup
|
||||||
for (int i = 0; i < 5; ++i) {
|
for (int i = 0; i < 5; ++i) {
|
||||||
eval(fn(std::forward<Args>(args)...));
|
eval(fn(std::forward<Args>(args)...));
|
||||||
|
@@ -72,6 +72,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
|||||||
|
|
||||||
|
|
||||||
quant_matmul = {
|
quant_matmul = {
|
||||||
|
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
|
||||||
|
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
|
||||||
|
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
|
||||||
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
|
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
|
||||||
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
|
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
|
||||||
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
|
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
|
||||||
@@ -84,6 +87,15 @@ quant_matmul = {
|
|||||||
"quant_matmul_128_8": partial(
|
"quant_matmul_128_8": partial(
|
||||||
_quant_matmul, transpose=False, group_size=128, bits=8
|
_quant_matmul, transpose=False, group_size=128, bits=8
|
||||||
),
|
),
|
||||||
|
"quant_matmul_t_32_2": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=32, bits=2
|
||||||
|
),
|
||||||
|
"quant_matmul_t_32_4": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=32, bits=4
|
||||||
|
),
|
||||||
|
"quant_matmul_t_32_8": partial(
|
||||||
|
_quant_matmul, transpose=True, group_size=32, bits=8
|
||||||
|
),
|
||||||
"quant_matmul_t_64_2": partial(
|
"quant_matmul_t_64_2": partial(
|
||||||
_quant_matmul, transpose=True, group_size=64, bits=2
|
_quant_matmul, transpose=True, group_size=64, bits=2
|
||||||
),
|
),
|
||||||
@@ -132,6 +144,13 @@ def reduction(op, axis, x):
|
|||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
mx.eval(z)
|
||||||
|
|
||||||
|
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
@@ -368,10 +387,6 @@ if __name__ == "__main__":
|
|||||||
if len(args.axis) > 1:
|
if len(args.axis) > 1:
|
||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
if args.print_pid:
|
|
||||||
print(os.getpid())
|
|
||||||
input("Press enter to run")
|
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
mx.set_default_device(mx.cpu)
|
mx.set_default_device(mx.cpu)
|
||||||
else:
|
else:
|
||||||
@@ -394,6 +409,10 @@ if __name__ == "__main__":
|
|||||||
x = xs[0]
|
x = xs[0]
|
||||||
axis = args.axis[0]
|
axis = args.axis[0]
|
||||||
|
|
||||||
|
if args.print_pid:
|
||||||
|
print(os.getpid())
|
||||||
|
input("Press enter to run")
|
||||||
|
|
||||||
if args.benchmark == "matmul_square":
|
if args.benchmark == "matmul_square":
|
||||||
print(bench(matmul_square, x))
|
print(bench(matmul_square, x))
|
||||||
|
|
||||||
@@ -493,5 +512,8 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown benchmark")
|
raise ValueError("Unknown benchmark")
|
||||||
|
@@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
|
|||||||
def mish(x: torch.Tensor) -> torch.Tensor:
|
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||||
y = x
|
y = x
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
return torch.nn.functional.mish(y)
|
y = torch.nn.functional.mish(y)
|
||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@@ -283,6 +283,14 @@ def topk(axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step_function(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.where(y < 0, 0, 1)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def selu(x):
|
def selu(x):
|
||||||
y = x
|
y = x
|
||||||
@@ -331,10 +339,6 @@ if __name__ == "__main__":
|
|||||||
if len(args.axis) > 1:
|
if len(args.axis) > 1:
|
||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
if args.print_pid:
|
|
||||||
print(os.getpid())
|
|
||||||
input("Press enter to run")
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "cpu" if args.cpu else "mps"
|
device = "cpu" if args.cpu else "mps"
|
||||||
|
|
||||||
@@ -354,6 +358,10 @@ if __name__ == "__main__":
|
|||||||
x = xs[0]
|
x = xs[0]
|
||||||
axis = args.axis[0]
|
axis = args.axis[0]
|
||||||
|
|
||||||
|
if args.print_pid:
|
||||||
|
print(os.getpid())
|
||||||
|
input("Press enter to run")
|
||||||
|
|
||||||
if args.benchmark == "matmul_square":
|
if args.benchmark == "matmul_square":
|
||||||
print(bench(matmul_square, x))
|
print(bench(matmul_square, x))
|
||||||
|
|
||||||
@@ -446,5 +454,11 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "topk":
|
elif args.benchmark == "topk":
|
||||||
print(bench(topk, axis, x))
|
print(bench(topk, axis, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "step":
|
||||||
|
print(bench(step_function, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "selu":
|
||||||
|
print(bench(selu, x))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown benchmark")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
@@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
|
|||||||
result = run(*args, capture_output=True, **kwargs)
|
result = run(*args, capture_output=True, **kwargs)
|
||||||
return float(result.stdout)
|
return float(result.stdout)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
|
raise ValueError(
|
||||||
|
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compare(args):
|
def compare(args):
|
||||||
@@ -80,10 +82,8 @@ if __name__ == "__main__":
|
|||||||
_filter = make_predicate(args.filter, args.negative_filter)
|
_filter = make_predicate(args.filter, args.negative_filter)
|
||||||
|
|
||||||
if args.mlx_dtypes:
|
if args.mlx_dtypes:
|
||||||
compare_filtered = (
|
compare_filtered = lambda x: (
|
||||||
lambda x: compare_mlx_dtypes(
|
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
|
||||||
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
|
|
||||||
)
|
|
||||||
if _filter(x)
|
if _filter(x)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
107
benchmarks/python/compile_bench.py
Normal file
107
benchmarks/python/compile_bench.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def bench_gelu():
|
||||||
|
def gelu(x):
|
||||||
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1000, 1024))
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x):
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
time_fn(gen_fun(gelu), x, msg="fixed gelu")
|
||||||
|
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
|
||||||
|
|
||||||
|
def randint():
|
||||||
|
return random.randint(1, x.shape[0])
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x, y):
|
||||||
|
x = x[: randint()]
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
y = fun(y)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
y = mx.random.uniform(shape=(1000, 1024))
|
||||||
|
time_fn(gen_fun(gelu), x, y, msg="variable gelu")
|
||||||
|
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
|
||||||
|
time_fn(
|
||||||
|
gen_fun(mx.compile(gelu, shapeless=True)),
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
msg="shapeless variable gelu",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_layernorm():
|
||||||
|
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
mx.eval(weight, bias)
|
||||||
|
|
||||||
|
def layernorm(x):
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
means = mx.mean(x, axis=-1, keepdims=True)
|
||||||
|
var = mx.var(x, axis=-1, keepdims=True)
|
||||||
|
x = (x - means) * mx.rsqrt(var + 1e-4)
|
||||||
|
x = x.astype(mx.float16)
|
||||||
|
return weight * x + bias
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x):
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
|
||||||
|
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
|
||||||
|
|
||||||
|
def randint():
|
||||||
|
return random.randint(1, x.shape[0])
|
||||||
|
|
||||||
|
def gen_fun(fun):
|
||||||
|
def bench_fun(x):
|
||||||
|
x = x[: randint()]
|
||||||
|
for _ in range(10):
|
||||||
|
x = fun(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return bench_fun
|
||||||
|
|
||||||
|
random.seed(0)
|
||||||
|
time_fn(gen_fun(layernorm), x, msg="variable layernorm")
|
||||||
|
random.seed(0)
|
||||||
|
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
|
||||||
|
random.seed(0)
|
||||||
|
time_fn(
|
||||||
|
gen_fun(mx.compile(layernorm, shapeless=True)),
|
||||||
|
x,
|
||||||
|
msg="shapeless variable layernorm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Compile benchmarks.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
bench_gelu()
|
||||||
|
bench_layernorm()
|
123
benchmarks/python/conv1d_bench.py
Normal file
123
benchmarks/python/conv1d_bench.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
device_name = device_name.decode("utf-8").strip("\n")
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_1D(strides=1, padding=0, groups=1):
|
||||||
|
def mx_conv_1D(a, b):
|
||||||
|
ys = []
|
||||||
|
for _ in range(N_iter_func):
|
||||||
|
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_1D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_1D(strides=1, padding=0, groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_1D(a, b):
|
||||||
|
ys = []
|
||||||
|
for _ in range(N_iter_func):
|
||||||
|
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_1D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
|
||||||
|
scale = 1.0 / math.sqrt(wH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_1D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_1D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv1d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 1),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 2),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 4),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 8),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 8),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 16),
|
||||||
|
(4, 32, 32, 5, 32, 1, 2, 32),
|
||||||
|
(4, 32, 256, 5, 512, 1, 2, 2),
|
||||||
|
(4, 32, 256, 5, 512, 1, 2, 128),
|
||||||
|
(4, 32, 256, 5, 512, 1, 2, 256),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
|
||||||
|
for N, iH, C, wH, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, iH, C, wH, O, strides, padding, np_dtype, groups
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
127
benchmarks/python/conv2d_bench_cpu.py
Normal file
127
benchmarks/python/conv2d_bench_cpu.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 1
|
||||||
|
N_iter_bench = 10
|
||||||
|
N_iter_func = 5
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||||
|
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||||
|
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||||
|
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
143
benchmarks/python/conv2d_train_bench_cpu.py
Normal file
143
benchmarks/python/conv2d_train_bench_cpu.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn
|
||||||
|
import mlx.optimizers as opt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def bench_mlx(steps: int = 20) -> float:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
class BenchNetMLX(mlx.nn.Module):
|
||||||
|
# simple encoder-decoder net
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_channels=32):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.net = mlx.nn.Sequential(
|
||||||
|
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.Conv2d(
|
||||||
|
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.ConvTranspose2d(
|
||||||
|
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.ConvTranspose2d(
|
||||||
|
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, input):
|
||||||
|
return self.net(input)
|
||||||
|
|
||||||
|
benchNet = BenchNetMLX(3)
|
||||||
|
mx.eval(benchNet.parameters())
|
||||||
|
optim = opt.Adam(learning_rate=1e-3)
|
||||||
|
|
||||||
|
inputs = mx.random.normal([10, 256, 256, 3])
|
||||||
|
|
||||||
|
params = benchNet.parameters()
|
||||||
|
optim.init(params)
|
||||||
|
|
||||||
|
state = [benchNet.state, optim.state]
|
||||||
|
|
||||||
|
def loss_fn(params, image):
|
||||||
|
benchNet.update(params)
|
||||||
|
pred_image = benchNet(image)
|
||||||
|
return (pred_image - image).abs().mean()
|
||||||
|
|
||||||
|
def step(params, image):
|
||||||
|
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
||||||
|
optim.update(benchNet, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
total_time = 0.0
|
||||||
|
print("MLX:")
|
||||||
|
for i in range(steps):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
step(benchNet.parameters(), inputs)
|
||||||
|
mx.eval(state)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
|
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||||
|
total_time += (end_time - start_time) * 1000
|
||||||
|
|
||||||
|
return total_time
|
||||||
|
|
||||||
|
|
||||||
|
def bench_torch(steps: int = 20) -> float:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
class BenchNetTorch(torch.nn.Module):
|
||||||
|
# simple encoder-decoder net
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_channels=32):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.net = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.ConvTranspose2d(
|
||||||
|
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.ConvTranspose2d(
|
||||||
|
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return self.net(input)
|
||||||
|
|
||||||
|
benchNet = BenchNetTorch(3).to(device)
|
||||||
|
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
inputs = torch.randn(10, 3, 256, 256, device=device)
|
||||||
|
|
||||||
|
def loss_fn(pred_image, image):
|
||||||
|
return (pred_image - image).abs().mean()
|
||||||
|
|
||||||
|
total_time = 0.0
|
||||||
|
print("PyTorch:")
|
||||||
|
for i in range(steps):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
optim.zero_grad()
|
||||||
|
pred_image = benchNet(inputs)
|
||||||
|
loss = loss_fn(pred_image, inputs)
|
||||||
|
loss.backward()
|
||||||
|
optim.step()
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
|
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||||
|
total_time += (end_time - start_time) * 1000
|
||||||
|
|
||||||
|
return total_time
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
steps = 20
|
||||||
|
time_mlx = bench_mlx(steps)
|
||||||
|
time_torch = bench_torch(steps)
|
||||||
|
|
||||||
|
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
||||||
|
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
||||||
|
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
||||||
|
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
||||||
|
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
129
benchmarks/python/conv2d_transpose_bench_cpu.py
Normal file
129
benchmarks/python/conv2d_transpose_bench_cpu.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 1
|
||||||
|
N_iter_bench = 10
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_transpose_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv_transpose2d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_transpose_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_transpose_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv_transpose2d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_transpose_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv_transpose2d(
|
||||||
|
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
||||||
|
)
|
||||||
|
out_pt = torch.conv_transpose2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
110
benchmarks/python/conv3d_bench_cpu.py
Normal file
110
benchmarks/python/conv3d_bench_cpu.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 1
|
||||||
|
N_iter_bench = 10
|
||||||
|
N_iter_func = 5
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_3D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_3D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_3D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_3D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_3D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_3D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv3d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
||||||
|
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
143
benchmarks/python/conv3d_train_bench_cpu.py
Normal file
143
benchmarks/python/conv3d_train_bench_cpu.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn
|
||||||
|
import mlx.optimizers as opt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
class BenchNetMLX(mlx.nn.Module):
|
||||||
|
# simple encoder-decoder net
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_channels=16):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.net = mlx.nn.Sequential(
|
||||||
|
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.Conv3d(
|
||||||
|
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.ConvTranspose3d(
|
||||||
|
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
mlx.nn.ReLU(),
|
||||||
|
mlx.nn.ConvTranspose3d(
|
||||||
|
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, input):
|
||||||
|
return self.net(input)
|
||||||
|
|
||||||
|
benchNet = BenchNetMLX(3)
|
||||||
|
mx.eval(benchNet.parameters())
|
||||||
|
optim = opt.Adam(learning_rate=1e-3)
|
||||||
|
|
||||||
|
inputs = mx.random.normal(shape)
|
||||||
|
|
||||||
|
params = benchNet.parameters()
|
||||||
|
optim.init(params)
|
||||||
|
|
||||||
|
state = [benchNet.state, optim.state]
|
||||||
|
|
||||||
|
def loss_fn(params, image):
|
||||||
|
benchNet.update(params)
|
||||||
|
pred_image = benchNet(image)
|
||||||
|
return (pred_image - image).abs().mean()
|
||||||
|
|
||||||
|
def step(params, image):
|
||||||
|
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
||||||
|
optim.update(benchNet, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
total_time = 0.0
|
||||||
|
print("MLX:")
|
||||||
|
for i in range(steps):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
step(benchNet.parameters(), inputs)
|
||||||
|
mx.eval(state)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
|
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||||
|
total_time += (end_time - start_time) * 1000
|
||||||
|
|
||||||
|
return total_time
|
||||||
|
|
||||||
|
|
||||||
|
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
class BenchNetTorch(torch.nn.Module):
|
||||||
|
# simple encoder-decoder net
|
||||||
|
|
||||||
|
def __init__(self, in_channels, hidden_channels=16):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.net = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Conv3d(
|
||||||
|
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.ConvTranspose3d(
|
||||||
|
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.ConvTranspose3d(
|
||||||
|
hidden_channels, in_channels, kernel_size=3, padding=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return self.net(input)
|
||||||
|
|
||||||
|
benchNet = BenchNetTorch(3).to(device)
|
||||||
|
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
inputs = torch.randn(*shape, device=device)
|
||||||
|
|
||||||
|
def loss_fn(pred_image, image):
|
||||||
|
return (pred_image - image).abs().mean()
|
||||||
|
|
||||||
|
total_time = 0.0
|
||||||
|
print("PyTorch:")
|
||||||
|
for i in range(steps):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
optim.zero_grad()
|
||||||
|
pred_image = benchNet(inputs)
|
||||||
|
loss = loss_fn(pred_image, inputs)
|
||||||
|
loss.backward()
|
||||||
|
optim.step()
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
|
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
||||||
|
total_time += (end_time - start_time) * 1000
|
||||||
|
|
||||||
|
return total_time
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
steps = 10
|
||||||
|
time_mlx = bench_mlx(steps)
|
||||||
|
time_torch = bench_torch(steps)
|
||||||
|
|
||||||
|
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
||||||
|
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
||||||
|
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
||||||
|
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
||||||
|
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
116
benchmarks/python/conv3d_transpose_bench_cpu.py
Normal file
116
benchmarks/python/conv3d_transpose_bench_cpu.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 1
|
||||||
|
N_iter_bench = 10
|
||||||
|
N_iter_func = 5
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
||||||
|
def mx_conv_3D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv_transpose3d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_3D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_3D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv_transpose3d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_3D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_3D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_3D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv_transpose3d(
|
||||||
|
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.conv_transpose3d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
||||||
|
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
135
benchmarks/python/conv_bench.py
Normal file
135
benchmarks/python/conv_bench.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
device_name = device_name.decode("utf-8").strip("\n")
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||||
|
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
135
benchmarks/python/conv_transpose_bench.py
Normal file
135
benchmarks/python/conv_transpose_bench.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_transpose_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv_transpose2d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_transpose_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_transpose_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv_transpose2d(
|
||||||
|
a, b, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_transpose_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv_transpose2d(
|
||||||
|
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.conv_transpose2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float32",)
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
print(
|
||||||
|
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
66
benchmarks/python/distributed_bench.py
Normal file
66
benchmarks/python/distributed_bench.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Run with:
|
||||||
|
mpirun -n 2 python /path/to/distributed_bench.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def time_fn(fn, *args, **kwargs):
|
||||||
|
msg = kwargs.pop("msg", None)
|
||||||
|
world = mx.distributed.init()
|
||||||
|
if world.rank() == 0:
|
||||||
|
if msg:
|
||||||
|
print(f"Timing {msg} ...", end=" ")
|
||||||
|
else:
|
||||||
|
print(f"Timing {fn.__name__} ...", end=" ")
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(5):
|
||||||
|
mx.eval(fn(*args, **kwargs))
|
||||||
|
|
||||||
|
num_iters = 100
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(num_iters):
|
||||||
|
x = mx.eval(fn(*args, **kwargs))
|
||||||
|
toc = time.perf_counter()
|
||||||
|
|
||||||
|
msec = 1e3 * (toc - tic) / num_iters
|
||||||
|
if world.rank() == 0:
|
||||||
|
print(f"{msec:.5f} msec")
|
||||||
|
|
||||||
|
|
||||||
|
def time_all_sum():
|
||||||
|
shape = (4096,)
|
||||||
|
x = mx.random.uniform(shape=shape)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def sine(x):
|
||||||
|
for _ in range(20):
|
||||||
|
x = mx.sin(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(sine, x)
|
||||||
|
|
||||||
|
def all_sum_plain(x):
|
||||||
|
for _ in range(20):
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(all_sum_plain, x)
|
||||||
|
|
||||||
|
def all_sum_with_sine(x):
|
||||||
|
for _ in range(20):
|
||||||
|
x = mx.sin(x)
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(all_sum_with_sine, x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_all_sum()
|
84
benchmarks/python/einsum_bench.py
Normal file
84
benchmarks/python/einsum_bench.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def timeit(fn, its=100, args=[]):
|
||||||
|
for _ in range(5):
|
||||||
|
fn(*args)
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(its):
|
||||||
|
fn(*args)
|
||||||
|
toc = time.perf_counter()
|
||||||
|
return 1e3 * (toc - tic) / its
|
||||||
|
|
||||||
|
|
||||||
|
def time_little_einsum_path():
|
||||||
|
subscripts = "ik,kj->ij"
|
||||||
|
x = mx.ones((32, 32))
|
||||||
|
y = mx.ones((32, 32))
|
||||||
|
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
|
||||||
|
|
||||||
|
x = np.array(x)
|
||||||
|
y = np.array(y)
|
||||||
|
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
|
||||||
|
print("Timing little einsum path...")
|
||||||
|
print(f"MLX ... {mx_time:.3f} ms")
|
||||||
|
print(f"NumPy... {np_time:.3f} ms")
|
||||||
|
|
||||||
|
|
||||||
|
def time_big_einsum_path():
|
||||||
|
chars = list("abcdefgh")
|
||||||
|
char_to_dim = {c: v for v, c in enumerate(chars)}
|
||||||
|
|
||||||
|
num_inputs = 10
|
||||||
|
inputs = []
|
||||||
|
subscripts = []
|
||||||
|
for _ in range(num_inputs):
|
||||||
|
subscript = np.random.choice(chars, size=5, replace=False).tolist()
|
||||||
|
subscripts.append("".join(subscript))
|
||||||
|
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
|
||||||
|
subscripts = ",".join(subscripts)
|
||||||
|
|
||||||
|
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
|
||||||
|
|
||||||
|
inputs = [mx.array(x) for x in inputs]
|
||||||
|
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
|
||||||
|
print("Timing big einsum path...")
|
||||||
|
print(f"MLX ... {mx_time:.3f} ms")
|
||||||
|
print(f"NumPy... {np_time:.3f} ms")
|
||||||
|
|
||||||
|
|
||||||
|
def time_attention():
|
||||||
|
def regular_attention(x):
|
||||||
|
# shape [batch, sequence, num_heads, head_dim]
|
||||||
|
queries, keys, values = x, x, x
|
||||||
|
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
|
||||||
|
mx.eval(output)
|
||||||
|
|
||||||
|
def einsum_attention(x):
|
||||||
|
# shape [batch, sequence, num_heads, head_dim]
|
||||||
|
queries, keys, values = x, x, x
|
||||||
|
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
output = mx.einsum("ijtu,iujk->itjk", scores, values)
|
||||||
|
mx.eval(output)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, 512, 32, 128))
|
||||||
|
|
||||||
|
regular_time = timeit(regular_attention, args=(x,))
|
||||||
|
ein_time = timeit(einsum_attention, args=(x,))
|
||||||
|
print("Timing einsum attention...")
|
||||||
|
print(f"Regular ... {regular_time:.3f} ms")
|
||||||
|
print(f"Einsum ... {ein_time:.3f} ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_little_einsum_path()
|
||||||
|
time_big_einsum_path()
|
||||||
|
time_attention()
|
118
benchmarks/python/fft_bench.py
Normal file
118
benchmarks/python/fft_bench.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import sympy
|
||||||
|
import torch
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def bandwidth_gb(runtime_ms, system_size):
|
||||||
|
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
|
||||||
|
bytes_per_gb = 1e9
|
||||||
|
ms_per_s = 1e3
|
||||||
|
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
||||||
|
|
||||||
|
|
||||||
|
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
|
||||||
|
def fft_mlx(x):
|
||||||
|
if dim == 1:
|
||||||
|
out = mx.fft.fft(x)
|
||||||
|
elif dim == 2:
|
||||||
|
out = mx.fft.fft2(x)
|
||||||
|
mx.eval(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def fft_mps(x):
|
||||||
|
if dim == 1:
|
||||||
|
out = torch.fft.fft(x)
|
||||||
|
elif dim == 2:
|
||||||
|
out = torch.fft.fft2(x)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return out
|
||||||
|
|
||||||
|
bandwidths = []
|
||||||
|
for n in fft_sizes:
|
||||||
|
batch_size = system_size // n**dim
|
||||||
|
shape = [batch_size] + [n for _ in range(dim)]
|
||||||
|
if backend == "mlx":
|
||||||
|
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||||
|
x = mx.array(x_np)
|
||||||
|
mx.eval(x)
|
||||||
|
fft = fft_mlx
|
||||||
|
elif backend == "mps":
|
||||||
|
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||||
|
x = torch.tensor(x_np, device="mps")
|
||||||
|
torch.mps.synchronize()
|
||||||
|
fft = fft_mps
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
runtime_ms = measure_runtime(fft, x=x)
|
||||||
|
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
|
||||||
|
print(n, bandwidth)
|
||||||
|
bandwidths.append(bandwidth)
|
||||||
|
|
||||||
|
return np.array(bandwidths)
|
||||||
|
|
||||||
|
|
||||||
|
def time_fft():
|
||||||
|
x = np.array(range(2, 512))
|
||||||
|
system_size = int(2**26)
|
||||||
|
|
||||||
|
print("MLX GPU")
|
||||||
|
with mx.stream(mx.gpu):
|
||||||
|
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||||
|
|
||||||
|
print("MPS GPU")
|
||||||
|
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
|
||||||
|
|
||||||
|
print("CPU")
|
||||||
|
system_size = int(2**20)
|
||||||
|
with mx.stream(mx.cpu):
|
||||||
|
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||||
|
|
||||||
|
x = np.array(x)
|
||||||
|
|
||||||
|
all_indices = x - x[0]
|
||||||
|
radix_2to13 = (
|
||||||
|
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
|
||||||
|
)
|
||||||
|
bluesteins = (
|
||||||
|
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
for indices, name in [
|
||||||
|
(all_indices, "All"),
|
||||||
|
(radix_2to13, "Radix 2-13"),
|
||||||
|
(bluesteins, "Bluestein's"),
|
||||||
|
]:
|
||||||
|
# plot bandwidths
|
||||||
|
print(name)
|
||||||
|
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
|
||||||
|
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
|
||||||
|
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
|
||||||
|
plt.title(f"MLX FFT Benchmark -- {name}")
|
||||||
|
plt.xlabel("N")
|
||||||
|
plt.ylabel("Bandwidth (GB/s)")
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(f"{name}.png")
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
av_gpu_bandwidth = np.mean(gpu_bandwidths)
|
||||||
|
av_mps_bandwidth = np.mean(mps_bandwidths)
|
||||||
|
av_cpu_bandwidth = np.mean(cpu_bandwidths)
|
||||||
|
print("Average bandwidths:")
|
||||||
|
print("GPU:", av_gpu_bandwidth)
|
||||||
|
print("MPS:", av_mps_bandwidth)
|
||||||
|
print("CPU:", av_cpu_bandwidth)
|
||||||
|
|
||||||
|
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
|
||||||
|
print("Percent MLX faster than MPS: ", portion_faster * 100)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_fft()
|
53
benchmarks/python/gather_bench.py
Normal file
53
benchmarks/python/gather_bench.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import torch
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_gather_mlx(x_shape, idx_shape):
|
||||||
|
def gather(x, idx):
|
||||||
|
mx.eval(x[idx])
|
||||||
|
|
||||||
|
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
|
||||||
|
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||||
|
|
||||||
|
runtime = measure_runtime(gather, x=x, idx=idx)
|
||||||
|
print(f"MLX: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_gather_torch(x_shape, idx_shape, device):
|
||||||
|
def gather(x, idx, device):
|
||||||
|
_ = x[idx]
|
||||||
|
if device == torch.device("mps"):
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
|
||||||
|
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
|
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
|
||||||
|
print(f"PyTorch: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device("mps")
|
||||||
|
|
||||||
|
idx_shapes = [(1_000_000,), (100_000,), ()]
|
||||||
|
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
|
||||||
|
|
||||||
|
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
|
||||||
|
print("=" * 20)
|
||||||
|
print(f"X {x_shape}, Indices {idx_shape}")
|
||||||
|
benchmark_gather_mlx(x_shape, idx_shape)
|
||||||
|
benchmark_gather_torch(x_shape, idx_shape, device=device)
|
70
benchmarks/python/hadamard_bench.py
Normal file
70
benchmarks/python/hadamard_bench.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def had(x):
|
||||||
|
y = mx.hadamard_transform(x)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def copy(x):
|
||||||
|
y = x + 1.0
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def run(dtype):
|
||||||
|
system_size = 2**26
|
||||||
|
outputs = {}
|
||||||
|
for test_fn in (had, copy):
|
||||||
|
for m in [1, 12, 20, 28]:
|
||||||
|
if test_fn == copy:
|
||||||
|
key = "copy"
|
||||||
|
elif m == 1:
|
||||||
|
key = "had_2^k"
|
||||||
|
else:
|
||||||
|
key = "had_m*2^k"
|
||||||
|
outputs.setdefault(key, {})
|
||||||
|
for k in range(7, 14):
|
||||||
|
n = m * 2**k
|
||||||
|
if n > 2**15:
|
||||||
|
continue
|
||||||
|
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
|
||||||
|
x = mx.array(x_np)
|
||||||
|
runtime_ms = measure_runtime(test_fn, x=x)
|
||||||
|
bytes_per_gb = 1e9
|
||||||
|
ms_per_s = 1e3
|
||||||
|
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
|
||||||
|
bandwidth_gb = (
|
||||||
|
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
|
||||||
|
)
|
||||||
|
print(n, bandwidth_gb)
|
||||||
|
outputs[key][n] = bandwidth_gb
|
||||||
|
|
||||||
|
colors = {
|
||||||
|
"copy": "black",
|
||||||
|
"had_2^k": "steelblue",
|
||||||
|
"had_m*2^k": "skyblue",
|
||||||
|
}
|
||||||
|
for key, output in outputs.items():
|
||||||
|
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
|
||||||
|
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
|
||||||
|
plt.xlabel("N")
|
||||||
|
plt.ylabel("Bandwidth (GB/s)")
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(f"bench_{dtype.__name__}.png")
|
||||||
|
plt.clf()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--fp16", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
dtype = np.float16 if args.fp16 else np.float32
|
||||||
|
run(dtype)
|
41
benchmarks/python/layer_norm_bench.py
Normal file
41
benchmarks/python/layer_norm_bench.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def layer_norm(x, w, b, eps):
|
||||||
|
ot = x.dtype
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
mu = mx.mean(x, -1, keepdims=True)
|
||||||
|
v = mx.var(x, -1, keepdims=True)
|
||||||
|
return (x - mu) * mx.rsqrt(v + eps) * w + b
|
||||||
|
|
||||||
|
|
||||||
|
def time_layer_norm():
|
||||||
|
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
|
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
|
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||||
|
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
|
def layer_norm_loop(g, x, w, b):
|
||||||
|
gx, gw, gb = x, w, b
|
||||||
|
for _ in range(32):
|
||||||
|
gx, gw, gb = g(gx, gw, gb, y)
|
||||||
|
return gx, gw, gb
|
||||||
|
|
||||||
|
time_fn(layer_norm_loop, g1, x, w, b)
|
||||||
|
time_fn(layer_norm_loop, g2, x, w, b)
|
||||||
|
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||||
|
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_layer_norm()
|
@@ -1,198 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
from flax import linen as nn
|
|
||||||
|
|
||||||
|
|
||||||
class RoPE(nn.Module):
|
|
||||||
dims: int
|
|
||||||
traditional: bool = False
|
|
||||||
|
|
||||||
def _compute_rope(self, costheta, sintheta, x):
|
|
||||||
x1 = x[..., : self.dims // 2]
|
|
||||||
x2 = x[..., self.dims // 2 : self.dims]
|
|
||||||
rx1 = x1 * costheta - x2 * sintheta
|
|
||||||
rx2 = x1 * sintheta + x2 * costheta
|
|
||||||
|
|
||||||
if self.dims < x.shape[-1]:
|
|
||||||
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
|
|
||||||
else:
|
|
||||||
rx = jnp.concatenate([rx1, rx2], axis=-1)
|
|
||||||
|
|
||||||
return rx
|
|
||||||
|
|
||||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
|
||||||
x1 = x[..., ::2]
|
|
||||||
x2 = x[..., 1::2]
|
|
||||||
rx1 = x1 * costheta - x2 * sintheta
|
|
||||||
rx2 = x1 * sintheta + x2 * costheta
|
|
||||||
|
|
||||||
if self.dims < x.shape[-1]:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"RoPE doesn't implement partial traditional application"
|
|
||||||
)
|
|
||||||
|
|
||||||
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
|
||||||
|
|
||||||
return rx
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_cos_sin_theta(
|
|
||||||
N: int,
|
|
||||||
D: int,
|
|
||||||
offset: int = 0,
|
|
||||||
base: float = 10000,
|
|
||||||
dtype=jnp.float32,
|
|
||||||
):
|
|
||||||
D = D // 2
|
|
||||||
positions = jnp.arange(offset, N, dtype=dtype)
|
|
||||||
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
|
|
||||||
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
|
|
||||||
costheta = jnp.cos(theta)
|
|
||||||
sintheta = jnp.sin(theta)
|
|
||||||
|
|
||||||
return costheta, sintheta
|
|
||||||
|
|
||||||
@nn.compact
|
|
||||||
def __call__(self, x, offset: int = 0):
|
|
||||||
shape = x.shape
|
|
||||||
x = x.reshape((-1, shape[-2], shape[-1]))
|
|
||||||
N = x.shape[1] + offset
|
|
||||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
|
||||||
N, self.dims, offset=offset, dtype=x.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
rope = (
|
|
||||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
|
||||||
)
|
|
||||||
rx = rope(costheta, sintheta, x)
|
|
||||||
|
|
||||||
return rx.reshape(shape)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
|
||||||
dims: int
|
|
||||||
num_heads: int
|
|
||||||
dtype: jnp.dtype
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
num_heads = self.num_heads
|
|
||||||
dims = self.dims
|
|
||||||
|
|
||||||
self.rope = RoPE(dims // num_heads, True)
|
|
||||||
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
|
|
||||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
|
||||||
queries = self.query_proj(queries)
|
|
||||||
keys = self.key_proj(keys)
|
|
||||||
values = self.value_proj(values)
|
|
||||||
|
|
||||||
num_heads = self.num_heads
|
|
||||||
B, L, D = queries.shape
|
|
||||||
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
|
||||||
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
|
||||||
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
key_cache, value_cache = cache
|
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
||||||
keys = jnp.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = jnp.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
|
||||||
queries = self.rope(queries)
|
|
||||||
keys = self.rope(keys)
|
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
|
||||||
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
|
|
||||||
if mask is not None:
|
|
||||||
scores = scores + mask
|
|
||||||
scores = jax.nn.softmax(scores, axis=-1)
|
|
||||||
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
|
|
||||||
|
|
||||||
return self.out_proj(values_hat), (keys, values)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaEncoderLayer(nn.Module):
|
|
||||||
dims: int
|
|
||||||
mlp_dims: int
|
|
||||||
num_heads: int
|
|
||||||
dtype: jnp.dtype
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
dims = self.dims
|
|
||||||
mlp_dims = self.mlp_dims
|
|
||||||
num_heads = self.num_heads
|
|
||||||
|
|
||||||
self.attention = LlamaAttention(dims, num_heads, dtype)
|
|
||||||
|
|
||||||
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
|
|
||||||
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
|
|
||||||
|
|
||||||
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
|
||||||
|
|
||||||
def __call__(self, x, mask=None, cache=None):
|
|
||||||
y = self.norm1(x)
|
|
||||||
y, cache = self.attention(y, y, y, mask, cache)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
y = self.norm2(x)
|
|
||||||
a = self.linear1(y)
|
|
||||||
b = self.linear2(y)
|
|
||||||
y = jax.nn.silu(a) * b
|
|
||||||
y = self.linear3(y)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
return x, cache
|
|
||||||
|
|
||||||
|
|
||||||
def measure(model, x, cache):
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
jax.block_until_ready((y, c))
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
jax.block_until_ready((y, c))
|
|
||||||
|
|
||||||
end = time.time()
|
|
||||||
return (end - start) * 1000 / 5
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
H = 32
|
|
||||||
D = 4096
|
|
||||||
F = 43 * 256
|
|
||||||
C = 1000
|
|
||||||
dtype = jnp.float16
|
|
||||||
|
|
||||||
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
|
|
||||||
|
|
||||||
x = jax.random.normal(k1, (1, 1, D), dtype)
|
|
||||||
cache = [
|
|
||||||
jax.random.normal(k2, [1, H, C, D // H], dtype),
|
|
||||||
jax.random.normal(k3, [1, H, C, D // H], dtype),
|
|
||||||
]
|
|
||||||
|
|
||||||
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
|
|
||||||
params = layer.init(k4, x, mask=None, cache=cache)["params"]
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def model_fn(x, mask, cache):
|
|
||||||
return layer.apply({"params": params}, x, mask=mask, cache=cache)
|
|
||||||
|
|
||||||
T = measure(model_fn, x, cache)
|
|
||||||
|
|
||||||
print("Time per layer per token:", T, "ms")
|
|
||||||
print("Lower bound total time per token:", T * 32, "ms")
|
|
@@ -1,118 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
import mlx.utils
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
|
||||||
def __init__(self, dims: int, num_heads: int):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.rope = nn.RoPE(dims // num_heads, True)
|
|
||||||
self.query_proj = nn.Linear(dims, dims, False)
|
|
||||||
self.key_proj = nn.Linear(dims, dims, False)
|
|
||||||
self.value_proj = nn.Linear(dims, dims, False)
|
|
||||||
self.out_proj = nn.Linear(dims, dims, False)
|
|
||||||
|
|
||||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
|
||||||
queries = self.query_proj(queries)
|
|
||||||
keys = self.key_proj(keys)
|
|
||||||
values = self.value_proj(values)
|
|
||||||
|
|
||||||
num_heads = self.num_heads
|
|
||||||
B, L, D = queries.shape
|
|
||||||
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
|
||||||
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
|
||||||
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
key_cache, value_cache = cache
|
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
|
||||||
queries = self.rope(queries)
|
|
||||||
keys = self.rope(keys)
|
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
|
||||||
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
|
|
||||||
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
|
|
||||||
if mask is not None:
|
|
||||||
scores = scores + mask
|
|
||||||
scores = mx.softmax(scores, axis=-1)
|
|
||||||
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
|
|
||||||
|
|
||||||
return self.out_proj(values_hat), (keys, values)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaEncoderLayer(nn.Module):
|
|
||||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.attention = LlamaAttention(dims, num_heads)
|
|
||||||
|
|
||||||
self.norm1 = nn.RMSNorm(dims)
|
|
||||||
self.norm2 = nn.RMSNorm(dims)
|
|
||||||
|
|
||||||
self.linear1 = nn.Linear(dims, mlp_dims, False)
|
|
||||||
self.linear2 = nn.Linear(dims, mlp_dims, False)
|
|
||||||
self.linear3 = nn.Linear(mlp_dims, dims, False)
|
|
||||||
|
|
||||||
def __call__(self, x, mask=None, cache=None):
|
|
||||||
y = self.norm1(x)
|
|
||||||
y, cache = self.attention(y, y, y, mask, cache)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
y = self.norm2(x)
|
|
||||||
a = self.linear1(y)
|
|
||||||
b = self.linear2(y)
|
|
||||||
y = a * mx.sigmoid(a) * b
|
|
||||||
y = self.linear3(y)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
return x, cache
|
|
||||||
|
|
||||||
|
|
||||||
def measure(model, x, cache):
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
mx.eval(y, c)
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
rs = []
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
rs.append((y, c))
|
|
||||||
mx.eval(rs)
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
return (end - start) * 1000 / 5
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
H = 32
|
|
||||||
D = 4096
|
|
||||||
F = 43 * 256
|
|
||||||
C = 1000
|
|
||||||
mx.set_default_device(mx.gpu)
|
|
||||||
dtype = mx.float16
|
|
||||||
|
|
||||||
layer = LlamaEncoderLayer(D, F, H)
|
|
||||||
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
|
|
||||||
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
|
|
||||||
x = mx.random.normal([1, 1, D], dtype=dtype)
|
|
||||||
cache = [
|
|
||||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
|
||||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
|
||||||
]
|
|
||||||
mx.eval(x, cache)
|
|
||||||
|
|
||||||
T = measure(layer, x, cache)
|
|
||||||
|
|
||||||
print("Time per layer per token:", T, "ms")
|
|
||||||
print("Lower bound total time per token:", T * 32, "ms")
|
|
@@ -1,199 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.mps
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
|
||||||
if x.device != torch.device("cpu"):
|
|
||||||
torch.mps.synchronize()
|
|
||||||
|
|
||||||
|
|
||||||
class RoPE(nn.Module):
|
|
||||||
def __init__(self, dims: int, traditional: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.dims = dims
|
|
||||||
self.traditional = traditional
|
|
||||||
|
|
||||||
def _compute_rope(self, costheta, sintheta, x):
|
|
||||||
x1 = x[..., : self.dims // 2]
|
|
||||||
x2 = x[..., self.dims // 2 : self.dims]
|
|
||||||
rx1 = x1 * costheta - x2 * sintheta
|
|
||||||
rx2 = x1 * sintheta + x2 * costheta
|
|
||||||
|
|
||||||
if self.dims < x.shape[-1]:
|
|
||||||
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
|
|
||||||
else:
|
|
||||||
rx = torch.cat([rx1, rx2], dim=-1)
|
|
||||||
|
|
||||||
return rx
|
|
||||||
|
|
||||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
|
||||||
x1 = x[..., ::2]
|
|
||||||
x2 = x[..., 1::2]
|
|
||||||
rx1 = x1 * costheta - x2 * sintheta
|
|
||||||
rx2 = x1 * sintheta + x2 * costheta
|
|
||||||
|
|
||||||
if self.dims < x.shape[-1]:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"RoPE doesn't implement partial traditional application"
|
|
||||||
)
|
|
||||||
|
|
||||||
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
|
|
||||||
|
|
||||||
return rx
|
|
||||||
|
|
||||||
def forward(self, x, offset: int = 0):
|
|
||||||
shape = x.shape
|
|
||||||
x = x.view(-1, shape[-2], shape[-1])
|
|
||||||
N = x.shape[1] + offset
|
|
||||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
|
||||||
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
rope = (
|
|
||||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
|
||||||
)
|
|
||||||
rx = rope(costheta, sintheta, x)
|
|
||||||
|
|
||||||
return rx.view(*shape)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_cos_sin_theta(
|
|
||||||
N: int,
|
|
||||||
D: int,
|
|
||||||
offset: int = 0,
|
|
||||||
base: float = 10000,
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.float32,
|
|
||||||
):
|
|
||||||
D = D // 2
|
|
||||||
positions = torch.arange(offset, N, dtype=dtype, device=device)
|
|
||||||
freqs = torch.exp(
|
|
||||||
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
|
|
||||||
)
|
|
||||||
theta = positions.view(-1, 1) * freqs.view(1, -1)
|
|
||||||
costheta = torch.cos(theta)
|
|
||||||
sintheta = torch.sin(theta)
|
|
||||||
|
|
||||||
return costheta, sintheta
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dims: int, epsilon: float = 1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.gamma = nn.Parameter(torch.ones((dims,)))
|
|
||||||
self.epsilon = epsilon
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
|
|
||||||
return self.gamma * x * n
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
|
||||||
def __init__(self, dims: int, num_heads: int):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.rope = RoPE(dims // num_heads, True)
|
|
||||||
self.query_proj = nn.Linear(dims, dims, bias=False)
|
|
||||||
self.key_proj = nn.Linear(dims, dims, bias=False)
|
|
||||||
self.value_proj = nn.Linear(dims, dims, bias=False)
|
|
||||||
self.out_proj = nn.Linear(dims, dims, bias=False)
|
|
||||||
|
|
||||||
def forward(self, queries, keys, values, mask=None, cache=None):
|
|
||||||
queries = self.query_proj(queries)
|
|
||||||
keys = self.key_proj(keys)
|
|
||||||
values = self.value_proj(values)
|
|
||||||
|
|
||||||
num_heads = self.num_heads
|
|
||||||
B, L, D = queries.shape
|
|
||||||
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
|
||||||
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
|
||||||
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
key_cache, value_cache = cache
|
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
||||||
keys = torch.cat([key_cache, keys], dim=2)
|
|
||||||
values = torch.cat([value_cache, values], dim=2)
|
|
||||||
else:
|
|
||||||
queries = self.rope(queries)
|
|
||||||
keys = self.rope(keys)
|
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
|
||||||
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
|
|
||||||
if mask is not None:
|
|
||||||
scores = scores + mask
|
|
||||||
scores = torch.softmax(scores, dim=-1)
|
|
||||||
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
|
|
||||||
return self.out_proj(values_hat), (keys, values)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaEncoderLayer(nn.Module):
|
|
||||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.attention = LlamaAttention(dims, num_heads)
|
|
||||||
|
|
||||||
self.norm1 = RMSNorm(dims)
|
|
||||||
self.norm2 = RMSNorm(dims)
|
|
||||||
|
|
||||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
|
||||||
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
|
||||||
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x, mask=None, cache=None):
|
|
||||||
y = self.norm1(x)
|
|
||||||
y, cache = self.attention(y, y, y, mask, cache)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
y = self.norm2(x)
|
|
||||||
a = self.linear1(y)
|
|
||||||
b = self.linear2(y)
|
|
||||||
y = torch.nn.functional.silu(a) * b
|
|
||||||
y = self.linear3(y)
|
|
||||||
x = x + y
|
|
||||||
|
|
||||||
return x, cache
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def measure(model, x, cache):
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
sync_if_needed(x)
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
for i in range(5):
|
|
||||||
y, c = model(x, mask=None, cache=cache)
|
|
||||||
sync_if_needed(x)
|
|
||||||
end = time.time()
|
|
||||||
return (end - start) * 1000 / 5
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
H = 32
|
|
||||||
D = 4096
|
|
||||||
F = 43 * 256
|
|
||||||
C = 1000
|
|
||||||
device = torch.device("mps")
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
|
|
||||||
x = torch.randn(1, 1, D).to(device).to(dtype)
|
|
||||||
cache = [
|
|
||||||
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
|
||||||
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
|
||||||
]
|
|
||||||
|
|
||||||
T = measure(layer, x, cache)
|
|
||||||
|
|
||||||
print("Time per layer per token:", T, "ms")
|
|
||||||
print("Lower bound total time per token:", T * 32, "ms")
|
|
39
benchmarks/python/rms_norm_bench.py
Normal file
39
benchmarks/python/rms_norm_bench.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(x, w, eps):
|
||||||
|
ot = x.dtype
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||||
|
return (x * n).astype(ot) * w
|
||||||
|
|
||||||
|
|
||||||
|
def time_rms_norm():
|
||||||
|
f1 = lambda x, w, y: (rms_norm(x, w, 1e-5) * y).sum()
|
||||||
|
f2 = lambda x, w, y: (mx.fast.rms_norm(x, w, 1e-5) * y).sum()
|
||||||
|
g1 = mx.grad(f1, argnums=(0, 1))
|
||||||
|
g2 = mx.grad(f2, argnums=(0, 1))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||||
|
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||||
|
mx.eval(x, w, y)
|
||||||
|
|
||||||
|
def rms_norm_loop(g, x, w):
|
||||||
|
gx, gw = x, w
|
||||||
|
for _ in range(32):
|
||||||
|
gx, gw = g(gx, gw, y)
|
||||||
|
return gx, gw
|
||||||
|
|
||||||
|
time_fn(rms_norm_loop, g1, x, w)
|
||||||
|
time_fn(rms_norm_loop, g2, x, w)
|
||||||
|
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
||||||
|
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_rms_norm()
|
35
benchmarks/python/rope_bench.py
Normal file
35
benchmarks/python/rope_bench.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
|
||||||
|
def time_rope():
|
||||||
|
rope = nn.RoPE(64)
|
||||||
|
|
||||||
|
# vec
|
||||||
|
x = mx.random.uniform(shape=(1, 32, 1, 128)).astype(mx.float16)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def rope_vec(x):
|
||||||
|
for _ in range(32):
|
||||||
|
x = rope(x, offset=100)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(rope_vec, x)
|
||||||
|
|
||||||
|
# matrix
|
||||||
|
x = mx.random.uniform(shape=(1, 32, 1024, 128)).astype(mx.float16)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
def rope_mat(x):
|
||||||
|
for _ in range(32):
|
||||||
|
x = rope(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(rope_mat, x)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_rope()
|
96
benchmarks/python/scatter_bench.py
Normal file
96
benchmarks/python/scatter_bench.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import torch
|
||||||
|
from time_utils import measure_runtime
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||||
|
def scatter(dst, x, idx):
|
||||||
|
dst[tuple(idx)] = x
|
||||||
|
mx.eval(dst)
|
||||||
|
|
||||||
|
idx = []
|
||||||
|
for idx_shape in idx_shapes:
|
||||||
|
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
|
||||||
|
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||||
|
dst = mx.random.normal(dst_shape).astype(mx.float32)
|
||||||
|
|
||||||
|
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
|
||||||
|
print(f"MLX: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||||
|
def scatter(dst, x, idx, device):
|
||||||
|
dst[tuple(idx)] = x
|
||||||
|
if device == torch.device("mps"):
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
idx = []
|
||||||
|
for idx_shape in idx_shapes:
|
||||||
|
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
|
||||||
|
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||||
|
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
|
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
|
||||||
|
print(f"PyTorch: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device("mps")
|
||||||
|
|
||||||
|
dst_shapes = [
|
||||||
|
(10, 64),
|
||||||
|
(100_000, 64),
|
||||||
|
(1_000_000, 64),
|
||||||
|
(100_000,),
|
||||||
|
(200_000,),
|
||||||
|
(20_000_000,),
|
||||||
|
(10000, 64),
|
||||||
|
(100, 64),
|
||||||
|
(100, 10_000, 64),
|
||||||
|
(10, 100, 100, 21),
|
||||||
|
(1_000, 1_000, 10),
|
||||||
|
]
|
||||||
|
idx_shapes = [
|
||||||
|
[(1_000_000,)],
|
||||||
|
[(1_000_000,)],
|
||||||
|
[(100_000,)],
|
||||||
|
[(1_000_000,)],
|
||||||
|
[(20_000_000,)],
|
||||||
|
[(20_000_000,)],
|
||||||
|
[(1000000,)],
|
||||||
|
[(10000000,)],
|
||||||
|
[(1_000,)],
|
||||||
|
[(10_000,)],
|
||||||
|
[(1_000,), (1_000,)],
|
||||||
|
]
|
||||||
|
x_shapes = [
|
||||||
|
(1_000_000, 64),
|
||||||
|
(1_000_000, 64),
|
||||||
|
(100_000, 64),
|
||||||
|
(1_000_000,),
|
||||||
|
(20_000_000,),
|
||||||
|
(20_000_000,),
|
||||||
|
(1000000, 64),
|
||||||
|
(10000000, 64),
|
||||||
|
(1_000, 10_000, 64),
|
||||||
|
(10_000, 100, 100, 21),
|
||||||
|
(1_000, 10),
|
||||||
|
]
|
||||||
|
|
||||||
|
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||||
|
print("=" * 20)
|
||||||
|
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
|
||||||
|
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||||
|
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
189
benchmarks/python/sdpa_bench.py
Normal file
189
benchmarks/python/sdpa_bench.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||||
|
device_name = device_name.decode("utf-8").strip("\n")
|
||||||
|
|
||||||
|
N_warmup = 5
|
||||||
|
N_iter_bench = 40
|
||||||
|
N_iter_func = 8
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, *args):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(*args)
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(*args)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_sdpa_fused_inner(q, k, v, scale):
|
||||||
|
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||||
|
q_dtype = q.dtype
|
||||||
|
q = q * mx.array(scale, q_dtype)
|
||||||
|
n_q_heads = q.shape[-3]
|
||||||
|
n_kv_heads = k.shape[-3]
|
||||||
|
n_repeats = n_q_heads // n_kv_heads
|
||||||
|
|
||||||
|
B = q.shape[0]
|
||||||
|
L = q.shape[2]
|
||||||
|
|
||||||
|
if n_repeats > 1:
|
||||||
|
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||||
|
k = mx.expand_dims(k, 2)
|
||||||
|
v = mx.expand_dims(v, 2)
|
||||||
|
|
||||||
|
scores = q @ mx.swapaxes(k, -1, -2)
|
||||||
|
if f32softmax:
|
||||||
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
||||||
|
else:
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
|
||||||
|
out = scores @ v
|
||||||
|
if n_repeats > 1:
|
||||||
|
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_spda_unfused(q, k, v, scale, transpose):
|
||||||
|
q_out = q
|
||||||
|
if transpose:
|
||||||
|
k = mx.transpose(k, (0, 2, 1, 3))
|
||||||
|
v = mx.transpose(v, (0, 2, 1, 3))
|
||||||
|
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
if transpose:
|
||||||
|
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||||
|
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
||||||
|
if transpose:
|
||||||
|
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||||
|
|
||||||
|
mx.eval(q_out)
|
||||||
|
return q_out
|
||||||
|
|
||||||
|
|
||||||
|
def mlx_spda_fused(q, k, v, scale, transpose):
|
||||||
|
q_out = q
|
||||||
|
if transpose:
|
||||||
|
k = mx.transpose(k, (0, 2, 1, 3))
|
||||||
|
v = mx.transpose(v, (0, 2, 1, 3))
|
||||||
|
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
if transpose:
|
||||||
|
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||||
|
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
||||||
|
if transpose:
|
||||||
|
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||||
|
|
||||||
|
mx.eval(q_out)
|
||||||
|
return q_out
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
||||||
|
shape_q = (
|
||||||
|
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
||||||
|
)
|
||||||
|
shape_kv = (
|
||||||
|
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
||||||
|
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||||
|
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||||
|
|
||||||
|
scale = math.sqrt(1.0 / head_dim)
|
||||||
|
|
||||||
|
q_mx = mx.array(q_np)
|
||||||
|
k_mx = mx.array(k_np)
|
||||||
|
v_mx = mx.array(v_np)
|
||||||
|
|
||||||
|
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
||||||
|
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
||||||
|
|
||||||
|
if transpose:
|
||||||
|
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
||||||
|
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
||||||
|
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
||||||
|
|
||||||
|
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
||||||
|
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
||||||
|
|
||||||
|
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx_fused, time_mlx_unfused
|
||||||
|
|
||||||
|
|
||||||
|
def get_gflop_count(B, M, N, K):
|
||||||
|
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||||
|
|
||||||
|
dtypes = ("float16", "float32")[:1]
|
||||||
|
transposes = (False,)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
shapes_64 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 32, 32, 64, 32, 32),
|
||||||
|
( 1, 64, 64, 64, 32, 32),
|
||||||
|
( 1, 128, 128, 64, 32, 32),
|
||||||
|
( 1, 256, 256, 64, 32, 32),
|
||||||
|
( 1, 512, 512, 64, 32, 32),
|
||||||
|
( 1, 1024, 1024, 64, 32, 32),
|
||||||
|
( 1, 2048, 2048, 64, 32, 32),
|
||||||
|
( 1, 4096, 4096, 64, 32, 32),
|
||||||
|
)
|
||||||
|
|
||||||
|
shapes_80 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 1024, 1024, 80, 32, 32),
|
||||||
|
( 1, 2048, 2048, 80, 32, 32),
|
||||||
|
( 1, 4096, 4096, 80, 32, 32),
|
||||||
|
)
|
||||||
|
|
||||||
|
shapes_128 = (
|
||||||
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
|
( 1, 1024, 1024, 128, 32, 32),
|
||||||
|
( 1, 2048, 2048, 128, 32, 32),
|
||||||
|
( 1, 4096, 4096, 128, 32, 32),
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
shapes = shapes_64 + shapes_80 + shapes_128
|
||||||
|
|
||||||
|
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
for transpose in transposes:
|
||||||
|
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||||
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
||||||
|
)
|
||||||
|
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||||
|
t_str = 1 if transpose else 0
|
||||||
|
print(
|
||||||
|
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
58
benchmarks/python/sdpa_vector_bench.py
Normal file
58
benchmarks/python/sdpa_vector_bench.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
L = 16384
|
||||||
|
H = 32
|
||||||
|
H_k = H // 4
|
||||||
|
D = 128
|
||||||
|
dtype = mx.float16
|
||||||
|
loops = 10
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q, k, v):
|
||||||
|
def _sdpa(q, k, v):
|
||||||
|
B, Hq, L, D = q.shape
|
||||||
|
_, Hk, S, _ = k.shape
|
||||||
|
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||||
|
k = k[:, :, None, :, :]
|
||||||
|
v = v[:, :, None, :, :]
|
||||||
|
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||||
|
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||||
|
o = p @ v
|
||||||
|
return o.reshape(B, Hq, L, D)
|
||||||
|
|
||||||
|
for i in range(loops):
|
||||||
|
q = _sdpa(q, k, v)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
def sdpa(q, k, v):
|
||||||
|
for i in range(loops):
|
||||||
|
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
def time_self_attention_primitives():
|
||||||
|
mx.random.seed(3)
|
||||||
|
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||||
|
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
|
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
|
mx.eval(q, k, v)
|
||||||
|
time_fn(attention, q, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
def time_self_attention_sdpa():
|
||||||
|
mx.random.seed(3)
|
||||||
|
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||||
|
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
|
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
|
mx.eval(q, k, v)
|
||||||
|
time_fn(sdpa, q, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_self_attention_sdpa()
|
||||||
|
time_self_attention_primitives()
|
@@ -44,6 +44,13 @@ def time_matmul():
|
|||||||
time_fn(mx.matmul, a, b)
|
time_fn(mx.matmul, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_maximum():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
b = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
mx.eval(a, b)
|
||||||
|
time_fn(mx.maximum, a, b)
|
||||||
|
|
||||||
|
|
||||||
def time_negative():
|
def time_negative():
|
||||||
a = mx.random.uniform(shape=(10000, 1000))
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
@@ -101,6 +108,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
time_add()
|
time_add()
|
||||||
time_matmul()
|
time_matmul()
|
||||||
|
time_maximum()
|
||||||
time_exp()
|
time_exp()
|
||||||
time_negative()
|
time_negative()
|
||||||
time_logsumexp()
|
time_logsumexp()
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -6,7 +6,11 @@ import mlx.core as mx
|
|||||||
|
|
||||||
|
|
||||||
def time_fn(fn, *args, **kwargs):
|
def time_fn(fn, *args, **kwargs):
|
||||||
print(f"Timing {fn.__name__} ...", end=" ")
|
msg = kwargs.pop("msg", None)
|
||||||
|
if msg:
|
||||||
|
print(f"Timing {msg} ...", end=" ")
|
||||||
|
else:
|
||||||
|
print(f"Timing {fn.__name__} ...", end=" ")
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
@@ -20,3 +24,15 @@ def time_fn(fn, *args, **kwargs):
|
|||||||
|
|
||||||
msec = 1e3 * (toc - tic) / num_iters
|
msec = 1e3 * (toc - tic) / num_iters
|
||||||
print(f"{msec:.5f} msec")
|
print(f"{msec:.5f} msec")
|
||||||
|
|
||||||
|
|
||||||
|
def measure_runtime(fn, **kwargs):
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
fn(**kwargs)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
iters = 100
|
||||||
|
for _ in range(iters):
|
||||||
|
fn(**kwargs)
|
||||||
|
return (time.time() - tic) * 1000 / iters
|
||||||
|
@@ -1,56 +1,41 @@
|
|||||||
include(CMakeParseArguments)
|
include(CMakeParseArguments)
|
||||||
|
|
||||||
###############################################################################
|
# ##############################################################################
|
||||||
# Build metal library
|
# Build metal library
|
||||||
#
|
#
|
||||||
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
||||||
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
||||||
#
|
#
|
||||||
# Args:
|
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||||
# TARGET: Custom target to be added for the metal library
|
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||||
# TITLE: Name of the .metallib
|
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||||
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
# files (like headers)
|
||||||
# SOURCES: List of source files
|
|
||||||
# INCLUDE_DIRS: List of include dirs
|
|
||||||
# DEPS: List of dependency files (like headers)
|
|
||||||
#
|
#
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
cmake_parse_arguments(
|
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
MTLLIB
|
|
||||||
""
|
|
||||||
"${oneValueArgs}"
|
|
||||||
"${multiValueArgs}"
|
|
||||||
${ARGN}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set output
|
# Set output
|
||||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||||
|
|
||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||||
|
|
||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||||
COMMAND xcrun -sdk macosx metal
|
COMMAND
|
||||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
xcrun -sdk macosx metal
|
||||||
${MTLLIB_COMPILE_OPTIONS}
|
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||||
${MTLLIB_SOURCES}
|
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||||
-o ${MTLLIB_BUILD_TARGET}
|
|
||||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||||
COMMAND_EXPAND_LISTS
|
COMMAND_EXPAND_LISTS
|
||||||
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
||||||
VERBATIM
|
VERBATIM)
|
||||||
)
|
|
||||||
|
|
||||||
# Add metallib custom target
|
# Add metallib custom target
|
||||||
add_custom_target(
|
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
|
||||||
${MTLLIB_TARGET}
|
|
||||||
DEPENDS
|
|
||||||
${MTLLIB_BUILD_TARGET}
|
|
||||||
)
|
|
||||||
|
|
||||||
endmacro(mlx_build_metallib)
|
endmacro(mlx_build_metallib)
|
||||||
|
1
docs/.gitignore
vendored
1
docs/.gitignore
vendored
@@ -1,2 +1,3 @@
|
|||||||
src/python/_autosummary*/
|
src/python/_autosummary*/
|
||||||
src/python/nn/_autosummary*/
|
src/python/nn/_autosummary*/
|
||||||
|
src/python/optimizers/_autosummary*/
|
||||||
|
50
docs/Doxyfile
Normal file
50
docs/Doxyfile
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
################################################################################
|
||||||
|
# Primary project setup. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
PROJECT_NAME = "MLX"
|
||||||
|
OUTPUT_DIRECTORY = build
|
||||||
|
XML_OUTPUT = xml
|
||||||
|
HTML_OUTPUT = html
|
||||||
|
STRIP_FROM_PATH = ../
|
||||||
|
INPUT = ../mlx
|
||||||
|
FILE_PATTERNS = *.h
|
||||||
|
EXCLUDE_PATTERNS = */private/*
|
||||||
|
CREATE_SUBDIRS = NO
|
||||||
|
FULL_PATH_NAMES = YES
|
||||||
|
RECURSIVE = YES
|
||||||
|
GENERATE_HTML = YES
|
||||||
|
GENERATE_LATEX = NO
|
||||||
|
GENERATE_XML = YES
|
||||||
|
XML_PROGRAMLISTING = YES
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Doxygen preprocessor / parser control. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
ENABLE_PREPROCESSING = YES
|
||||||
|
MACRO_EXPANSION = YES
|
||||||
|
EXPAND_ONLY_PREDEF = NO
|
||||||
|
SKIP_FUNCTION_MACROS = NO
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Compound extraction control. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
EXTRACT_ALL = YES
|
||||||
|
EXTRACT_PACKAGE = YES
|
||||||
|
EXTRACT_STATIC = YES
|
||||||
|
CASE_SENSE_NAMES = NO
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Docstring control / customization. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
JAVADOC_AUTOBRIEF = YES
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Warning suppression. #
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
QUIET = YES
|
||||||
|
WARN_IF_UNDOCUMENTED = NO
|
@@ -2,12 +2,16 @@
|
|||||||
|
|
||||||
### Setup (do once)
|
### Setup (do once)
|
||||||
|
|
||||||
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
|
Install Doxygen:
|
||||||
for example with `conda`:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
conda install sphinx
|
brew install doxygen
|
||||||
pip install sphinx-book-theme
|
```
|
||||||
|
|
||||||
|
Install Python packages:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Build
|
### Build
|
||||||
@@ -15,7 +19,7 @@ pip install sphinx-book-theme
|
|||||||
Build the docs from `mlx/docs/`
|
Build the docs from `mlx/docs/`
|
||||||
|
|
||||||
```
|
```
|
||||||
make html
|
doxygen && make html
|
||||||
```
|
```
|
||||||
|
|
||||||
View the docs by running a server in `mlx/docs/build/html/`:
|
View the docs by running a server in `mlx/docs/build/html/`:
|
||||||
|
4
docs/requirements.txt
Normal file
4
docs/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
sphinx
|
||||||
|
breathe
|
||||||
|
sphinx-book-theme
|
||||||
|
mlx
|
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 746 KiB |
Binary file not shown.
Before Width: | Height: | Size: 7.2 KiB After Width: | Height: | Size: 76 KiB |
BIN
docs/src/_static/mlx_logo_dark.png
Normal file
BIN
docs/src/_static/mlx_logo_dark.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 48 KiB |
@@ -4,16 +4,17 @@
|
|||||||
|
|
||||||
.. autoclass:: {{ objname }}
|
.. autoclass:: {{ objname }}
|
||||||
|
|
||||||
{#{% block methods %}
|
{% block methods %}
|
||||||
|
|
||||||
{% if methods %}
|
{% if methods %}
|
||||||
.. rubric:: {{ _('Methods') }}
|
.. rubric:: {{ _('Methods') }}
|
||||||
|
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
{% for item in methods %}
|
{% for item in methods %}
|
||||||
{%- if item not in inherited_members and item != '__init__' %}
|
{%- if item not in inherited_members and item != "__init__" %}
|
||||||
~{{ name }}.{{ item }}
|
~{{ name }}.{{ item }}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endblock %}#}
|
{% endblock %}
|
||||||
|
|
||||||
|
@@ -12,7 +12,7 @@ import mlx.core as mx
|
|||||||
project = "MLX"
|
project = "MLX"
|
||||||
copyright = "2023, MLX Contributors"
|
copyright = "2023, MLX Contributors"
|
||||||
author = "MLX Contributors"
|
author = "MLX Contributors"
|
||||||
version = ".".join(mx.__version__.split()[:-1])
|
version = ".".join(mx.__version__.split(".")[:3])
|
||||||
release = version
|
release = version
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
@@ -22,22 +22,28 @@ extensions = [
|
|||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
"sphinx.ext.napoleon",
|
"sphinx.ext.napoleon",
|
||||||
|
"breathe",
|
||||||
]
|
]
|
||||||
|
|
||||||
python_use_unqualified_type_names = True
|
python_use_unqualified_type_names = True
|
||||||
autosummary_generate = True
|
autosummary_generate = True
|
||||||
|
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
|
||||||
|
|
||||||
intersphinx_mapping = {
|
intersphinx_mapping = {
|
||||||
"https://docs.python.org/3": None,
|
"python": ("https://docs.python.org/3", None),
|
||||||
"https://numpy.org/doc/stable/": None,
|
"numpy": ("https://numpy.org/doc/stable/", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
breathe_projects = {"mlx": "../build/xml"}
|
||||||
|
breathe_default_project = "mlx"
|
||||||
|
|
||||||
templates_path = ["_templates"]
|
templates_path = ["_templates"]
|
||||||
html_static_path = ["_static"]
|
html_static_path = ["_static"]
|
||||||
source_suffix = ".rst"
|
source_suffix = ".rst"
|
||||||
master_doc = "index"
|
main_doc = "index"
|
||||||
highlight_language = "python"
|
highlight_language = "python"
|
||||||
pygments_style = "sphinx"
|
pygments_style = "sphinx"
|
||||||
|
add_module_names = False
|
||||||
|
|
||||||
# -- Options for HTML output -------------------------------------------------
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
||||||
@@ -48,11 +54,45 @@ html_theme_options = {
|
|||||||
"repository_url": "https://github.com/ml-explore/mlx",
|
"repository_url": "https://github.com/ml-explore/mlx",
|
||||||
"use_repository_button": True,
|
"use_repository_button": True,
|
||||||
"navigation_with_keys": False,
|
"navigation_with_keys": False,
|
||||||
|
"logo": {
|
||||||
|
"image_light": "_static/mlx_logo.png",
|
||||||
|
"image_dark": "_static/mlx_logo_dark.png",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
html_logo = "_static/mlx_logo.png"
|
html_favicon = html_theme_options["logo"]["image_light"]
|
||||||
|
|
||||||
|
|
||||||
# -- Options for HTMLHelp output ---------------------------------------------
|
# -- Options for HTMLHelp output ---------------------------------------------
|
||||||
|
|
||||||
htmlhelp_basename = "mlx_doc"
|
htmlhelp_basename = "mlx_doc"
|
||||||
|
|
||||||
|
|
||||||
|
def setup(app):
|
||||||
|
from sphinx.util import inspect
|
||||||
|
|
||||||
|
wrapped_isfunc = inspect.isfunction
|
||||||
|
|
||||||
|
def isfunc(obj):
|
||||||
|
type_name = str(type(obj))
|
||||||
|
if "nanobind.nb_method" in type_name or "nanobind.nb_func" in type_name:
|
||||||
|
return True
|
||||||
|
return wrapped_isfunc(obj)
|
||||||
|
|
||||||
|
inspect.isfunction = isfunc
|
||||||
|
|
||||||
|
|
||||||
|
# -- Options for LaTeX output ------------------------------------------------
|
||||||
|
|
||||||
|
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
|
||||||
|
latex_elements = {
|
||||||
|
"preamble": r"""
|
||||||
|
\usepackage{enumitem}
|
||||||
|
\setlistdepth{5}
|
||||||
|
\setlist[itemize,1]{label=$\bullet$}
|
||||||
|
\setlist[itemize,2]{label=$\bullet$}
|
||||||
|
\setlist[itemize,3]{label=$\bullet$}
|
||||||
|
\setlist[itemize,4]{label=$\bullet$}
|
||||||
|
\setlist[itemize,5]{label=$\bullet$}
|
||||||
|
\renewlist{itemize}{itemize}{5}
|
||||||
|
""",
|
||||||
|
}
|
||||||
|
@@ -3,4 +3,5 @@
|
|||||||
Operations
|
Operations
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
.. doxygengroup:: ops
|
||||||
|
:content-only:
|
||||||
|
427
docs/src/dev/custom_metal_kernels.rst
Normal file
427
docs/src/dev/custom_metal_kernels.rst
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
.. _custom_metal_kernels:
|
||||||
|
|
||||||
|
Custom Metal Kernels
|
||||||
|
====================
|
||||||
|
|
||||||
|
MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||||
|
|
||||||
|
Simple Example
|
||||||
|
--------------
|
||||||
|
|
||||||
|
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
T tmp = inp[elem];
|
||||||
|
out[elem] = metal::exp(tmp);
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="myexp",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
outputs = kernel(
|
||||||
|
inputs=[a],
|
||||||
|
template=[("T", mx.float32)],
|
||||||
|
grid=(a.size, 1, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
output_shapes=[a.shape],
|
||||||
|
output_dtypes=[a.dtype],
|
||||||
|
)
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
|
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||||
|
b = exp_elementwise(a)
|
||||||
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
We are only required to pass the body of the Metal kernel in ``source``.
|
||||||
|
|
||||||
|
The full function signature will be generated using:
|
||||||
|
|
||||||
|
* The shapes/dtypes of ``inputs``
|
||||||
|
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
||||||
|
so we will add ``const device float16_t* inp`` to the signature.
|
||||||
|
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
||||||
|
in ``source``.
|
||||||
|
* The list of ``output_dtypes``
|
||||||
|
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
||||||
|
so we add ``device float16_t* out``.
|
||||||
|
* Template parameters passed using ``template``
|
||||||
|
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
|
||||||
|
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
||||||
|
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
||||||
|
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
||||||
|
These will be added as function arguments.
|
||||||
|
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
|
||||||
|
|
||||||
|
Putting this all together, the generated function signature for ``myexp`` is as follows:
|
||||||
|
|
||||||
|
.. code-block:: cpp
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
[[kernel]] void custom_kernel_myexp_float(
|
||||||
|
const device float16_t* inp [[buffer(0)]],
|
||||||
|
device float16_t* out [[buffer(1)]],
|
||||||
|
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
|
||||||
|
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
T tmp = inp[elem];
|
||||||
|
out[elem] = metal::exp(tmp);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||||
|
|
||||||
|
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||||
|
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||||
|
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||||
|
|
||||||
|
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||||
|
|
||||||
|
Using Shape/Strides
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||||
|
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||||
|
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||||
|
when indexing.
|
||||||
|
|
||||||
|
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||||
|
input array ``a`` if any are present in ``source``.
|
||||||
|
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||||
|
|
||||||
|
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||||
|
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||||
|
T tmp = inp[loc];
|
||||||
|
// Output arrays are always row contiguous
|
||||||
|
out[elem] = metal::exp(tmp);
|
||||||
|
"""
|
||||||
|
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="myexp_strided",
|
||||||
|
input_names=["inp"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source
|
||||||
|
)
|
||||||
|
outputs = kernel(
|
||||||
|
inputs=[a],
|
||||||
|
template=[("T", mx.float32)],
|
||||||
|
grid=(a.size, 1, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
output_shapes=[a.shape],
|
||||||
|
output_dtypes=[a.dtype],
|
||||||
|
ensure_row_contiguous=False,
|
||||||
|
)
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
|
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||||
|
# make non-contiguous
|
||||||
|
a = a[::2]
|
||||||
|
b = exp_elementwise(a)
|
||||||
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
|
Complex Example
|
||||||
|
-----------------------------
|
||||||
|
|
||||||
|
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
|
||||||
|
|
||||||
|
We'll start with the following MLX implementation using standard ops:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def grid_sample_ref(x, grid):
|
||||||
|
N, H_in, W_in, _ = x.shape
|
||||||
|
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||||
|
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||||
|
|
||||||
|
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||||
|
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||||
|
|
||||||
|
ix_ne = ix_nw + 1
|
||||||
|
iy_ne = iy_nw
|
||||||
|
|
||||||
|
ix_sw = ix_nw
|
||||||
|
iy_sw = iy_nw + 1
|
||||||
|
|
||||||
|
ix_se = ix_nw + 1
|
||||||
|
iy_se = iy_nw + 1
|
||||||
|
|
||||||
|
nw = (ix_se - ix) * (iy_se - iy)
|
||||||
|
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||||
|
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||||
|
se = (ix - ix_nw) * (iy - iy_nw)
|
||||||
|
|
||||||
|
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||||
|
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||||
|
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||||
|
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||||
|
|
||||||
|
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||||
|
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||||
|
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||||
|
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||||
|
|
||||||
|
I_nw *= mask_nw[..., None]
|
||||||
|
I_ne *= mask_ne[..., None]
|
||||||
|
I_sw *= mask_sw[..., None]
|
||||||
|
I_se *= mask_se[..., None]
|
||||||
|
|
||||||
|
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||||
|
to write a fast GPU kernel for both the forward and backward passes.
|
||||||
|
|
||||||
|
First we'll implement the forward pass as a fused kernel:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@mx.custom_function
|
||||||
|
def grid_sample(x, grid):
|
||||||
|
|
||||||
|
assert x.ndim == 4, "`x` must be 4D."
|
||||||
|
assert grid.ndim == 4, "`grid` must be 4D."
|
||||||
|
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
out_shape = (B, gN, gM, C)
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
int H = x_shape[1];
|
||||||
|
int W = x_shape[2];
|
||||||
|
int C = x_shape[3];
|
||||||
|
int gH = grid_shape[1];
|
||||||
|
int gW = grid_shape[2];
|
||||||
|
|
||||||
|
int w_stride = C;
|
||||||
|
int h_stride = W * w_stride;
|
||||||
|
int b_stride = H * h_stride;
|
||||||
|
|
||||||
|
uint grid_idx = elem / C * 2;
|
||||||
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
|
int ix_nw = floor(ix);
|
||||||
|
int iy_nw = floor(iy);
|
||||||
|
|
||||||
|
int ix_ne = ix_nw + 1;
|
||||||
|
int iy_ne = iy_nw;
|
||||||
|
|
||||||
|
int ix_sw = ix_nw;
|
||||||
|
int iy_sw = iy_nw + 1;
|
||||||
|
|
||||||
|
int ix_se = ix_nw + 1;
|
||||||
|
int iy_se = iy_nw + 1;
|
||||||
|
|
||||||
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
|
|
||||||
|
int batch_idx = elem / C / gH / gW * b_stride;
|
||||||
|
int channel_idx = elem % C;
|
||||||
|
int base_idx = batch_idx + channel_idx;
|
||||||
|
|
||||||
|
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||||
|
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||||
|
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||||
|
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||||
|
|
||||||
|
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||||
|
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||||
|
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||||
|
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||||
|
|
||||||
|
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||||
|
"""
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="grid_sample",
|
||||||
|
input_names=["x", "grid"],
|
||||||
|
output_names=["out"],
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
outputs = kernel(
|
||||||
|
inputs=[x, grid],
|
||||||
|
template=[("T", x.dtype)],
|
||||||
|
output_shapes=[out_shape],
|
||||||
|
output_dtypes=[x.dtype],
|
||||||
|
grid=(np.prod(out_shape), 1, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
)
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
|
For a reasonably sized input such as:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
x.shape = (8, 1024, 1024, 64)
|
||||||
|
grid.shape = (8, 256, 256, 2)
|
||||||
|
|
||||||
|
On an M1 Max, we see a big performance improvement:
|
||||||
|
|
||||||
|
``55.7ms -> 6.7ms => 8x speed up``
|
||||||
|
|
||||||
|
Grid Sample VJP
|
||||||
|
---------------
|
||||||
|
|
||||||
|
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||||
|
its custom vjp transform so MLX can differentiate it.
|
||||||
|
|
||||||
|
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||||
|
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||||
|
|
||||||
|
* ``init_value=0``
|
||||||
|
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||||
|
|
||||||
|
* ``atomic_outputs=True``
|
||||||
|
Designate all of the kernel outputs as ``atomic`` in the function signature.
|
||||||
|
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
|
||||||
|
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
|
||||||
|
|
||||||
|
We can then implement the backwards pass as follows:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@grid_sample.vjp
|
||||||
|
def grid_sample_vjp(primals, cotangent, _):
|
||||||
|
x, grid = primals
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
|
source = """
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
int H = x_shape[1];
|
||||||
|
int W = x_shape[2];
|
||||||
|
int C = x_shape[3];
|
||||||
|
// Pad C to the nearest larger simdgroup size multiple
|
||||||
|
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||||
|
|
||||||
|
int gH = grid_shape[1];
|
||||||
|
int gW = grid_shape[2];
|
||||||
|
|
||||||
|
int w_stride = C;
|
||||||
|
int h_stride = W * w_stride;
|
||||||
|
int b_stride = H * h_stride;
|
||||||
|
|
||||||
|
uint grid_idx = elem / C_padded * 2;
|
||||||
|
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||||
|
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||||
|
|
||||||
|
int ix_nw = floor(ix);
|
||||||
|
int iy_nw = floor(iy);
|
||||||
|
|
||||||
|
int ix_ne = ix_nw + 1;
|
||||||
|
int iy_ne = iy_nw;
|
||||||
|
|
||||||
|
int ix_sw = ix_nw;
|
||||||
|
int iy_sw = iy_nw + 1;
|
||||||
|
|
||||||
|
int ix_se = ix_nw + 1;
|
||||||
|
int iy_se = iy_nw + 1;
|
||||||
|
|
||||||
|
T nw = (ix_se - ix) * (iy_se - iy);
|
||||||
|
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||||
|
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||||
|
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||||
|
|
||||||
|
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||||
|
int channel_idx = elem % C_padded;
|
||||||
|
int base_idx = batch_idx + channel_idx;
|
||||||
|
|
||||||
|
T gix = T(0);
|
||||||
|
T giy = T(0);
|
||||||
|
if (channel_idx < C) {
|
||||||
|
int cot_index = elem / C_padded * C + channel_idx;
|
||||||
|
T cot = cotangent[cot_index];
|
||||||
|
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||||
|
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||||
|
|
||||||
|
T I_nw = x[offset];
|
||||||
|
gix -= I_nw * (iy_se - iy) * cot;
|
||||||
|
giy -= I_nw * (ix_se - ix) * cot;
|
||||||
|
}
|
||||||
|
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||||
|
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||||
|
|
||||||
|
T I_ne = x[offset];
|
||||||
|
gix += I_ne * (iy_sw - iy) * cot;
|
||||||
|
giy -= I_ne * (ix - ix_sw) * cot;
|
||||||
|
}
|
||||||
|
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||||
|
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||||
|
|
||||||
|
T I_sw = x[offset];
|
||||||
|
gix -= I_sw * (iy - iy_ne) * cot;
|
||||||
|
giy += I_sw * (ix_ne - ix) * cot;
|
||||||
|
}
|
||||||
|
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||||
|
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||||
|
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||||
|
|
||||||
|
T I_se = x[offset];
|
||||||
|
gix += I_se * (iy - iy_nw) * cot;
|
||||||
|
giy += I_se * (ix - ix_nw) * cot;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
T gix_mult = W / 2;
|
||||||
|
T giy_mult = H / 2;
|
||||||
|
|
||||||
|
// Reduce across each simdgroup first.
|
||||||
|
// This is much faster than relying purely on atomics.
|
||||||
|
gix = simd_sum(gix);
|
||||||
|
giy = simd_sum(giy);
|
||||||
|
|
||||||
|
if (thread_index_in_simdgroup == 0) {
|
||||||
|
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||||
|
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="grid_sample_grad",
|
||||||
|
input_names=["x", "grid", "cotangent"],
|
||||||
|
output_names=["x_grad", "grid_grad"],
|
||||||
|
source=source,
|
||||||
|
atomic_outputs=True,
|
||||||
|
)
|
||||||
|
# pad the output channels to simd group size
|
||||||
|
# so that our `simd_sum`s don't overlap.
|
||||||
|
simdgroup_size = 32
|
||||||
|
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||||
|
grid_size = B * gN * gM * C_padded
|
||||||
|
outputs = kernel(
|
||||||
|
inputs=[x, grid, cotangent],
|
||||||
|
template=[("T", x.dtype)],
|
||||||
|
output_shapes=[x.shape, grid.shape],
|
||||||
|
output_dtypes=[x.dtype, x.dtype],
|
||||||
|
grid=(grid_size, 1, 1),
|
||||||
|
threadgroup=(256, 1, 1),
|
||||||
|
init_value=0,
|
||||||
|
)
|
||||||
|
return outputs[0], outputs[1]
|
||||||
|
|
||||||
|
There's an even larger speed up for the vjp:
|
||||||
|
|
||||||
|
``676.4ms -> 16.7ms => 40x speed up``
|
@@ -1,24 +1,16 @@
|
|||||||
Developer Documentation
|
Custom Extensions in MLX
|
||||||
=======================
|
========================
|
||||||
|
|
||||||
MLX provides a open and flexible backend to which users may add operations
|
You can extend MLX with custom operations on the CPU or GPU. This guide
|
||||||
and specialized implementations without much hassle. While the library supplies
|
explains how to do that with a simple example.
|
||||||
efficient operations that can be used and composed for any number of
|
|
||||||
applications, there may arise cases where new functionalities or highly
|
|
||||||
optimized implementations are needed. For such cases, you may design and
|
|
||||||
implement your own operations that link to and build on top of :mod:`mlx.core`.
|
|
||||||
We will introduce the inner-workings of MLX and go over a simple example to
|
|
||||||
learn the steps involved in adding new operations to MLX with your own CPU
|
|
||||||
and GPU implementations.
|
|
||||||
|
|
||||||
Introducing the Example
|
Introducing the Example
|
||||||
-----------------------
|
-----------------------
|
||||||
|
|
||||||
Let's say that you would like an operation that takes in two arrays,
|
Let's say you would like an operation that takes in two arrays, ``x`` and
|
||||||
``x`` and ``y``, scales them both by some coefficients ``alpha`` and ``beta``
|
``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
|
||||||
respectively, and then adds them together to get the result
|
and then adds them together to get the result ``z = alpha * x + beta * y``.
|
||||||
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
You can do that in MLX directly:
|
||||||
writing out a function as follows:
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@@ -27,44 +19,35 @@ writing out a function as follows:
|
|||||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||||
return alpha * x + beta * y
|
return alpha * x + beta * y
|
||||||
|
|
||||||
This function performs that operation while leaving the implementations and
|
This function performs that operation while leaving the implementation and
|
||||||
differentiation to MLX.
|
function transformations to MLX.
|
||||||
|
|
||||||
However, you work with vector math libraries often and realize that the
|
However you may need to customize the underlying implementation, perhaps to
|
||||||
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
|
make it faster or for custom differentiation. In this tutorial we will go
|
||||||
You would really like the part of your applications that does this operation
|
through adding custom extensions. It will cover:
|
||||||
on the CPU to be very fast - so you decide that you want it to rely on the
|
|
||||||
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
|
||||||
our assumptions on to you, let's also assume that you want to learn how add
|
|
||||||
your own implementation for the gradients of your new operation while going
|
|
||||||
over the ins-and-outs of the MLX framework.
|
|
||||||
|
|
||||||
Well, what a coincidence! You are in the right place. Over the course of this
|
* The structure of the MLX library.
|
||||||
example, we will learn:
|
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
|
||||||
|
* Implementing a GPU operation using metal.
|
||||||
* The structure of the MLX library from the frontend API to the backend implementations.
|
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||||
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
|
* Building a custom extension and binding it to python.
|
||||||
* How to implement your own GPU implementation using metal.
|
|
||||||
* How to add your own ``vjp`` and ``jvp``.
|
|
||||||
* How to build your implementations, link them to MLX, and bind them to python.
|
|
||||||
|
|
||||||
Operations and Primitives
|
Operations and Primitives
|
||||||
-------------------------
|
-------------------------
|
||||||
|
|
||||||
In one sentence, operations in MLX build the computation graph, and primitives
|
Operations in MLX build the computation graph. Primitives provide the rules for
|
||||||
provide the rules for evaluation and transformations of said graph. Let's start
|
evaluating and transforming the graph. Let's start by discussing operations in
|
||||||
by discussing operations in more detail.
|
more detail.
|
||||||
|
|
||||||
Operations
|
Operations
|
||||||
^^^^^^^^^^^
|
^^^^^^^^^^^
|
||||||
|
|
||||||
Operations are the frontend functions that operate on arrays. They are defined
|
Operations are the front-end functions that operate on arrays. They are defined
|
||||||
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
|
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||||
operations in the Python API (:ref:`ops`).
|
|
||||||
|
|
||||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
|
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
|
||||||
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
|
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||||
C++ API:
|
C++:
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@@ -83,10 +66,7 @@ C++ API:
|
|||||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||||
);
|
);
|
||||||
|
|
||||||
|
The simplest way to this operation is in terms of existing operations:
|
||||||
This operation itself can call other operations within it if needed. So, the
|
|
||||||
simplest way to go about implementing this operation would be do so in terms
|
|
||||||
of existing operations.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@@ -100,25 +80,23 @@ of existing operations.
|
|||||||
// Scale x and y on the provided stream
|
// Scale x and y on the provided stream
|
||||||
auto ax = multiply(array(alpha), x, s);
|
auto ax = multiply(array(alpha), x, s);
|
||||||
auto by = multiply(array(beta), y, s);
|
auto by = multiply(array(beta), y, s);
|
||||||
|
|
||||||
// Add and return
|
// Add and return
|
||||||
return add(ax, by, s);
|
return add(ax, by, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
However, as we discussed earlier, this is not our goal. The operations themselves
|
The operations themselves do not contain the implementations that act on the
|
||||||
do not contain the implementations that act on the data, nor do they contain the
|
data, nor do they contain the rules of transformations. Rather, they are an
|
||||||
rules of transformations. Rather, they are an easy to use interface that build
|
easy to use interface that use :class:`Primitive` building blocks.
|
||||||
on top of the building blocks we call :class:`Primitive`.
|
|
||||||
|
|
||||||
Primitives
|
Primitives
|
||||||
^^^^^^^^^^^
|
^^^^^^^^^^^
|
||||||
|
|
||||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||||
defines how to create an output given a set of input :class:`array` . Further,
|
defines how to create outputs arrays given a input arrays. Further, a
|
||||||
a :class:`Primitive` is a class that contains rules on how it is evaluated
|
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||||
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
|
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||||
``jvp``. These words on their own can be a bit abstract, so lets take a step
|
more concrete:
|
||||||
back and go to our example to give ourselves a more concrete image.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@@ -134,11 +112,15 @@ back and go to our example to give ourselves a more concrete image.
|
|||||||
* To avoid unnecessary allocations, the evaluation function
|
* To avoid unnecessary allocations, the evaluation function
|
||||||
* is responsible for allocating space for the array.
|
* is responsible for allocating space for the array.
|
||||||
*/
|
*/
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) override;
|
||||||
|
void eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) override;
|
||||||
|
|
||||||
/** The Jacobian-vector product. */
|
/** The Jacobian-vector product. */
|
||||||
array jvp(
|
std::vector<array> jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) override;
|
const std::vector<int>& argnums) override;
|
||||||
@@ -147,7 +129,8 @@ back and go to our example to give ourselves a more concrete image.
|
|||||||
std::vector<array> vjp(
|
std::vector<array> vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const array& cotan,
|
const array& cotan,
|
||||||
const std::vector<int>& argnums) override;
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The primitive must know how to vectorize itself across
|
* The primitive must know how to vectorize itself across
|
||||||
@@ -155,7 +138,7 @@ back and go to our example to give ourselves a more concrete image.
|
|||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
*/
|
*/
|
||||||
std::pair<array, int> vmap(
|
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
@@ -175,22 +158,22 @@ back and go to our example to give ourselves a more concrete image.
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
The :class:`Axpby` class derives from the base :class:`Primitive` class and
|
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||||
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
|
:class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
|
||||||
``beta`` as parameters. It then provides implementations of how the array ``out``
|
implementations of how the output array is produced given the inputs through
|
||||||
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
|
:meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
|
||||||
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
|
of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
|
||||||
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
|
:meth:`Axpby::vmap`.
|
||||||
|
|
||||||
Using the Primitives
|
Using the Primitive
|
||||||
^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Operations can use this :class:`Primitive` to add a new :class:`array` to
|
Operations can use this :class:`Primitive` to add a new :class:`array` to the
|
||||||
the computation graph. An :class:`array` can be constructed by providing its
|
computation graph. An :class:`array` can be constructed by providing its data
|
||||||
data type, shape, the :class:`Primitive` that computes it, and the
|
type, shape, the :class:`Primitive` that computes it, and the :class:`array`
|
||||||
:class:`array` inputs that are passed to the primitive.
|
inputs that are passed to the primitive.
|
||||||
|
|
||||||
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@@ -223,7 +206,7 @@ Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
|||||||
/* const std::vector<int>& shape = */ out_shape,
|
/* const std::vector<int>& shape = */ out_shape,
|
||||||
/* Dtype dtype = */ out_dtype,
|
/* Dtype dtype = */ out_dtype,
|
||||||
/* std::unique_ptr<Primitive> primitive = */
|
/* std::unique_ptr<Primitive> primitive = */
|
||||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,27 +221,26 @@ This operation now handles the following:
|
|||||||
Implementing the Primitive
|
Implementing the Primitive
|
||||||
--------------------------
|
--------------------------
|
||||||
|
|
||||||
No computation happens when we call the operation alone. In effect, the
|
No computation happens when we call the operation alone. The operation only
|
||||||
operation only builds the computation graph. When we evaluate the output
|
builds the computation graph. When we evaluate the output array, MLX schedules
|
||||||
array, MLX schedules the execution of the computation graph, and calls
|
the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
||||||
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
|
:meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
|
||||||
stream/device specified by the user.
|
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
|
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
|
||||||
no memory has been allocated for the output array. It falls on the implementation
|
no memory has been allocated for the output array. It falls on the implementation
|
||||||
of these functions to allocate memory as needed
|
of these functions to allocate memory as needed.
|
||||||
|
|
||||||
Implementing the CPU Backend
|
Implementing the CPU Back-end
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Let's start by trying to implement a naive and generic version of
|
Let's start by implementing a naive and generic version of
|
||||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||||
|
|
||||||
Our naive method will go over each element of the output array, find the
|
Our naive method will go over each element of the output array, find the
|
||||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||||
pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@@ -296,19 +278,19 @@ pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Now, we would like our implementation to be able to do this pointwise operation
|
Our implementation should work for all incoming floating point arrays.
|
||||||
for all incoming floating point arrays. Accordingly, we add dispatches for
|
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||||
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
|
``complex64``. We throw an error if we encounter an unexpected type.
|
||||||
if we encounter an unexpected type.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
/** Fall back implementation for evaluation on CPU */
|
||||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval(
|
||||||
// Check the inputs (registered in the op while constructing the out array)
|
const std::vector<array>& inputs,
|
||||||
assert(inputs.size() == 2);
|
const std::vector<array>& outputs) {
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Dispatch to the correct dtype
|
// Dispatch to the correct dtype
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == float32) {
|
||||||
@@ -321,28 +303,26 @@ if we encounter an unexpected type.
|
|||||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Axpby is only supported for floating point types.");
|
"[Axpby] Only supports floating point types.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
We have a fallback implementation! Now, to do what we are really here to do.
|
This is good as a fallback implementation. We can use the ``axpby`` routine
|
||||||
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
|
provided by the Accelerate_ framework for a faster implementation in certain
|
||||||
framework? Well, there are 3 complications to keep in mind:
|
cases:
|
||||||
|
|
||||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||||
floats. We can only direct to it for ``float32`` types
|
floats. We can only use it for ``float32`` types.
|
||||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
|
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
|
||||||
have fixed strides between them. Possibly due to broadcasts and transposes,
|
elements have fixed strides between them. We only direct to Accelerate
|
||||||
we aren't guaranteed that the inputs fit this requirement. We can
|
if both ``x`` and ``y`` are row contiguous or column contiguous.
|
||||||
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
|
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
|
||||||
column contiguous.
|
MLX expects to write the output to a new array. We must copy the elements
|
||||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
|
of ``y`` into the output and use that as an input to ``axpby``.
|
||||||
MLX expects to write out the answer to a new array. We must copy the elements
|
|
||||||
of ``y`` into the output array and use that as an input to ``axpby``
|
|
||||||
|
|
||||||
Let's write out an implementation that uses Accelerate in the right conditions.
|
Let's write an implementation that uses Accelerate in the right conditions.
|
||||||
It must simply allocate data for the output, copy elements of ``y`` into it,
|
It allocates data for the output, copies ``y`` into it, and then calls the
|
||||||
and then call the :meth:`catlas_saxpby` from accelerate.
|
:func:`catlas_saxpby` from accelerate.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@@ -356,17 +336,7 @@ and then call the :meth:`catlas_saxpby` from accelerate.
|
|||||||
// Accelerate library provides catlas_saxpby which does
|
// Accelerate library provides catlas_saxpby which does
|
||||||
// Y = (alpha * X) + (beta * Y) in place
|
// Y = (alpha * X) + (beta * Y) in place
|
||||||
// To use it, we first copy the data in y over to the output array
|
// To use it, we first copy the data in y over to the output array
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
// This specialization requires both x and y be contiguous in the same mode
|
|
||||||
// i.e: corresponding linear indices in both point to corresponding elements
|
|
||||||
// The data in the output array is allocated to match the strides in y
|
|
||||||
// such that x, y, and out are contiguous in the same mode and
|
|
||||||
// no transposition is needed
|
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
|
||||||
y.data_size(),
|
|
||||||
y.strides(),
|
|
||||||
y.flags());
|
|
||||||
|
|
||||||
// We then copy over the elements using the contiguous vector specialization
|
// We then copy over the elements using the contiguous vector specialization
|
||||||
copy_inplace(y, out, CopyType::Vector);
|
copy_inplace(y, out, CopyType::Vector);
|
||||||
@@ -389,18 +359,20 @@ and then call the :meth:`catlas_saxpby` from accelerate.
|
|||||||
/* INCY = */ 1);
|
/* INCY = */ 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Great! But what about the inputs that do not fit the criteria for accelerate?
|
For inputs that do not fit the criteria for accelerate, we fall back to
|
||||||
Luckily, we can always just direct back to :meth:`Axpby::eval`.
|
:meth:`Axpby::eval`. With this in mind, let's finish our
|
||||||
|
:meth:`Axpby::eval_cpu`.
|
||||||
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Evaluate primitive on CPU using accelerate specializations */
|
/** Evaluate primitive on CPU using accelerate specializations */
|
||||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Accelerate specialization for contiguous single precision float arrays
|
// Accelerate specialization for contiguous single precision float arrays
|
||||||
if (out.dtype() == float32 &&
|
if (out.dtype() == float32 &&
|
||||||
@@ -410,35 +382,33 @@ With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to common backend if specializations are not available
|
// Fall back to common back-end if specializations are not available
|
||||||
eval(inputs, out);
|
eval(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
We have now hit a milestone! Just this much is enough to run the operation
|
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||||
:meth:`axpby` on a CPU stream!
|
you do not plan on running the operation on the GPU or using transforms on
|
||||||
|
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||||
|
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||||
|
|
||||||
If you do not plan on running the operation on the GPU or using transforms on
|
Implementing the GPU Back-end
|
||||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
|
||||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
|
||||||
|
|
||||||
Implementing the GPU Backend
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||||
all GPU kernels in MLX are written using metal.
|
GPU kernels in MLX are written using Metal.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Here are some helpful resources if you are new to metal!
|
Here are some helpful resources if you are new to Metal:
|
||||||
|
|
||||||
* A walkthrough of the metal compute pipeline: `Metal Example`_
|
* A walkthrough of the metal compute pipeline: `Metal Example`_
|
||||||
* Documentation for metal shading language: `Metal Specification`_
|
* Documentation for metal shading language: `Metal Specification`_
|
||||||
* Using metal from C++: `Metal-cpp`_
|
* Using metal from C++: `Metal-cpp`_
|
||||||
|
|
||||||
Let's keep the GPU algorithm simple. We will launch exactly as many threads
|
Let's keep the GPU kernel simple. We will launch exactly as many threads as
|
||||||
as there are elements in the output. Each thread will pick the element it needs
|
there are elements in the output. Each thread will pick the element it needs
|
||||||
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
|
from ``x`` and ``y``, do the point-wise operation, and update its assigned
|
||||||
element in the output.
|
element in the output.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@@ -457,15 +427,14 @@ element in the output.
|
|||||||
// Convert linear indices to offsets in array
|
// Convert linear indices to offsets in array
|
||||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||||
|
|
||||||
// Do the operation and update the output
|
// Do the operation and update the output
|
||||||
out[index] =
|
out[index] =
|
||||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
We then need to instantiate this template for all floating point types and give
|
We then need to instantiate this template for all floating point types and give
|
||||||
each instantiation a unique host name so we can identify the right kernel for
|
each instantiation a unique host name so we can identify it.
|
||||||
each data type.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@@ -488,29 +457,21 @@ each data type.
|
|||||||
instantiate_axpby(bfloat16, bfloat16_t);
|
instantiate_axpby(bfloat16, bfloat16_t);
|
||||||
instantiate_axpby(complex64, complex64_t);
|
instantiate_axpby(complex64, complex64_t);
|
||||||
|
|
||||||
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
The logic to determine the kernel, set the inputs, resolve the grid dimensions,
|
||||||
will see later in :ref:`Building with CMake`. In the following example, we
|
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||||
assume that the library ``mlx_ext.metallib`` will always be co-located with
|
|
||||||
the executable/ shared-library calling the :meth:`register_library` function.
|
|
||||||
The :meth:`register_library` function takes the library's name and potential
|
|
||||||
path (or in this case, a function that can produce the path of the metal
|
|
||||||
library) and tries to load that library if it hasn't already been registered
|
|
||||||
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
|
|
||||||
it is important to package your C++ library with the metal library. We will
|
|
||||||
go over this process in more detail later.
|
|
||||||
|
|
||||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions
|
|
||||||
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
|
||||||
below.
|
below.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Evaluate primitive on GPU */
|
/** Evaluate primitive on GPU */
|
||||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Axpby::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
// Prepare inputs
|
// Prepare inputs
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Each primitive carries the stream it should execute on
|
// Each primitive carries the stream it should execute on
|
||||||
// and each stream carries its device identifiers
|
// and each stream carries its device identifiers
|
||||||
@@ -518,23 +479,22 @@ below.
|
|||||||
// We get the needed metal device using the stream
|
// We get the needed metal device using the stream
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
// Allocate output memory
|
// Allocate output memory
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
// Resolve name of kernel (corresponds to axpby.metal)
|
// Resolve name of kernel
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "axpby_" << "general_" << type_to_name(out);
|
kname << "axpby_" << "general_" << type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available and look for it
|
// Make sure the metal library is available
|
||||||
// in the same folder as this executable if needed
|
d.register_library("mlx_ext");
|
||||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Kernel parameters are registered with buffer indices corresponding to
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
// those in the kernel declaration at axpby.metal
|
// those in the kernel declaration at axpby.metal
|
||||||
@@ -542,21 +502,21 @@ below.
|
|||||||
size_t nelem = out.size();
|
size_t nelem = out.size();
|
||||||
|
|
||||||
// Encode input arrays to kernel
|
// Encode input arrays to kernel
|
||||||
set_array_buffer(compute_encoder, x, 0);
|
compute_encoder.set_input_array(x, 0);
|
||||||
set_array_buffer(compute_encoder, y, 1);
|
compute_encoder.set_input_array(y, 1);
|
||||||
|
|
||||||
// Encode output arrays to kernel
|
// Encode output arrays to kernel
|
||||||
set_array_buffer(compute_encoder, out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Encode alpha and beta
|
// Encode alpha and beta
|
||||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
compute_encoder.set_bytes(alpha_, 3);
|
||||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
compute_encoder.set_bytes(beta_, 4);
|
||||||
|
|
||||||
// Encode shape, strides and ndim
|
// Encode shape, strides and ndim
|
||||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
compute_encoder.set_bytes(y.strides(), 7);
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
compute_encoder.set_bytes(ndim, 8);
|
||||||
|
|
||||||
// We launch 1 thread for each input and make sure that the number of
|
// We launch 1 thread for each input and make sure that the number of
|
||||||
// threads in any given threadgroup is not higher than the max allowed
|
// threads in any given threadgroup is not higher than the max allowed
|
||||||
@@ -570,33 +530,30 @@ below.
|
|||||||
|
|
||||||
// Launch the grid with the given number of threads divided among
|
// Launch the grid with the given number of threads divided among
|
||||||
// the given threadgroups
|
// the given threadgroups
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||||
|
|
||||||
A few things to note about MLX and metal before moving on. MLX keeps track
|
A few things to note about MLX and Metal before moving on. MLX keeps track of
|
||||||
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
|
the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
|
||||||
to give us the active metal compute command encoder instead of building a
|
associated. We rely on :meth:`d.get_command_encoder` to give us the active
|
||||||
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
metal compute command encoder instead of building a new one and calling
|
||||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
:meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
|
||||||
until some specified limit is hit or the compute encoder needs to be flushed
|
pipelines) to the active command buffer until some specified limit is hit or
|
||||||
for synchronization. MLX also handles enqueuing and committing the associated
|
the command buffer needs to be flushed for synchronization.
|
||||||
command buffers as needed. We suggest taking a deeper dive into
|
|
||||||
:class:`metal::Device` if you would like to study this routine further.
|
|
||||||
|
|
||||||
Primitive Transforms
|
Primitive Transforms
|
||||||
^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Now that we have come this far, let's also learn how to add implementations to
|
Next, let's add implementations for transformations in a :class:`Primitive`.
|
||||||
transformations in a :class:`Primitive`. These transformations can be built on
|
These transformations can be built on top of other operations, including the
|
||||||
top of our operations, including the one we just defined now. Which then gives
|
one we just defined:
|
||||||
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** The Jacobian-vector product. */
|
/** The Jacobian-vector product. */
|
||||||
array Axpby::jvp(
|
std::vector<array> Axpby::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
@@ -611,12 +568,12 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
|||||||
if (argnums.size() > 1) {
|
if (argnums.size() > 1) {
|
||||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||||
auto scale_arr = array(scale, tangents[0].dtype());
|
auto scale_arr = array(scale, tangents[0].dtype());
|
||||||
return multiply(scale_arr, tangents[0], stream());
|
return {multiply(scale_arr, tangents[0], stream())};
|
||||||
}
|
}
|
||||||
// If, argnums = {0, 1}, we take contributions from both
|
// If, argnums = {0, 1}, we take contributions from both
|
||||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||||
else {
|
else {
|
||||||
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -625,34 +582,35 @@ us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
|||||||
/** The vector-Jacobian product. */
|
/** The vector-Jacobian product. */
|
||||||
std::vector<array> Axpby::vjp(
|
std::vector<array> Axpby::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const array& cotan,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<int>& /* unused */) {
|
||||||
// Reverse mode diff
|
// Reverse mode diff
|
||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
for (auto arg : argnums) {
|
for (auto arg : argnums) {
|
||||||
auto scale = arg == 0 ? alpha_ : beta_;
|
auto scale = arg == 0 ? alpha_ : beta_;
|
||||||
auto scale_arr = array(scale, cotan.dtype());
|
auto scale_arr = array(scale, cotangents[0].dtype());
|
||||||
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
|
||||||
}
|
}
|
||||||
return vjps;
|
return vjps;
|
||||||
}
|
}
|
||||||
|
|
||||||
Finally, you need not have a transformation fully defined to start using your
|
Note, a transformation does not need to be fully defined to start using
|
||||||
own :class:`Primitive`.
|
the :class:`Primitive`.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Vectorize primitive along given axis */
|
/** Vectorize primitive along given axis */
|
||||||
std::pair<array, int> Axpby::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
throw std::runtime_error("[Axpby] vmap not implemented.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Building and Binding
|
Building and Binding
|
||||||
--------------------
|
--------------------
|
||||||
|
|
||||||
Let's look at the overall directory structure first.
|
Let's look at the overall directory structure first.
|
||||||
|
|
||||||
| extensions
|
| extensions
|
||||||
| ├── axpby
|
| ├── axpby
|
||||||
@@ -666,40 +624,39 @@ Let's look at the overall directory structure first.
|
|||||||
| └── setup.py
|
| └── setup.py
|
||||||
|
|
||||||
* ``extensions/axpby/`` defines the C++ extension library
|
* ``extensions/axpby/`` defines the C++ extension library
|
||||||
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
* ``extensions/mlx_sample_extensions`` sets out the structure for the
|
||||||
associated python package
|
associated Python package
|
||||||
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
* ``extensions/bindings.cpp`` provides Python bindings for our operation
|
||||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||||
python bindings
|
Python bindings
|
||||||
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
|
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
|
||||||
the python package
|
the Python package
|
||||||
|
|
||||||
Binding to Python
|
Binding to Python
|
||||||
^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings
|
We use nanobind_ to build a Python API for the C++ library. Since bindings for
|
||||||
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
|
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
|
||||||
are already provided, adding our :meth:`axpby` becomes very simple!
|
already provided, adding our :meth:`axpby` is simple.
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
NB_MODULE(_ext, m) {
|
||||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
m.doc() = "Sample extension for MLX";
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"axpby",
|
"axpby",
|
||||||
&axpby,
|
&axpby,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"y"_a,
|
"y"_a,
|
||||||
py::pos_only(),
|
|
||||||
"alpha"_a,
|
"alpha"_a,
|
||||||
"beta"_a,
|
"beta"_a,
|
||||||
py::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = py::none(),
|
"stream"_a = nb::none(),
|
||||||
R"pbdoc(
|
R"(
|
||||||
Scale and sum two vectors element-wise
|
Scale and sum two vectors element-wise
|
||||||
``z = alpha * x + beta * y``
|
``z = alpha * x + beta * y``
|
||||||
|
|
||||||
Follows numpy style broadcasting between ``x`` and ``y``
|
Follows numpy style broadcasting between ``x`` and ``y``
|
||||||
Inputs are upcasted to floats if needed
|
Inputs are upcasted to floats if needed
|
||||||
|
|
||||||
@@ -711,17 +668,17 @@ are already provided, adding our :meth:`axpby` becomes very simple!
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: ``alpha * x + beta * y``
|
array: ``alpha * x + beta * y``
|
||||||
)pbdoc");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
Most of the complexity in the above example comes from additional bells and
|
Most of the complexity in the above example comes from additional bells and
|
||||||
whistles such as the literal names and doc-strings.
|
whistles such as the literal names and doc-strings.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
:mod:`mlx.core` needs to be imported before importing
|
:mod:`mlx.core` must be imported before importing
|
||||||
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
|
:mod:`mlx_sample_extensions` as defined by the nanobind module above to
|
||||||
ensure that the casters for :mod:`mlx.core` components like
|
ensure that the casters for :mod:`mlx.core` components like
|
||||||
:class:`mlx.core.array` are available.
|
:class:`mlx.core.array` are available.
|
||||||
|
|
||||||
.. _Building with CMake:
|
.. _Building with CMake:
|
||||||
@@ -729,8 +686,8 @@ whistles such as the literal names and doc-strings.
|
|||||||
Building with CMake
|
Building with CMake
|
||||||
^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Building the C++ extension library itself is simple, it only requires that you
|
Building the C++ extension library only requires that you ``find_package(MLX
|
||||||
``find_package(MLX CONFIG)`` and then link it to your library.
|
CONFIG)`` and then link it to your library.
|
||||||
|
|
||||||
.. code-block:: cmake
|
.. code-block:: cmake
|
||||||
|
|
||||||
@@ -752,12 +709,12 @@ Building the C++ extension library itself is simple, it only requires that you
|
|||||||
# Link to mlx
|
# Link to mlx
|
||||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||||
|
|
||||||
We also need to build the attached metal library. For convenience, we provide a
|
We also need to build the attached Metal library. For convenience, we provide a
|
||||||
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||||
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||||
automatically imported with MLX package).
|
automatically imported with MLX package).
|
||||||
|
|
||||||
Here is what that looks like in practice!
|
Here is what that looks like in practice:
|
||||||
|
|
||||||
.. code-block:: cmake
|
.. code-block:: cmake
|
||||||
|
|
||||||
@@ -779,27 +736,29 @@ Here is what that looks like in practice!
|
|||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
Finally, we build the Pybind11_ bindings
|
Finally, we build the nanobind_ bindings
|
||||||
|
|
||||||
.. code-block:: cmake
|
.. code-block:: cmake
|
||||||
|
|
||||||
pybind11_add_module(
|
nanobind_add_module(
|
||||||
mlx_sample_extensions
|
_ext
|
||||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||||
|
NB_DOMAIN mlx
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||||
)
|
)
|
||||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||||
|
|
||||||
if(BUILD_SHARED_LIBS)
|
if(BUILD_SHARED_LIBS)
|
||||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
Building with ``setuptools``
|
Building with ``setuptools``
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Once we have set out the CMake build rules as described above, we can use the
|
Once we have set out the CMake build rules as described above, we can use the
|
||||||
build utilities defined in :mod:`mlx.extension` for a simple build process.
|
build utilities defined in :mod:`mlx.extension`:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from mlx import extension
|
from mlx import extension
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
@@ -809,48 +768,50 @@ build utilities defined in :mod:`mlx.extension` for a simple build process.
|
|||||||
name="mlx_sample_extensions",
|
name="mlx_sample_extensions",
|
||||||
version="0.0.0",
|
version="0.0.0",
|
||||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
|
||||||
cmdclass={"build_ext": extension.CMakeBuild},
|
cmdclass={"build_ext": extension.CMakeBuild},
|
||||||
packages = ["mlx_sample_extensions"],
|
packages=["mlx_sample_extensions"],
|
||||||
package_dir = {"": "mlx_sample_extensions"},
|
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||||
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
|
extras_require={"dev":[]},
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
python_requires=">=3.7",
|
python_requires=">=3.8",
|
||||||
)
|
)
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
We treat ``extensions/mlx_sample_extensions`` as the package directory
|
We treat ``extensions/mlx_sample_extensions`` as the package directory
|
||||||
even though it only contains a ``__init__.py`` to ensure the following:
|
even though it only contains a ``__init__.py`` to ensure the following:
|
||||||
|
|
||||||
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
|
|
||||||
* The C++ extension library and the metal library are co-located with the python
|
|
||||||
bindings and copied together if the package is installed
|
|
||||||
|
|
||||||
You can build inplace for development using
|
* :mod:`mlx.core` must be imported before importing :mod:`_ext`
|
||||||
|
* The C++ extension library and the metal library are co-located with the python
|
||||||
|
bindings and copied together if the package is installed
|
||||||
|
|
||||||
|
To build the package, first install the build dependencies with ``pip install
|
||||||
|
-r requirements.txt``. You can then build inplace for development using
|
||||||
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
|
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
|
||||||
|
|
||||||
This will result in a directory structure as follows:
|
This results in the directory structure:
|
||||||
|
|
||||||
| extensions
|
| extensions
|
||||||
| ├── mlx_sample_extensions
|
| ├── mlx_sample_extensions
|
||||||
| │ ├── __init__.py
|
| │ ├── __init__.py
|
||||||
| │ ├── libmlx_ext.dylib # C++ extension library
|
| │ ├── libmlx_ext.dylib # C++ extension library
|
||||||
| │ ├── mlx_ext.metallib # Metal library
|
| │ ├── mlx_ext.metallib # Metal library
|
||||||
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
|
| │ └── _ext.cpython-3x-darwin.so # Python Binding
|
||||||
| ...
|
| ...
|
||||||
|
|
||||||
When you try to install using the command ``python -m pip install .``
|
When you try to install using the command ``python -m pip install .`` (in
|
||||||
(in ``extensions/``), the package will be installed with the same structure as
|
``extensions/``), the package will be installed with the same structure as
|
||||||
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
|
||||||
copied along with the python binding since they are specified as ``package_data``.
|
copied along with the Python binding since they are specified as
|
||||||
|
``package_data``.
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
-----
|
-----
|
||||||
|
|
||||||
After installing the extension as described above, you should be able to simply
|
After installing the extension as described above, you should be able to simply
|
||||||
import the python package and play with it as you would any other MLX operation!
|
import the Python package and play with it as you would any other MLX operation.
|
||||||
|
|
||||||
Let's looks at a simple script and it's results!
|
Let's look at a simple script and its results:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@@ -863,7 +824,7 @@ Let's looks at a simple script and it's results!
|
|||||||
|
|
||||||
print(f"c shape: {c.shape}")
|
print(f"c shape: {c.shape}")
|
||||||
print(f"c dtype: {c.dtype}")
|
print(f"c dtype: {c.dtype}")
|
||||||
print(f"c correctness: {mx.all(c == 6.0).item()}")
|
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
|
|
||||||
@@ -874,12 +835,12 @@ Output:
|
|||||||
c correctness: True
|
c correctness: True
|
||||||
|
|
||||||
Results
|
Results
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||||
with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_sample_extensions import axpby
|
from mlx_sample_extensions import axpby
|
||||||
@@ -898,7 +859,7 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
|||||||
alpha = 4.0
|
alpha = 4.0
|
||||||
beta = 2.0
|
beta = 2.0
|
||||||
|
|
||||||
mx.eval((x, y))
|
mx.eval(x, y)
|
||||||
|
|
||||||
def bench(f):
|
def bench(f):
|
||||||
# Warm up
|
# Warm up
|
||||||
@@ -919,30 +880,23 @@ with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
|||||||
|
|
||||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||||
|
|
||||||
Results:
|
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
|
||||||
|
modest improvements right away!
|
||||||
|
|
||||||
.. code-block::
|
This operation is now good to be used to build other operations, in
|
||||||
|
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
|
||||||
Simple axpby: 0.114 s | Custom axpby: 0.109 s
|
:meth:`grad`.
|
||||||
|
|
||||||
We see some modest improvements right away!
|
|
||||||
|
|
||||||
This operation is now good to be used to build other operations,
|
|
||||||
in :class:`mlx.nn.Module` calls, and also as a part of graph
|
|
||||||
transformations such as :meth:`grad` and :meth:`simplify`!
|
|
||||||
|
|
||||||
Scripts
|
Scripts
|
||||||
-------
|
-------
|
||||||
|
|
||||||
.. admonition:: Download the code
|
.. admonition:: Download the code
|
||||||
|
|
||||||
The full example code is available in `mlx-examples <code>`_.
|
The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
|
||||||
|
|
||||||
.. code: `TODO_LINK/extensions`_
|
|
||||||
|
|
||||||
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||||
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||||
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
||||||
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||||
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
||||||
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
.. _nanobind: https://nanobind.readthedocs.io/en/latest/
|
||||||
|
68
docs/src/dev/metal_debugger.rst
Normal file
68
docs/src/dev/metal_debugger.rst
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
Metal Debugger
|
||||||
|
==============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
Profiling is a key step for performance optimization. You can build MLX with
|
||||||
|
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and
|
||||||
|
optimization workflow. The ``MLX_METAL_DEBUG`` debug option:
|
||||||
|
|
||||||
|
* Records source during Metal compilation, for later inspection while
|
||||||
|
debugging.
|
||||||
|
* Labels Metal objects such as command queues, improving capture readability.
|
||||||
|
|
||||||
|
To build with debugging enabled in Python prepend
|
||||||
|
``CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"`` to the build call.
|
||||||
|
|
||||||
|
The :func:`metal.start_capture` function initiates a capture of all MLX GPU
|
||||||
|
work.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
To capture a GPU trace you must run the application with
|
||||||
|
``MTL_CAPTURE_ENABLED=1``.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
a = mx.random.uniform(shape=(512, 512))
|
||||||
|
b = mx.random.uniform(shape=(512, 512))
|
||||||
|
mx.eval(a, b)
|
||||||
|
|
||||||
|
trace_file = "mlx_trace.gputrace"
|
||||||
|
|
||||||
|
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
|
||||||
|
# that the path trace_file does not already exist.
|
||||||
|
mx.metal.start_capture(trace_file)
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
mx.eval(mx.add(a, b))
|
||||||
|
|
||||||
|
mx.metal.stop_capture()
|
||||||
|
|
||||||
|
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
|
||||||
|
has a great overview of all operations. Checkout the `Metal debugger
|
||||||
|
documentation`_ for more information.
|
||||||
|
|
||||||
|
.. image:: ../_static/metal_debugger/capture.png
|
||||||
|
:class: dark-light
|
||||||
|
|
||||||
|
Xcode Workflow
|
||||||
|
--------------
|
||||||
|
|
||||||
|
You can skip saving to a path by running within Xcode. First, generate an
|
||||||
|
Xcode project using CMake.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
mkdir build && cd build
|
||||||
|
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
|
||||||
|
open mlx.xcodeproj
|
||||||
|
|
||||||
|
Select the ``metal_capture`` example schema and run.
|
||||||
|
|
||||||
|
.. image:: ../_static/metal_debugger/schema.png
|
||||||
|
:class: dark-light
|
||||||
|
|
||||||
|
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger
|
@@ -15,7 +15,7 @@ module to concisely define the model architecture.
|
|||||||
Attention layer
|
Attention layer
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
We will start with the llama attention layer which notably uses the RoPE
|
We will start with the Llama attention layer which notably uses the RoPE
|
||||||
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
||||||
key/value cache that will be concatenated with the provided keys and values to
|
key/value cache that will be concatenated with the provided keys and values to
|
||||||
support efficient inference.
|
support efficient inference.
|
||||||
|
@@ -64,7 +64,7 @@ set:
|
|||||||
Next, setup the problem parameters and load the data. To load the data, you need our
|
Next, setup the problem parameters and load the data. To load the data, you need our
|
||||||
`mnist data loader
|
`mnist data loader
|
||||||
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
||||||
we will import as `mnist`.
|
we will import as ``mnist``.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@@ -41,7 +41,9 @@ are the CPU and GPU.
|
|||||||
usage/indexing
|
usage/indexing
|
||||||
usage/saving_and_loading
|
usage/saving_and_loading
|
||||||
usage/function_transforms
|
usage/function_transforms
|
||||||
|
usage/compile
|
||||||
usage/numpy
|
usage/numpy
|
||||||
|
usage/distributed
|
||||||
usage/using_streams
|
usage/using_streams
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
@@ -57,14 +59,18 @@ are the CPU and GPU.
|
|||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
python/array
|
python/array
|
||||||
|
python/data_types
|
||||||
python/devices_and_streams
|
python/devices_and_streams
|
||||||
python/ops
|
python/ops
|
||||||
python/random
|
python/random
|
||||||
python/transforms
|
python/transforms
|
||||||
|
python/fast
|
||||||
python/fft
|
python/fft
|
||||||
python/linalg
|
python/linalg
|
||||||
|
python/metal
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
|
python/distributed
|
||||||
python/tree_utils
|
python/tree_utils
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
@@ -78,3 +84,5 @@ are the CPU and GPU.
|
|||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
dev/extensions
|
dev/extensions
|
||||||
|
dev/metal_debugger
|
||||||
|
dev/custom_metal_kernels
|
||||||
|
@@ -14,11 +14,11 @@ silicon computer is
|
|||||||
To install from PyPI you must meet the following requirements:
|
To install from PyPI you must meet the following requirements:
|
||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.8
|
- Using a native Python >= 3.9
|
||||||
- macOS >= 13.3
|
- macOS >= 13.5
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
MLX is only available on devices running macOS >= 13.3
|
MLX is only available on devices running macOS >= 13.5
|
||||||
It is highly recommended to use macOS 14 (Sonoma)
|
It is highly recommended to use macOS 14 (Sonoma)
|
||||||
|
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ Build Requirements
|
|||||||
|
|
||||||
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
|
||||||
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
|
||||||
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
|
- Xcode >= 15.0 and macOS SDK >= 14.0
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
|
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
|
||||||
@@ -70,39 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from
|
|||||||
|
|
||||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||||
|
|
||||||
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
|
Then simply build and install MLX using pip:
|
||||||
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install "pybind11[global]"
|
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
||||||
conda install pybind11
|
|
||||||
brew install pybind11
|
|
||||||
|
|
||||||
Then simply build and install it using pip:
|
For developing, install the package with development dependencies, and use an
|
||||||
|
editable install:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
||||||
|
|
||||||
For developing use an editable install:
|
Once the development dependencies are installed, you can build faster with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
|
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
||||||
|
|
||||||
To make sure the install is working run the tests with:
|
Run the tests with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install ".[testing]"
|
|
||||||
python -m unittest discover python/tests
|
python -m unittest discover python/tests
|
||||||
|
|
||||||
Optional: Install stubs to enable auto completions and type checking from your IDE:
|
Optional: Install stubs to enable auto completions and type checking from your
|
||||||
|
IDE:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install ".[dev]"
|
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
|
|
||||||
C++ API
|
C++ API
|
||||||
@@ -123,7 +120,7 @@ Create a build directory and run CMake and make:
|
|||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
mkdir -p build && cd build
|
mkdir -p build && cd build
|
||||||
cmake .. && make -j
|
cmake .. && make -j
|
||||||
|
|
||||||
Run tests with:
|
Run tests with:
|
||||||
|
|
||||||
@@ -142,7 +139,7 @@ directory as the executable statically linked to ``libmlx.a`` or the
|
|||||||
preprocessor constant ``METAL_PATH`` should be defined at build time and it
|
preprocessor constant ``METAL_PATH`` should be defined at build time and it
|
||||||
should point to the path to the built metal library.
|
should point to the path to the built metal library.
|
||||||
|
|
||||||
.. list-table:: Build Options
|
.. list-table:: Build Options
|
||||||
:widths: 25 8
|
:widths: 25 8
|
||||||
:header-rows: 1
|
:header-rows: 1
|
||||||
|
|
||||||
@@ -156,31 +153,67 @@ should point to the path to the built metal library.
|
|||||||
- OFF
|
- OFF
|
||||||
* - MLX_BUILD_METAL
|
* - MLX_BUILD_METAL
|
||||||
- ON
|
- ON
|
||||||
|
* - MLX_BUILD_CPU
|
||||||
|
- ON
|
||||||
* - MLX_BUILD_PYTHON_BINDINGS
|
* - MLX_BUILD_PYTHON_BINDINGS
|
||||||
- OFF
|
- OFF
|
||||||
|
* - MLX_METAL_DEBUG
|
||||||
|
- OFF
|
||||||
|
* - MLX_BUILD_SAFETENSORS
|
||||||
|
- ON
|
||||||
|
* - MLX_BUILD_GGUF
|
||||||
|
- ON
|
||||||
|
* - MLX_METAL_JIT
|
||||||
|
- OFF
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
If you have multiple Xcode installations and wish to use
|
If you have multiple Xcode installations and wish to use
|
||||||
a specific one while building, you can do so by adding the
|
a specific one while building, you can do so by adding the
|
||||||
following environment variable before building
|
following environment variable before building
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
|
export DEVELOPER_DIR="/path/to/Xcode.app/Contents/Developer/"
|
||||||
|
|
||||||
Further, you can use the following command to find out which
|
Further, you can use the following command to find out which
|
||||||
macOS SDK will be used
|
macOS SDK will be used
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
Binary Size Minimization
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
|
||||||
|
and ``BUILD_SHARED_LIBS=ON``.
|
||||||
|
|
||||||
|
The MLX CMake build has several additional options to make smaller binaries.
|
||||||
|
For example, if you don't need the CPU backend or support for safetensors and
|
||||||
|
GGUF, you can do:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
cmake .. \
|
||||||
|
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
|
-DMLX_BUILD_CPU=OFF \
|
||||||
|
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||||
|
-DMLX_BUILD_GGUF=OFF \
|
||||||
|
-DMLX_METAL_JIT=ON
|
||||||
|
|
||||||
|
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
|
||||||
|
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||||
|
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||||
|
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||||
|
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||||
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
||||||
Metal not found
|
Metal not found
|
||||||
~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -202,12 +235,12 @@ Then set the active developer directory:
|
|||||||
|
|
||||||
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
|
sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
|
||||||
|
|
||||||
x86 Shell
|
x86 Shell
|
||||||
~~~~~~~~~
|
~~~~~~~~~
|
||||||
|
|
||||||
.. _build shell:
|
.. _build shell:
|
||||||
|
|
||||||
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||||
Rosetta instead of natively.
|
Rosetta instead of natively.
|
||||||
|
|
||||||
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
||||||
@@ -231,4 +264,4 @@ Also check that cmake is using the correct architecture:
|
|||||||
|
|
||||||
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
||||||
but the build errors out with "Building for x86_64 on macOS is not supported."
|
but the build errors out with "Building for x86_64 on macOS is not supported."
|
||||||
wipe your build cahce with ``rm -rf build/`` and try again.
|
wipe your build cache with ``rm -rf build/`` and try again.
|
||||||
|
@@ -10,27 +10,39 @@ Array
|
|||||||
|
|
||||||
array
|
array
|
||||||
array.astype
|
array.astype
|
||||||
|
array.at
|
||||||
array.item
|
array.item
|
||||||
array.tolist
|
array.tolist
|
||||||
array.dtype
|
array.dtype
|
||||||
|
array.itemsize
|
||||||
|
array.nbytes
|
||||||
array.ndim
|
array.ndim
|
||||||
array.shape
|
array.shape
|
||||||
array.size
|
array.size
|
||||||
Dtype
|
|
||||||
array.abs
|
array.abs
|
||||||
array.all
|
array.all
|
||||||
array.any
|
array.any
|
||||||
array.argmax
|
array.argmax
|
||||||
array.argmin
|
array.argmin
|
||||||
|
array.conj
|
||||||
array.cos
|
array.cos
|
||||||
array.dtype
|
array.cummax
|
||||||
|
array.cummin
|
||||||
|
array.cumprod
|
||||||
|
array.cumsum
|
||||||
|
array.diag
|
||||||
|
array.diagonal
|
||||||
array.exp
|
array.exp
|
||||||
|
array.flatten
|
||||||
array.log
|
array.log
|
||||||
|
array.log10
|
||||||
array.log1p
|
array.log1p
|
||||||
|
array.log2
|
||||||
array.logsumexp
|
array.logsumexp
|
||||||
array.max
|
array.max
|
||||||
array.mean
|
array.mean
|
||||||
array.min
|
array.min
|
||||||
|
array.moveaxis
|
||||||
array.prod
|
array.prod
|
||||||
array.reciprocal
|
array.reciprocal
|
||||||
array.reshape
|
array.reshape
|
||||||
@@ -40,7 +52,11 @@ Array
|
|||||||
array.split
|
array.split
|
||||||
array.sqrt
|
array.sqrt
|
||||||
array.square
|
array.square
|
||||||
|
array.squeeze
|
||||||
|
array.std
|
||||||
array.sum
|
array.sum
|
||||||
|
array.swapaxes
|
||||||
array.transpose
|
array.transpose
|
||||||
array.T
|
array.T
|
||||||
array.var
|
array.var
|
||||||
|
array.view
|
||||||
|
@@ -1,7 +1,5 @@
|
|||||||
.. _data_types:
|
.. _data_types:
|
||||||
|
|
||||||
:orphan:
|
|
||||||
|
|
||||||
Data Types
|
Data Types
|
||||||
==========
|
==========
|
||||||
|
|
||||||
@@ -44,9 +42,27 @@ The default floating point type is ``float32`` and the default integer type is
|
|||||||
* - ``int64``
|
* - ``int64``
|
||||||
- 8
|
- 8
|
||||||
- 64-bit signed integer
|
- 64-bit signed integer
|
||||||
|
* - ``bfloat16``
|
||||||
|
- 2
|
||||||
|
- 16-bit brain float (e8, m7)
|
||||||
* - ``float16``
|
* - ``float16``
|
||||||
- 2
|
- 2
|
||||||
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
|
- 16-bit IEEE float (e5, m10)
|
||||||
* - ``float32``
|
* - ``float32``
|
||||||
- 4
|
- 4
|
||||||
- 32-bit float
|
- 32-bit float
|
||||||
|
* - ``complex64``
|
||||||
|
- 8
|
||||||
|
- 64-bit complex float
|
||||||
|
|
||||||
|
|
||||||
|
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||||
|
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||||
|
``dtype`` (or category) is a subtype of another category.
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Dtype
|
||||||
|
DtypeCategory
|
||||||
|
issubdtype
|
||||||
|
@@ -9,9 +9,11 @@ Devices and Streams
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
Device
|
Device
|
||||||
|
Stream
|
||||||
default_device
|
default_device
|
||||||
set_default_device
|
set_default_device
|
||||||
Stream
|
|
||||||
default_stream
|
default_stream
|
||||||
new_stream
|
new_stream
|
||||||
set_default_stream
|
set_default_stream
|
||||||
|
stream
|
||||||
|
synchronize
|
||||||
|
22
docs/src/python/distributed.rst
Normal file
22
docs/src/python/distributed.rst
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
.. _distributed:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.distributed
|
||||||
|
|
||||||
|
Distributed Communication
|
||||||
|
==========================
|
||||||
|
|
||||||
|
MLX provides a distributed communication package using MPI. The MPI library is
|
||||||
|
loaded at runtime; if MPI is available then distributed communication is also
|
||||||
|
made available.
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Group
|
||||||
|
is_available
|
||||||
|
init
|
||||||
|
all_sum
|
||||||
|
all_gather
|
||||||
|
send
|
||||||
|
recv
|
||||||
|
recv_like
|
15
docs/src/python/fast.rst
Normal file
15
docs/src/python/fast.rst
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
.. _fast:
|
||||||
|
|
||||||
|
Fast
|
||||||
|
====
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.fast
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
rms_norm
|
||||||
|
layer_norm
|
||||||
|
rope
|
||||||
|
scaled_dot_product_attention
|
||||||
|
metal_kernel
|
@@ -8,4 +8,13 @@ Linear Algebra
|
|||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
inv
|
||||||
|
tri_inv
|
||||||
norm
|
norm
|
||||||
|
cholesky
|
||||||
|
cholesky_inv
|
||||||
|
cross
|
||||||
|
qr
|
||||||
|
svd
|
||||||
|
eigvalsh
|
||||||
|
eigh
|
||||||
|
20
docs/src/python/metal.rst
Normal file
20
docs/src/python/metal.rst
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
Metal
|
||||||
|
=====
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.metal
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
is_available
|
||||||
|
device_info
|
||||||
|
get_active_memory
|
||||||
|
get_peak_memory
|
||||||
|
reset_peak_memory
|
||||||
|
get_cache_memory
|
||||||
|
set_memory_limit
|
||||||
|
set_cache_limit
|
||||||
|
set_wired_limit
|
||||||
|
clear_cache
|
||||||
|
start_capture
|
||||||
|
stop_capture
|
@@ -173,6 +173,7 @@ In detail:
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
value_and_grad
|
value_and_grad
|
||||||
|
quantize
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
|
||||||
@@ -180,3 +181,4 @@ In detail:
|
|||||||
nn/layers
|
nn/layers
|
||||||
nn/functions
|
nn/functions
|
||||||
nn/losses
|
nn/losses
|
||||||
|
nn/init
|
||||||
|
@@ -12,12 +12,28 @@ simple functions.
|
|||||||
:toctree: _autosummary_functions
|
:toctree: _autosummary_functions
|
||||||
:template: nn-module-template.rst
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
|
elu
|
||||||
|
celu
|
||||||
gelu
|
gelu
|
||||||
gelu_approx
|
gelu_approx
|
||||||
gelu_fast_approx
|
gelu_fast_approx
|
||||||
|
glu
|
||||||
|
hard_shrink
|
||||||
|
hard_tanh
|
||||||
|
hardswish
|
||||||
|
leaky_relu
|
||||||
|
log_sigmoid
|
||||||
|
log_softmax
|
||||||
mish
|
mish
|
||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
|
relu6
|
||||||
selu
|
selu
|
||||||
|
sigmoid
|
||||||
silu
|
silu
|
||||||
|
softmax
|
||||||
|
softmin
|
||||||
|
softplus
|
||||||
|
softshrink
|
||||||
step
|
step
|
||||||
|
tanh
|
||||||
|
45
docs/src/python/nn/init.rst
Normal file
45
docs/src/python/nn/init.rst
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
.. _init:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.nn.init
|
||||||
|
|
||||||
|
Initializers
|
||||||
|
------------
|
||||||
|
|
||||||
|
The ``mlx.nn.init`` package contains commonly used initializers for neural
|
||||||
|
network parameters. Initializers return a function which can be applied to any
|
||||||
|
input :obj:`mlx.core.array` to produce an initialized output.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
init_fn = nn.init.uniform()
|
||||||
|
|
||||||
|
# Produces a [2, 2] uniform matrix
|
||||||
|
param = init_fn(mx.zeros((2, 2)))
|
||||||
|
|
||||||
|
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
|
||||||
|
distribution, you can do:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
|
||||||
|
init_fn = nn.init.uniform(low=-0.1, high=0.1)
|
||||||
|
model.apply(init_fn)
|
||||||
|
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
constant
|
||||||
|
normal
|
||||||
|
uniform
|
||||||
|
identity
|
||||||
|
glorot_normal
|
||||||
|
glorot_uniform
|
||||||
|
he_normal
|
||||||
|
he_uniform
|
@@ -10,28 +10,60 @@ Layers
|
|||||||
:template: nn-module-template.rst
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
ALiBi
|
ALiBi
|
||||||
|
AvgPool1d
|
||||||
|
AvgPool2d
|
||||||
|
AvgPool3d
|
||||||
BatchNorm
|
BatchNorm
|
||||||
|
CELU
|
||||||
Conv1d
|
Conv1d
|
||||||
Conv2d
|
Conv2d
|
||||||
|
Conv3d
|
||||||
|
ConvTranspose1d
|
||||||
|
ConvTranspose2d
|
||||||
|
ConvTranspose3d
|
||||||
Dropout
|
Dropout
|
||||||
Dropout2d
|
Dropout2d
|
||||||
Dropout3d
|
Dropout3d
|
||||||
Embedding
|
Embedding
|
||||||
|
ELU
|
||||||
GELU
|
GELU
|
||||||
|
GLU
|
||||||
GroupNorm
|
GroupNorm
|
||||||
|
GRU
|
||||||
|
HardShrink
|
||||||
|
HardTanh
|
||||||
|
Hardswish
|
||||||
InstanceNorm
|
InstanceNorm
|
||||||
LayerNorm
|
LayerNorm
|
||||||
|
LeakyReLU
|
||||||
Linear
|
Linear
|
||||||
|
LogSigmoid
|
||||||
|
LogSoftmax
|
||||||
|
LSTM
|
||||||
|
MaxPool1d
|
||||||
|
MaxPool2d
|
||||||
|
MaxPool3d
|
||||||
Mish
|
Mish
|
||||||
MultiHeadAttention
|
MultiHeadAttention
|
||||||
PReLU
|
PReLU
|
||||||
|
QuantizedEmbedding
|
||||||
QuantizedLinear
|
QuantizedLinear
|
||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
|
ReLU6
|
||||||
|
RNN
|
||||||
RoPE
|
RoPE
|
||||||
SELU
|
SELU
|
||||||
Sequential
|
Sequential
|
||||||
|
Sigmoid
|
||||||
SiLU
|
SiLU
|
||||||
SinusoidalPositionalEncoding
|
SinusoidalPositionalEncoding
|
||||||
|
Softmin
|
||||||
|
Softshrink
|
||||||
|
Softsign
|
||||||
|
Softmax
|
||||||
|
Softplus
|
||||||
Step
|
Step
|
||||||
|
Tanh
|
||||||
Transformer
|
Transformer
|
||||||
|
Upsample
|
||||||
|
@@ -18,6 +18,7 @@ Loss Functions
|
|||||||
kl_div_loss
|
kl_div_loss
|
||||||
l1_loss
|
l1_loss
|
||||||
log_cosh_loss
|
log_cosh_loss
|
||||||
|
margin_ranking_loss
|
||||||
mse_loss
|
mse_loss
|
||||||
nll_loss
|
nll_loss
|
||||||
smooth_l1_loss
|
smooth_l1_loss
|
||||||
|
@@ -11,6 +11,7 @@ Module
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
Module.training
|
Module.training
|
||||||
|
Module.state
|
||||||
|
|
||||||
.. rubric:: Methods
|
.. rubric:: Methods
|
||||||
|
|
||||||
@@ -29,6 +30,7 @@ Module
|
|||||||
Module.named_modules
|
Module.named_modules
|
||||||
Module.parameters
|
Module.parameters
|
||||||
Module.save_weights
|
Module.save_weights
|
||||||
|
Module.set_dtype
|
||||||
Module.train
|
Module.train
|
||||||
Module.trainable_parameters
|
Module.trainable_parameters
|
||||||
Module.unfreeze
|
Module.unfreeze
|
||||||
|
@@ -5,13 +5,14 @@ Operations
|
|||||||
|
|
||||||
.. currentmodule:: mlx.core
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
abs
|
abs
|
||||||
add
|
add
|
||||||
|
addmm
|
||||||
all
|
all
|
||||||
allclose
|
allclose
|
||||||
any
|
any
|
||||||
arange
|
arange
|
||||||
arccos
|
arccos
|
||||||
@@ -19,42 +20,76 @@ Operations
|
|||||||
arcsin
|
arcsin
|
||||||
arcsinh
|
arcsinh
|
||||||
arctan
|
arctan
|
||||||
|
arctan2
|
||||||
arctanh
|
arctanh
|
||||||
argmax
|
argmax
|
||||||
argmin
|
argmin
|
||||||
argpartition
|
argpartition
|
||||||
argsort
|
argsort
|
||||||
array_equal
|
array_equal
|
||||||
|
as_strided
|
||||||
|
atleast_1d
|
||||||
|
atleast_2d
|
||||||
|
atleast_3d
|
||||||
|
bitwise_and
|
||||||
|
bitwise_or
|
||||||
|
bitwise_xor
|
||||||
|
block_masked_mm
|
||||||
broadcast_to
|
broadcast_to
|
||||||
ceil
|
ceil
|
||||||
clip
|
clip
|
||||||
concatenate
|
concatenate
|
||||||
|
conj
|
||||||
|
conjugate
|
||||||
convolve
|
convolve
|
||||||
conv1d
|
conv1d
|
||||||
conv2d
|
conv2d
|
||||||
|
conv3d
|
||||||
|
conv_transpose1d
|
||||||
|
conv_transpose2d
|
||||||
|
conv_transpose3d
|
||||||
|
conv_general
|
||||||
cos
|
cos
|
||||||
cosh
|
cosh
|
||||||
|
cummax
|
||||||
|
cummin
|
||||||
|
cumprod
|
||||||
|
cumsum
|
||||||
|
degrees
|
||||||
dequantize
|
dequantize
|
||||||
|
diag
|
||||||
|
diagonal
|
||||||
divide
|
divide
|
||||||
divmod
|
divmod
|
||||||
|
einsum
|
||||||
|
einsum_path
|
||||||
equal
|
equal
|
||||||
erf
|
erf
|
||||||
erfinv
|
erfinv
|
||||||
exp
|
exp
|
||||||
|
expm1
|
||||||
expand_dims
|
expand_dims
|
||||||
eye
|
eye
|
||||||
flatten
|
flatten
|
||||||
floor
|
floor
|
||||||
floor_divide
|
floor_divide
|
||||||
full
|
full
|
||||||
|
gather_mm
|
||||||
|
gather_qmm
|
||||||
greater
|
greater
|
||||||
greater_equal
|
greater_equal
|
||||||
|
hadamard_transform
|
||||||
identity
|
identity
|
||||||
|
imag
|
||||||
inner
|
inner
|
||||||
isnan
|
isfinite
|
||||||
isposinf
|
isclose
|
||||||
isneginf
|
|
||||||
isinf
|
isinf
|
||||||
|
isnan
|
||||||
|
isneginf
|
||||||
|
isposinf
|
||||||
|
issubdtype
|
||||||
|
left_shift
|
||||||
less
|
less
|
||||||
less_equal
|
less_equal
|
||||||
linspace
|
linspace
|
||||||
@@ -72,22 +107,32 @@ Operations
|
|||||||
max
|
max
|
||||||
maximum
|
maximum
|
||||||
mean
|
mean
|
||||||
|
meshgrid
|
||||||
min
|
min
|
||||||
minimum
|
minimum
|
||||||
moveaxis
|
moveaxis
|
||||||
multiply
|
multiply
|
||||||
|
nan_to_num
|
||||||
negative
|
negative
|
||||||
|
not_equal
|
||||||
ones
|
ones
|
||||||
ones_like
|
ones_like
|
||||||
outer
|
outer
|
||||||
partition
|
partition
|
||||||
pad
|
pad
|
||||||
|
power
|
||||||
prod
|
prod
|
||||||
|
put_along_axis
|
||||||
quantize
|
quantize
|
||||||
quantized_matmul
|
quantized_matmul
|
||||||
|
radians
|
||||||
|
real
|
||||||
reciprocal
|
reciprocal
|
||||||
|
remainder
|
||||||
repeat
|
repeat
|
||||||
reshape
|
reshape
|
||||||
|
right_shift
|
||||||
|
roll
|
||||||
round
|
round
|
||||||
rsqrt
|
rsqrt
|
||||||
save
|
save
|
||||||
@@ -106,6 +151,7 @@ Operations
|
|||||||
square
|
square
|
||||||
squeeze
|
squeeze
|
||||||
stack
|
stack
|
||||||
|
std
|
||||||
stop_gradient
|
stop_gradient
|
||||||
subtract
|
subtract
|
||||||
sum
|
sum
|
||||||
@@ -115,11 +161,15 @@ Operations
|
|||||||
tan
|
tan
|
||||||
tanh
|
tanh
|
||||||
tensordot
|
tensordot
|
||||||
|
tile
|
||||||
|
topk
|
||||||
|
trace
|
||||||
transpose
|
transpose
|
||||||
tri
|
tri
|
||||||
tril
|
tril
|
||||||
triu
|
triu
|
||||||
var
|
var
|
||||||
|
view
|
||||||
where
|
where
|
||||||
zeros
|
zeros
|
||||||
zeros_like
|
zeros_like
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
.. _optimizers:
|
.. _optimizers:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
Optimizers
|
Optimizers
|
||||||
==========
|
==========
|
||||||
|
|
||||||
@@ -29,19 +31,48 @@ model's parameters and the **optimizer state**.
|
|||||||
# Compute the new parameters but also the optimizer state.
|
# Compute the new parameters but also the optimizer state.
|
||||||
mx.eval(model.parameters(), optimizer.state)
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
|
||||||
.. currentmodule:: mlx.optimizers
|
Saving and Loading
|
||||||
|
------------------
|
||||||
|
|
||||||
|
To serialize an optimizer, save its state. To load an optimizer, load and set
|
||||||
|
the saved state. Here's a simple example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
|
||||||
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
|
# Perform some updates with the optimizer
|
||||||
|
model = {"w" : mx.zeros((5, 5))}
|
||||||
|
grads = {"w" : mx.ones((5, 5))}
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
|
# Save the state
|
||||||
|
state = tree_flatten(optimizer.state)
|
||||||
|
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||||
|
|
||||||
|
# Later on, for example when loading from a checkpoint,
|
||||||
|
# recreate the optimizer and load the state
|
||||||
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
|
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||||
|
optimizer.state = state
|
||||||
|
|
||||||
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
|
||||||
|
parameters are not. A good rule of thumb is if the parameter can be scheduled
|
||||||
|
then it will be included in the optimizer state.
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
|
||||||
|
optimizers/optimizer
|
||||||
|
optimizers/common_optimizers
|
||||||
|
optimizers/schedulers
|
||||||
|
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
:template: optimizers-template.rst
|
|
||||||
|
|
||||||
OptimizerState
|
clip_grad_norm
|
||||||
Optimizer
|
|
||||||
SGD
|
|
||||||
RMSprop
|
|
||||||
Adagrad
|
|
||||||
AdaDelta
|
|
||||||
Adam
|
|
||||||
AdamW
|
|
||||||
Adamax
|
|
||||||
Lion
|
|
||||||
|
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
.. _common_optimizers:
|
||||||
|
|
||||||
|
Common Optimizers
|
||||||
|
=================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
:template: optimizers-template.rst
|
||||||
|
|
||||||
|
SGD
|
||||||
|
RMSprop
|
||||||
|
Adagrad
|
||||||
|
Adafactor
|
||||||
|
AdaDelta
|
||||||
|
Adam
|
||||||
|
AdamW
|
||||||
|
Adamax
|
||||||
|
Lion
|
23
docs/src/python/optimizers/optimizer.rst
Normal file
23
docs/src/python/optimizers/optimizer.rst
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
Optimizer
|
||||||
|
=========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
|
.. autoclass:: Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
.. rubric:: Attributes
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Optimizer.state
|
||||||
|
|
||||||
|
.. rubric:: Methods
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
Optimizer.apply_gradients
|
||||||
|
Optimizer.init
|
||||||
|
Optimizer.update
|
15
docs/src/python/optimizers/schedulers.rst
Normal file
15
docs/src/python/optimizers/schedulers.rst
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
.. _schedulers:
|
||||||
|
|
||||||
|
Schedulers
|
||||||
|
==========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
cosine_decay
|
||||||
|
exponential_decay
|
||||||
|
join_schedules
|
||||||
|
linear_schedule
|
||||||
|
step_decay
|
@@ -38,8 +38,11 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
|||||||
gumbel
|
gumbel
|
||||||
key
|
key
|
||||||
normal
|
normal
|
||||||
|
multivariate_normal
|
||||||
randint
|
randint
|
||||||
seed
|
seed
|
||||||
split
|
split
|
||||||
truncated_normal
|
truncated_normal
|
||||||
uniform
|
uniform
|
||||||
|
laplace
|
||||||
|
permutation
|
||||||
|
@@ -9,9 +9,12 @@ Transforms
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
eval
|
eval
|
||||||
|
compile
|
||||||
|
custom_function
|
||||||
|
disable_compile
|
||||||
|
enable_compile
|
||||||
grad
|
grad
|
||||||
value_and_grad
|
value_and_grad
|
||||||
jvp
|
jvp
|
||||||
vjp
|
vjp
|
||||||
vmap
|
vmap
|
||||||
simplify
|
|
||||||
|
@@ -19,3 +19,5 @@ return python trees will be using the default python ``dict``, ``list`` and
|
|||||||
tree_flatten
|
tree_flatten
|
||||||
tree_unflatten
|
tree_unflatten
|
||||||
tree_map
|
tree_map
|
||||||
|
tree_map_with_path
|
||||||
|
tree_reduce
|
||||||
|
423
docs/src/usage/compile.rst
Normal file
423
docs/src/usage/compile.rst
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
.. _compile:
|
||||||
|
|
||||||
|
Compilation
|
||||||
|
===========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
MLX has a :func:`compile` function transformation which compiles computation
|
||||||
|
graphs. Function compilation results in smaller graphs by merging common work
|
||||||
|
and fusing certain operations. In many cases this can lead to big improvements
|
||||||
|
in run-time and memory use.
|
||||||
|
|
||||||
|
Getting started with :func:`compile` is simple, but there are some edge cases
|
||||||
|
that are good to be aware of for more complex graphs and advanced usage.
|
||||||
|
|
||||||
|
Basics of Compile
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
Let's start with a simple example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return mx.exp(-x) + y
|
||||||
|
|
||||||
|
x = mx.array(1.0)
|
||||||
|
y = mx.array(2.0)
|
||||||
|
|
||||||
|
# Regular call, no compilation
|
||||||
|
# Prints: array(2.36788, dtype=float32)
|
||||||
|
print(fun(x, y))
|
||||||
|
|
||||||
|
# Compile the function
|
||||||
|
compiled_fun = mx.compile(fun)
|
||||||
|
|
||||||
|
# Prints: array(2.36788, dtype=float32)
|
||||||
|
print(compiled_fun(x, y))
|
||||||
|
|
||||||
|
The output of both the regular function and the compiled function is the same
|
||||||
|
up to numerical precision.
|
||||||
|
|
||||||
|
The first time you call a compiled function, MLX will build the compute
|
||||||
|
graph, optimize it, and generate and compile code. This can be relatively
|
||||||
|
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||||
|
function multiple times will not initiate a new compilation. This means you
|
||||||
|
should typically compile functions that you plan to use more than once.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def fun(x, y):
|
||||||
|
return mx.exp(-x) + y
|
||||||
|
|
||||||
|
x = mx.array(1.0)
|
||||||
|
y = mx.array(2.0)
|
||||||
|
|
||||||
|
compiled_fun = mx.compile(fun)
|
||||||
|
|
||||||
|
# Compiled here
|
||||||
|
compiled_fun(x, y)
|
||||||
|
|
||||||
|
# Not compiled again
|
||||||
|
compiled_fun(x, y)
|
||||||
|
|
||||||
|
# Not compiled again
|
||||||
|
mx.compile(fun)(x, y)
|
||||||
|
|
||||||
|
There are some important cases to be aware of that can cause a function to
|
||||||
|
be recompiled:
|
||||||
|
|
||||||
|
* Changing the shape or number of dimensions
|
||||||
|
* Changing the type of any of the inputs
|
||||||
|
* Changing the number of inputs to the function
|
||||||
|
|
||||||
|
In certain cases only some of the compilation stack will be rerun (for
|
||||||
|
example when changing the shapes) and in other cases the full compilation
|
||||||
|
stack will be rerun (for example when changing the types). In general you
|
||||||
|
should avoid compiling functions too frequently.
|
||||||
|
|
||||||
|
Another idiom to watch out for is compiling functions which get created and
|
||||||
|
destroyed frequently. This can happen, for example, when compiling an anonymous
|
||||||
|
function in a loop:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
a = mx.array(1.0)
|
||||||
|
# Don't do this, compiles lambda at each iteration
|
||||||
|
for _ in range(5):
|
||||||
|
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
|
||||||
|
|
||||||
|
Example Speedup
|
||||||
|
---------------
|
||||||
|
|
||||||
|
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
|
||||||
|
Transformer-based models. The implementation involves several unary and binary
|
||||||
|
element-wise operations:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def gelu(x):
|
||||||
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||||
|
|
||||||
|
If you use this function with small arrays, it will be overhead bound. If you
|
||||||
|
use it with large arrays it will be memory bandwidth bound. However, all of
|
||||||
|
the operations in the ``gelu`` are fusible into a single kernel with
|
||||||
|
:func:`compile`. This can speedup both cases considerably.
|
||||||
|
|
||||||
|
Let's compare the runtime of the regular function versus the compiled
|
||||||
|
function. We'll use the following timing helper which does a warm up and
|
||||||
|
handles synchronization:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
def timeit(fun, x):
|
||||||
|
# warm up
|
||||||
|
for _ in range(10):
|
||||||
|
mx.eval(fun(x))
|
||||||
|
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(100):
|
||||||
|
mx.eval(fun(x))
|
||||||
|
toc = time.perf_counter()
|
||||||
|
tpi = 1e3 * (toc - tic) / 100
|
||||||
|
print(f"Time per iteration {tpi:.3f} (ms)")
|
||||||
|
|
||||||
|
|
||||||
|
Now make an array, and benchmark both functions:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||||
|
timeit(nn.gelu, x)
|
||||||
|
timeit(mx.compile(nn.gelu), x)
|
||||||
|
|
||||||
|
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||||
|
five times faster.
|
||||||
|
|
||||||
|
Debugging
|
||||||
|
---------
|
||||||
|
|
||||||
|
When a compiled function is first called, it is traced with placeholder
|
||||||
|
inputs. This means you can't evaluate arrays (for example to print their
|
||||||
|
contents) inside compiled functions.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x):
|
||||||
|
z = -x
|
||||||
|
print(z) # Crash
|
||||||
|
return mx.exp(z)
|
||||||
|
|
||||||
|
fun(mx.array(5.0))
|
||||||
|
|
||||||
|
For debugging, inspecting arrays can be helpful. One way to do that is to
|
||||||
|
globally disable compilation using the :func:`disable_compile` function or
|
||||||
|
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
|
||||||
|
``fun`` is compiled:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x):
|
||||||
|
z = -x
|
||||||
|
print(z) # Okay
|
||||||
|
return mx.exp(z)
|
||||||
|
|
||||||
|
mx.disable_compile()
|
||||||
|
fun(mx.array(5.0))
|
||||||
|
|
||||||
|
|
||||||
|
Pure Functions
|
||||||
|
--------------
|
||||||
|
|
||||||
|
Compiled functions are intended to be *pure*; that is they should not have side
|
||||||
|
effects. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
state = []
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x, y):
|
||||||
|
z = x + y
|
||||||
|
state.append(z)
|
||||||
|
return mx.exp(z)
|
||||||
|
|
||||||
|
fun(mx.array(1.0), mx.array(2.0))
|
||||||
|
# Crash!
|
||||||
|
print(state)
|
||||||
|
|
||||||
|
After the first call of ``fun``, the ``state`` list will hold a placeholder
|
||||||
|
array. The placeholder does not have any data; it is only used to build the
|
||||||
|
computation graph. Printing such an array results in a crash.
|
||||||
|
|
||||||
|
You have two options to deal with this. The first option is to simply return
|
||||||
|
``state`` as an output:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
state = []
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x, y):
|
||||||
|
z = x + y
|
||||||
|
state.append(z)
|
||||||
|
return mx.exp(z), state
|
||||||
|
|
||||||
|
_, state = fun(mx.array(1.0), mx.array(2.0))
|
||||||
|
# Prints [array(3, dtype=float32)]
|
||||||
|
print(state)
|
||||||
|
|
||||||
|
In some cases returning updated state can be pretty inconvenient. Hence,
|
||||||
|
:func:`compile` has a parameter to capture implicit outputs:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
state = []
|
||||||
|
|
||||||
|
# Tell compile to capture state as an output
|
||||||
|
@partial(mx.compile, outputs=state)
|
||||||
|
def fun(x, y):
|
||||||
|
z = x + y
|
||||||
|
state.append(z)
|
||||||
|
return mx.exp(z), state
|
||||||
|
|
||||||
|
fun(mx.array(1.0), mx.array(2.0))
|
||||||
|
# Prints [array(3, dtype=float32)]
|
||||||
|
print(state)
|
||||||
|
|
||||||
|
This is particularly useful for compiling a function which includes an update
|
||||||
|
to a container of arrays, as is commonly done when training the parameters of a
|
||||||
|
:class:`mlx.nn.Module`.
|
||||||
|
|
||||||
|
Compiled functions will also treat any inputs not in the parameter list as
|
||||||
|
constants. For example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
state = [mx.array(1.0)]
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun(x):
|
||||||
|
return x + state[0]
|
||||||
|
|
||||||
|
# Prints array(2, dtype=float32)
|
||||||
|
print(fun(mx.array(1.0)))
|
||||||
|
|
||||||
|
# Update state
|
||||||
|
state[0] = mx.array(5.0)
|
||||||
|
|
||||||
|
# Still prints array(2, dtype=float32)
|
||||||
|
print(fun(mx.array(1.0)))
|
||||||
|
|
||||||
|
In order to have the change of state reflected in the outputs of ``fun`` you
|
||||||
|
again have two options. The first option is to simply pass ``state`` as input
|
||||||
|
to the function. In some cases this can be pretty inconvenient. Hence,
|
||||||
|
:func:`compile` also has a parameter to capture implicit inputs:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
state = [mx.array(1.0)]
|
||||||
|
|
||||||
|
# Tell compile to capture state as an input
|
||||||
|
@partial(mx.compile, inputs=state)
|
||||||
|
def fun(x):
|
||||||
|
return x + state[0]
|
||||||
|
|
||||||
|
# Prints array(2, dtype=float32)
|
||||||
|
print(fun(mx.array(1.0)))
|
||||||
|
|
||||||
|
# Update state
|
||||||
|
state[0] = mx.array(5.0)
|
||||||
|
|
||||||
|
# Prints array(6, dtype=float32)
|
||||||
|
print(fun(mx.array(1.0)))
|
||||||
|
|
||||||
|
|
||||||
|
Compiling Training Graphs
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
This section will step through how to use :func:`compile` with a simple example
|
||||||
|
of a common setup: training a model with :obj:`mlx.nn.Module` using an
|
||||||
|
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
|
||||||
|
full forward, backward, and update with :func:`compile`.
|
||||||
|
|
||||||
|
To start, here is the simple example without any compilation:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
|
||||||
|
# 4 examples with 10 features each
|
||||||
|
x = mx.random.uniform(shape=(4, 10))
|
||||||
|
|
||||||
|
# 0, 1 targets
|
||||||
|
y = mx.array([0, 1, 0, 1])
|
||||||
|
|
||||||
|
# Simple linear model
|
||||||
|
model = nn.Linear(10, 1)
|
||||||
|
|
||||||
|
# SGD with momentum
|
||||||
|
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||||
|
|
||||||
|
def loss_fn(model, x, y):
|
||||||
|
logits = model(x).squeeze()
|
||||||
|
return nn.losses.binary_cross_entropy(logits, y)
|
||||||
|
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
|
||||||
|
# Perform 10 steps of gradient descent
|
||||||
|
for it in range(10):
|
||||||
|
loss, grads = loss_and_grad_fn(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
|
||||||
|
To compile the update we can put it all in a function and compile it with the
|
||||||
|
appropriate input and output captures. Here's the same example but compiled:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
# 4 examples with 10 features each
|
||||||
|
x = mx.random.uniform(shape=(4, 10))
|
||||||
|
|
||||||
|
# 0, 1 targets
|
||||||
|
y = mx.array([0, 1, 0, 1])
|
||||||
|
|
||||||
|
# Simple linear model
|
||||||
|
model = nn.Linear(10, 1)
|
||||||
|
|
||||||
|
# SGD with momentum
|
||||||
|
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||||
|
|
||||||
|
def loss_fn(model, x, y):
|
||||||
|
logits = model(x).squeeze()
|
||||||
|
return nn.losses.binary_cross_entropy(logits, y)
|
||||||
|
|
||||||
|
# The state that will be captured as input and output
|
||||||
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def step(x, y):
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
loss, grads = loss_and_grad_fn(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
# Perform 10 steps of gradient descent
|
||||||
|
for it in range(10):
|
||||||
|
loss = step(x, y)
|
||||||
|
# Evaluate the model and optimizer state
|
||||||
|
mx.eval(state)
|
||||||
|
print(loss)
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
If you are using a module which performs random sampling such as
|
||||||
|
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
|
||||||
|
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
|
||||||
|
optimizer.state, mx.random.state]``.
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
For more examples of compiling full training graphs checkout the `MLX
|
||||||
|
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
|
||||||
|
|
||||||
|
Transformations with Compile
|
||||||
|
----------------------------
|
||||||
|
|
||||||
|
In MLX function transformations are composable. You can apply any function
|
||||||
|
transformation to the output of any other function transformation. For more on
|
||||||
|
this, see the documentation on :ref:`function transforms
|
||||||
|
<function_transforms>`.
|
||||||
|
|
||||||
|
Compiling transformed functions works just as expected:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
grad_fn = mx.grad(mx.exp)
|
||||||
|
|
||||||
|
compiled_grad_fn = mx.compile(grad_fn)
|
||||||
|
|
||||||
|
# Prints: array(2.71828, dtype=float32)
|
||||||
|
print(grad_fn(mx.array(1.0)))
|
||||||
|
|
||||||
|
# Also prints: array(2.71828, dtype=float32)
|
||||||
|
print(compiled_grad_fn(mx.array(1.0)))
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
In order to compile as much as possible, a transformation of a compiled
|
||||||
|
function will not by default be compiled. To compile the transformed
|
||||||
|
function simply pass it through :func:`compile`.
|
||||||
|
|
||||||
|
You can also compile functions which themselves call compiled functions. A
|
||||||
|
good practice is to compile the outer most function to give :func:`compile`
|
||||||
|
the most opportunity to optimize the computation graph:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def inner(x):
|
||||||
|
return mx.exp(-mx.abs(x))
|
||||||
|
|
||||||
|
def outer(x):
|
||||||
|
inner(inner(x))
|
||||||
|
|
||||||
|
# Compiling the outer function is good to do as it will likely
|
||||||
|
# be faster even though the inner functions are compiled
|
||||||
|
fun = mx.compile(outer)
|
166
docs/src/usage/distributed.rst
Normal file
166
docs/src/usage/distributed.rst
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
.. _usage_distributed:
|
||||||
|
|
||||||
|
Distributed Communication
|
||||||
|
=========================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.distributed
|
||||||
|
|
||||||
|
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
||||||
|
provide distributed communication operations that allow the computational cost
|
||||||
|
of training or inference to be shared across many physical machines. You can
|
||||||
|
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
A lot of operations may not be supported or not as fast as they should be.
|
||||||
|
We are adding more and tuning the ones we have as we are figuring out the
|
||||||
|
best way to do distributed computing on Macs using MLX.
|
||||||
|
|
||||||
|
Getting Started
|
||||||
|
---------------
|
||||||
|
|
||||||
|
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||||
|
machine. The minimal distributed program in MLX is as simple as:
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
world = mx.distributed.init()
|
||||||
|
x = mx.distributed.all_sum(mx.ones(10))
|
||||||
|
print(world.rank(), x)
|
||||||
|
|
||||||
|
The program above sums the array ``mx.ones(10)`` across all
|
||||||
|
distributed processes. If simply run with ``python``, however, only one
|
||||||
|
process is launched and no distributed communication takes place.
|
||||||
|
|
||||||
|
To launch the program in distributed mode we need to use ``mpirun`` or
|
||||||
|
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
||||||
|
following:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ mpirun -np 2 python test.py
|
||||||
|
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||||
|
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||||
|
|
||||||
|
The above launches two processes on the same (local) machine and we can see
|
||||||
|
both standard output streams. The processes send the array of 1s to each other
|
||||||
|
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
||||||
|
print 4 etc.
|
||||||
|
|
||||||
|
Installing MPI
|
||||||
|
---------------
|
||||||
|
|
||||||
|
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||||
|
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||||
|
with the Anaconda package manager as follows:
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ conda install openmpi
|
||||||
|
|
||||||
|
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||||
|
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||||
|
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
||||||
|
|
||||||
|
.. code:: shell
|
||||||
|
|
||||||
|
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||||
|
|
||||||
|
Setting up Remote Hosts
|
||||||
|
-----------------------
|
||||||
|
|
||||||
|
MPI can automatically connect to remote hosts and set up the communication over
|
||||||
|
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||||
|
debug connectivity issues is the following:
|
||||||
|
|
||||||
|
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||||
|
password or host confirmation
|
||||||
|
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
||||||
|
full path to force all machines to use a specific path.
|
||||||
|
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||||
|
in the ``.ssh/config`` files on all machines.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
||||||
|
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
||||||
|
|
||||||
|
An easy way to pass the host names to MPI is using a host file. A host file
|
||||||
|
looks like the following, where ``host1`` and ``host2`` should be the fully
|
||||||
|
qualified domain names or IPs for these hosts.
|
||||||
|
|
||||||
|
.. code::
|
||||||
|
|
||||||
|
host1 slots=1
|
||||||
|
host2 slots=1
|
||||||
|
|
||||||
|
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
||||||
|
process per host. The hostfile also needs to contain the current
|
||||||
|
host if you want to run on the local host. Passing the host file to
|
||||||
|
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
||||||
|
|
||||||
|
Training Example
|
||||||
|
----------------
|
||||||
|
|
||||||
|
In this section we will adapt an MLX training loop to support data parallel
|
||||||
|
distributed training. Namely, we will average the gradients across a set of
|
||||||
|
hosts before applying them to the model.
|
||||||
|
|
||||||
|
Our training loop looks like the following code snippet if we omit the model,
|
||||||
|
dataset and optimizer initialization.
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
model = ...
|
||||||
|
optimizer = ...
|
||||||
|
dataset = ...
|
||||||
|
|
||||||
|
def step(model, x, y):
|
||||||
|
loss, grads = loss_grad_fn(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
for x, y in dataset:
|
||||||
|
loss = step(model, x, y)
|
||||||
|
mx.eval(loss, model.parameters())
|
||||||
|
|
||||||
|
All we have to do to average the gradients across machines is perform an
|
||||||
|
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
|
||||||
|
have to :func:`mlx.utils.tree_map` the gradients with following function.
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
def all_avg(x):
|
||||||
|
return mx.distributed.all_sum(x) / mx.distributed.init().size()
|
||||||
|
|
||||||
|
Putting everything together our training loop step looks as follows with
|
||||||
|
everything else remaining the same.
|
||||||
|
|
||||||
|
.. code:: python
|
||||||
|
|
||||||
|
from mlx.utils import tree_map
|
||||||
|
|
||||||
|
def all_reduce_grads(grads):
|
||||||
|
N = mx.distributed.init()
|
||||||
|
if N == 1:
|
||||||
|
return grads
|
||||||
|
return tree_map(
|
||||||
|
lambda x: mx.distributed.all_sum(x) / N,
|
||||||
|
grads)
|
||||||
|
|
||||||
|
def step(model, x, y):
|
||||||
|
loss, grads = loss_grad_fn(model, x, y)
|
||||||
|
grads = all_reduce_grads(grads) # <--- This line was added
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
Tuning All Reduce
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
We are working on improving the performance of all reduce on MLX but for now
|
||||||
|
the two main things one can do to extract the most out of distributed training with MLX are:
|
||||||
|
|
||||||
|
1. Perform a few large reductions instead of many small ones to improve
|
||||||
|
bandwidth and latency
|
||||||
|
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
||||||
|
connections between each host to improve bandwidth
|
@@ -5,9 +5,12 @@ Function Transforms
|
|||||||
|
|
||||||
.. currentmodule:: mlx.core
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
MLX uses composable function transformations for automatic differentiation and
|
MLX uses composable function transformations for automatic differentiation,
|
||||||
vectorization. The key idea behind composable function transformations is that
|
vectorization, and compute graph optimizations. To see the complete list of
|
||||||
every transformation returns a function which can be further transformed.
|
function transformations check-out the :ref:`API documentation <transforms>`.
|
||||||
|
|
||||||
|
The key idea behind composable function transformations is that every
|
||||||
|
transformation returns a function which can be further transformed.
|
||||||
|
|
||||||
Here is a simple example:
|
Here is a simple example:
|
||||||
|
|
||||||
@@ -22,7 +25,7 @@ Here is a simple example:
|
|||||||
|
|
||||||
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
||||||
case it is the gradient of the sine function which is exactly the cosine
|
case it is the gradient of the sine function which is exactly the cosine
|
||||||
function. To get the second derivative you can do:
|
function. To get the second derivative you can do:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
@@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep
|
|||||||
getting higher order derivatives.
|
getting higher order derivatives.
|
||||||
|
|
||||||
Any of the MLX function transformations can be composed in any order to any
|
Any of the MLX function transformations can be composed in any order to any
|
||||||
depth. To see the complete list of function transformations check-out the
|
depth. See the following sections for more information on :ref:`automatic
|
||||||
:ref:`API documentation <transforms>`. See the following sections for more
|
differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.
|
||||||
information on :ref:`automatic differentiaion <auto diff>` and
|
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
|
||||||
:ref:`automatic vectorization <vmap>`.
|
|
||||||
|
|
||||||
Automatic Differentiation
|
Automatic Differentiation
|
||||||
-------------------------
|
-------------------------
|
||||||
@@ -47,7 +50,7 @@ Automatic Differentiation
|
|||||||
.. _auto diff:
|
.. _auto diff:
|
||||||
|
|
||||||
Automatic differentiation in MLX works on functions rather than on implicit
|
Automatic differentiation in MLX works on functions rather than on implicit
|
||||||
graphs.
|
graphs.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@@ -111,7 +114,7 @@ way to do that is the following:
|
|||||||
|
|
||||||
def loss_fn(params, x, y):
|
def loss_fn(params, x, y):
|
||||||
w, b = params["weight"], params["bias"]
|
w, b = params["weight"], params["bias"]
|
||||||
h = w * x + b
|
h = w * x + b
|
||||||
return mx.mean(mx.square(h - y))
|
return mx.mean(mx.square(h - y))
|
||||||
|
|
||||||
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
||||||
@@ -129,7 +132,7 @@ way to do that is the following:
|
|||||||
|
|
||||||
Notice the tree structure of the parameters is preserved in the gradients.
|
Notice the tree structure of the parameters is preserved in the gradients.
|
||||||
|
|
||||||
In some cases you may want to stop gradients from propagating through a
|
In some cases you may want to stop gradients from propagating through a
|
||||||
part of the function. You can use the :func:`stop_gradient` for that.
|
part of the function. You can use the :func:`stop_gradient` for that.
|
||||||
|
|
||||||
|
|
||||||
@@ -158,19 +161,19 @@ A naive way to add the elements from two sets of vectors is with a loop:
|
|||||||
ys = mx.random.uniform(shape=(100, 4096))
|
ys = mx.random.uniform(shape=(100, 4096))
|
||||||
|
|
||||||
def naive_add(xs, ys):
|
def naive_add(xs, ys):
|
||||||
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
|
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
|
||||||
|
|
||||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# Vectorize over the second dimension of x and the
|
# Vectorize over the second dimension of x and the
|
||||||
# first dimension of y
|
# first dimension of y
|
||||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
|
||||||
|
|
||||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||||
where the vectorized axes should be in the outputs.
|
where the vectorized axes should be in the outputs.
|
||||||
|
|
||||||
Let's time these two different versions:
|
Let's time these two different versions:
|
||||||
|
|
||||||
@@ -181,8 +184,8 @@ Let's time these two different versions:
|
|||||||
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||||
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||||
|
|
||||||
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
|
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
|
||||||
vectorized version takes only ``0.025`` seconds, more than ten times faster.
|
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
|
||||||
|
|
||||||
Of course, this operation is quite contrived. A better approach is to simply do
|
Of course, this operation is quite contrived. A better approach is to simply do
|
||||||
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
||||||
|
@@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
|
|||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
>>> arr = mx.arange(10)
|
>>> arr = mx.arange(10)
|
||||||
>>> idx = mx.array([5, 7])
|
>>> idx = mx.array([5, 7])
|
||||||
>>> arr[idx]
|
>>> arr[idx]
|
||||||
array([5, 7], dtype=int32)
|
array([5, 7], dtype=int32)
|
||||||
|
|
||||||
@@ -77,12 +77,12 @@ from the GPU. Performing bounds checking for array indices before launching the
|
|||||||
kernel would be extremely inefficient.
|
kernel would be extremely inefficient.
|
||||||
|
|
||||||
Indexing with boolean masks is something that MLX may support in the future. In
|
Indexing with boolean masks is something that MLX may support in the future. In
|
||||||
general, MLX has limited support for operations for which outputs
|
general, MLX has limited support for operations for which output
|
||||||
*shapes* are dependent on input *data*. Other examples of these types of
|
*shapes* are dependent on input *data*. Other examples of these types of
|
||||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||||
single input version of :func:`numpy.where`.
|
single input version of :func:`numpy.where`.
|
||||||
|
|
||||||
In Place Updates
|
In Place Updates
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
In place updates to indexed arrays are possible in MLX. For example:
|
In place updates to indexed arrays are possible in MLX. For example:
|
||||||
|
@@ -13,14 +13,14 @@ compute graph is recorded. The actual computation only happens if an
|
|||||||
:func:`eval` is performed.
|
:func:`eval` is performed.
|
||||||
|
|
||||||
MLX uses lazy evaluation because it has some nice features, some of which we
|
MLX uses lazy evaluation because it has some nice features, some of which we
|
||||||
describe below.
|
describe below.
|
||||||
|
|
||||||
Transforming Compute Graphs
|
Transforming Compute Graphs
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Lazy evaluation let's us record a compute graph without actually doing any
|
Lazy evaluation lets us record a compute graph without actually doing any
|
||||||
computations. This is useful for function transformations like :func:`grad` and
|
computations. This is useful for function transformations like :func:`grad` and
|
||||||
:func:`vmap` and graph optimizations like :func:`simplify`.
|
:func:`vmap` and graph optimizations.
|
||||||
|
|
||||||
Currently, MLX does not compile and rerun compute graphs. They are all
|
Currently, MLX does not compile and rerun compute graphs. They are all
|
||||||
generated dynamically. However, lazy evaluation makes it much easier to
|
generated dynamically. However, lazy evaluation makes it much easier to
|
||||||
@@ -109,14 +109,14 @@ Here is a concrete example:
|
|||||||
|
|
||||||
An important behavior to be aware of is when the graph will be implicitly
|
An important behavior to be aware of is when the graph will be implicitly
|
||||||
evaluated. Anytime you ``print`` an array, convert it to an
|
evaluated. Anytime you ``print`` an array, convert it to an
|
||||||
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
|
||||||
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||||
saving functions) will also evaluate the array.
|
saving functions) will also evaluate the array.
|
||||||
|
|
||||||
|
|
||||||
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
||||||
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
||||||
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||||
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
||||||
will be a partial evaluation, computing only the forward pass.
|
will be a partial evaluation, computing only the forward pass.
|
||||||
|
|
||||||
|
@@ -3,7 +3,11 @@
|
|||||||
Conversion to NumPy and Other Frameworks
|
Conversion to NumPy and Other Frameworks
|
||||||
========================================
|
========================================
|
||||||
|
|
||||||
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
MLX array supports conversion between other frameworks with either:
|
||||||
|
|
||||||
|
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||||
|
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||||
|
|
||||||
Let's convert an array to NumPy and back.
|
Let's convert an array to NumPy and back.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@@ -62,7 +66,7 @@ even though no in-place operations on MLX memory are executed.
|
|||||||
PyTorch
|
PyTorch
|
||||||
-------
|
-------
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||||
|
@@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
|
|||||||
and :func:`jvp` for Jacobian-vector products.
|
and :func:`jvp` for Jacobian-vector products.
|
||||||
|
|
||||||
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
||||||
gradient with respect to the function's input.
|
gradient with respect to the function's input.
|
||||||
|
@@ -8,33 +8,33 @@ Saving and Loading Arrays
|
|||||||
MLX supports multiple array serialization formats.
|
MLX supports multiple array serialization formats.
|
||||||
|
|
||||||
.. list-table:: Serialization Formats
|
.. list-table:: Serialization Formats
|
||||||
:widths: 20 8 25 25
|
:widths: 20 8 25 25
|
||||||
:header-rows: 1
|
:header-rows: 1
|
||||||
|
|
||||||
* - Format
|
* - Format
|
||||||
- Extension
|
- Extension
|
||||||
- Function
|
- Function
|
||||||
- Notes
|
- Notes
|
||||||
* - NumPy
|
* - NumPy
|
||||||
- ``.npy``
|
- ``.npy``
|
||||||
- :func:`save`
|
- :func:`save`
|
||||||
- Single arrays only
|
- Single arrays only
|
||||||
* - NumPy archive
|
* - NumPy archive
|
||||||
- ``.npz``
|
- ``.npz``
|
||||||
- :func:`savez` and :func:`savez_compressed`
|
- :func:`savez` and :func:`savez_compressed`
|
||||||
- Multiple arrays
|
- Multiple arrays
|
||||||
* - Safetensors
|
* - Safetensors
|
||||||
- ``.safetensors``
|
- ``.safetensors``
|
||||||
- :func:`save_safetensors`
|
- :func:`save_safetensors`
|
||||||
- Multiple arrays
|
- Multiple arrays
|
||||||
* - GGUF
|
* - GGUF
|
||||||
- ``.gguf``
|
- ``.gguf``
|
||||||
- :func:`save_gguf`
|
- :func:`save_gguf`
|
||||||
- Multiple arrays
|
- Multiple arrays
|
||||||
|
|
||||||
The :func:`load` function will load any of the supported serialization
|
The :func:`load` function will load any of the supported serialization
|
||||||
formats. It determines the format from the extensions. The output of
|
formats. It determines the format from the extensions. The output of
|
||||||
:func:`load` depends on the format.
|
:func:`load` depends on the format.
|
||||||
|
|
||||||
Here's an example of saving a single array to a file:
|
Here's an example of saving a single array to a file:
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ it will be added. You can load the array with:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
>>> mx.load("array.npy", a)
|
>>> mx.load("array.npy")
|
||||||
array([1], dtype=float32)
|
array([1], dtype=float32)
|
||||||
|
|
||||||
Here's an example of saving several arrays to a single file:
|
Here's an example of saving several arrays to a single file:
|
||||||
|
@@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
|
|||||||
|
|
||||||
In MLX, rather than moving arrays to devices, you specify the device when you
|
In MLX, rather than moving arrays to devices, you specify the device when you
|
||||||
run the operation. Any device can perform any operation on ``a`` and ``b``
|
run the operation. Any device can perform any operation on ``a`` and ``b``
|
||||||
without needing to move them from one memory location to another. For example:
|
without needing to move them from one memory location to another. For example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@@ -8,3 +8,5 @@ endfunction(build_example)
|
|||||||
build_example(tutorial.cpp)
|
build_example(tutorial.cpp)
|
||||||
build_example(linear_regression.cpp)
|
build_example(linear_regression.cpp)
|
||||||
build_example(logistic_regression.cpp)
|
build_example(logistic_regression.cpp)
|
||||||
|
build_example(metal_capture.cpp)
|
||||||
|
build_example(distributed.cpp)
|
||||||
|
22
examples/cpp/distributed.cpp
Normal file
22
examples/cpp/distributed.cpp
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
if (!distributed::is_available()) {
|
||||||
|
std::cout << "No communication backend found" << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto global_group = distributed::init();
|
||||||
|
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||||
|
|
||||||
|
array x = ones({10});
|
||||||
|
array out = distributed::all_sum(x, global_group);
|
||||||
|
|
||||||
|
std::cout << out << std::endl;
|
||||||
|
}
|
31
examples/cpp/metal_capture.cpp
Normal file
31
examples/cpp/metal_capture.cpp
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
// To use Metal debugging and profiling:
|
||||||
|
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||||
|
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||||
|
metal::start_capture("mlx_trace.gputrace");
|
||||||
|
|
||||||
|
// Start at index two because the default GPU and CPU streams have indices
|
||||||
|
// zero and one, respectively. This naming matches the label assigned to each
|
||||||
|
// stream's command queue.
|
||||||
|
auto s2 = new_stream(Device::gpu);
|
||||||
|
auto s3 = new_stream(Device::gpu);
|
||||||
|
|
||||||
|
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||||
|
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||||
|
auto x = add(a, a, s2);
|
||||||
|
auto y = add(b, b, s3);
|
||||||
|
|
||||||
|
// The multiply will happen on the default stream.
|
||||||
|
std::cout << multiply(x, y) << std::endl;
|
||||||
|
|
||||||
|
metal::stop_capture();
|
||||||
|
}
|
@@ -89,8 +89,8 @@ void automatic_differentiation() {
|
|||||||
// dfdx is 2 * x
|
// dfdx is 2 * x
|
||||||
|
|
||||||
// Get the second derivative by composing grad with grad
|
// Get the second derivative by composing grad with grad
|
||||||
auto df2dx2 = grad(grad(fn))(x);
|
auto d2fdx2 = grad(grad(fn))(x);
|
||||||
// df2dx2 is 2
|
// d2fdx2 is 2
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
cmake_minimum_required(VERSION 3.24)
|
cmake_minimum_required(VERSION 3.27)
|
||||||
|
|
||||||
project(mlx_sample_extensions LANGUAGES CXX)
|
project(_ext LANGUAGES CXX)
|
||||||
|
|
||||||
# ----------------------------- Setup -----------------------------
|
# ----------------------------- Setup -----------------------------
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
@@ -11,8 +11,16 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
|||||||
|
|
||||||
# ----------------------------- Dependencies -----------------------------
|
# ----------------------------- Dependencies -----------------------------
|
||||||
find_package(MLX CONFIG REQUIRED)
|
find_package(MLX CONFIG REQUIRED)
|
||||||
find_package(Python COMPONENTS Interpreter Development)
|
find_package(
|
||||||
find_package(pybind11 CONFIG REQUIRED)
|
Python 3.8
|
||||||
|
COMPONENTS Interpreter Development.Module
|
||||||
|
REQUIRED)
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE NB_DIR)
|
||||||
|
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||||
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
|
|
||||||
# ----------------------------- Extensions -----------------------------
|
# ----------------------------- Extensions -----------------------------
|
||||||
|
|
||||||
@@ -20,16 +28,10 @@ find_package(pybind11 CONFIG REQUIRED)
|
|||||||
add_library(mlx_ext)
|
add_library(mlx_ext)
|
||||||
|
|
||||||
# Add sources
|
# Add sources
|
||||||
target_sources(
|
target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)
|
||||||
mlx_ext
|
|
||||||
PUBLIC
|
|
||||||
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add include headers
|
# Add include headers
|
||||||
target_include_directories(
|
target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})
|
||||||
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Link to mlx
|
# Link to mlx
|
||||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||||
@@ -38,29 +40,35 @@ target_link_libraries(mlx_ext PUBLIC mlx)
|
|||||||
|
|
||||||
# Build metallib
|
# Build metallib
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
|
|
||||||
mlx_build_metallib(
|
mlx_build_metallib(
|
||||||
TARGET mlx_ext_metallib
|
TARGET
|
||||||
TITLE mlx_ext
|
|
||||||
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
|
||||||
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
|
|
||||||
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
|
|
||||||
)
|
|
||||||
|
|
||||||
add_dependencies(
|
|
||||||
mlx_ext
|
|
||||||
mlx_ext_metallib
|
mlx_ext_metallib
|
||||||
)
|
TITLE
|
||||||
|
mlx_ext
|
||||||
|
SOURCES
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
||||||
|
INCLUDE_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}
|
||||||
|
${MLX_INCLUDE_DIRS}
|
||||||
|
OUTPUT_DIRECTORY
|
||||||
|
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
||||||
|
|
||||||
|
add_dependencies(mlx_ext mlx_ext_metallib)
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Pybind -----------------------------
|
# ----------------------------- Python Bindings -----------------------------
|
||||||
pybind11_add_module(
|
nanobind_add_module(
|
||||||
mlx_sample_extensions
|
_ext
|
||||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
NB_STATIC
|
||||||
)
|
STABLE_ABI
|
||||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
LTO
|
||||||
|
NOMINSIZE
|
||||||
|
NB_DOMAIN
|
||||||
|
mlx
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
|
||||||
|
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||||
|
|
||||||
if(BUILD_SHARED_LIBS)
|
if(BUILD_SHARED_LIBS)
|
||||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
|
||||||
endif()
|
endif()
|
||||||
|
24
examples/extensions/README.md
Normal file
24
examples/extensions/README.md
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
|
||||||
|
## Build
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For faster builds during development, you can also pre-install the requirements:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
And then run:
|
||||||
|
|
||||||
|
```
|
||||||
|
python setup.py build_ext -j8 --inplace
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test
|
||||||
|
|
||||||
|
```
|
||||||
|
python test.py
|
||||||
|
```
|
@@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
@@ -43,7 +43,7 @@ array axpby(
|
|||||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||||
|
|
||||||
// Upcast to float32 for non-floating point inputs x and y
|
// Upcast to float32 for non-floating point inputs x and y
|
||||||
auto out_dtype = is_floating_point(promoted_dtype)
|
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||||
? promoted_dtype
|
? promoted_dtype
|
||||||
: promote_types(promoted_dtype, float32);
|
: promote_types(promoted_dtype, float32);
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ array axpby(
|
|||||||
/* const std::vector<int>& shape = */ out_shape,
|
/* const std::vector<int>& shape = */ out_shape,
|
||||||
/* Dtype dtype = */ out_dtype,
|
/* Dtype dtype = */ out_dtype,
|
||||||
/* std::unique_ptr<Primitive> primitive = */
|
/* std::unique_ptr<Primitive> primitive = */
|
||||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
std::make_shared<Axpby>(to_stream(s), alpha, beta),
|
||||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,12 +106,12 @@ void axpby_impl(
|
|||||||
/** Fall back implementation for evaluation on CPU */
|
/** Fall back implementation for evaluation on CPU */
|
||||||
void Axpby::eval(
|
void Axpby::eval(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& out_arr) {
|
std::vector<array>& outputs) {
|
||||||
auto out = out_arr[0];
|
|
||||||
// Check the inputs (registered in the op while constructing the out array)
|
// Check the inputs (registered in the op while constructing the out array)
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Dispatch to the correct dtype
|
// Dispatch to the correct dtype
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == float32) {
|
||||||
@@ -150,11 +150,7 @@ void axpby_impl_accelerate(
|
|||||||
// The data in the output array is allocated to match the strides in y
|
// The data in the output array is allocated to match the strides in y
|
||||||
// such that x, y, and out are contiguous in the same mode and
|
// such that x, y, and out are contiguous in the same mode and
|
||||||
// no transposition is needed
|
// no transposition is needed
|
||||||
out.set_data(
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
|
||||||
y.data_size(),
|
|
||||||
y.strides(),
|
|
||||||
y.flags());
|
|
||||||
|
|
||||||
// We then copy over the elements using the contiguous vector specialization
|
// We then copy over the elements using the contiguous vector specialization
|
||||||
copy_inplace(y, out, CopyType::Vector);
|
copy_inplace(y, out, CopyType::Vector);
|
||||||
@@ -180,11 +176,11 @@ void axpby_impl_accelerate(
|
|||||||
/** Evaluate primitive on CPU using accelerate specializations */
|
/** Evaluate primitive on CPU using accelerate specializations */
|
||||||
void Axpby::eval_cpu(
|
void Axpby::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outarr) {
|
std::vector<array>& outputs) {
|
||||||
auto out = outarr[0];
|
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Accelerate specialization for contiguous single precision float arrays
|
// Accelerate specialization for contiguous single precision float arrays
|
||||||
if (out.dtype() == float32 &&
|
if (out.dtype() == float32 &&
|
||||||
@@ -195,7 +191,7 @@ void Axpby::eval_cpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to common backend if specializations are not available
|
// Fall back to common backend if specializations are not available
|
||||||
eval(inputs, outarr);
|
eval(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // Accelerate not available
|
#else // Accelerate not available
|
||||||
@@ -203,8 +199,8 @@ void Axpby::eval_cpu(
|
|||||||
/** Evaluate primitive on CPU falling back to common backend */
|
/** Evaluate primitive on CPU falling back to common backend */
|
||||||
void Axpby::eval_cpu(
|
void Axpby::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& out) {
|
const std::vector<array>& outputs) {
|
||||||
eval(inputs, out);
|
eval(inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
@@ -218,12 +214,12 @@ void Axpby::eval_cpu(
|
|||||||
/** Evaluate primitive on GPU */
|
/** Evaluate primitive on GPU */
|
||||||
void Axpby::eval_gpu(
|
void Axpby::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outarr) {
|
std::vector<array>& outputs) {
|
||||||
// Prepare inputs
|
// Prepare inputs
|
||||||
auto out = outarr[0];
|
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Each primitive carries the stream it should execute on
|
// Each primitive carries the stream it should execute on
|
||||||
// and each stream carries its device identifiers
|
// and each stream carries its device identifiers
|
||||||
@@ -253,16 +249,15 @@ void Axpby::eval_gpu(
|
|||||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
kname << type_to_name(out);
|
kname << type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available and look for it
|
// Make sure the metal library is available
|
||||||
// in the same folder as this executable if needed
|
d.register_library("mlx_ext");
|
||||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
// Kernel parameters are registered with buffer indices corresponding to
|
// Kernel parameters are registered with buffer indices corresponding to
|
||||||
// those in the kernel declaration at axpby.metal
|
// those in the kernel declaration at axpby.metal
|
||||||
@@ -270,22 +265,22 @@ void Axpby::eval_gpu(
|
|||||||
size_t nelem = out.size();
|
size_t nelem = out.size();
|
||||||
|
|
||||||
// Encode input arrays to kernel
|
// Encode input arrays to kernel
|
||||||
set_array_buffer(compute_encoder, x, 0);
|
compute_encoder.set_input_array(x, 0);
|
||||||
set_array_buffer(compute_encoder, y, 1);
|
compute_encoder.set_input_array(y, 1);
|
||||||
|
|
||||||
// Encode output arrays to kernel
|
// Encode output arrays to kernel
|
||||||
set_array_buffer(compute_encoder, out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
|
||||||
// Encode alpha and beta
|
// Encode alpha and beta
|
||||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
compute_encoder.set_bytes(alpha_, 3);
|
||||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
compute_encoder.set_bytes(beta_, 4);
|
||||||
|
|
||||||
// Encode shape, strides and ndim if needed
|
// Encode shape, strides and ndim if needed
|
||||||
if (!contiguous_kernel) {
|
if (!contiguous_kernel) {
|
||||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
compute_encoder.set_bytes(y.strides(), 7);
|
||||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
compute_encoder.set_bytes(ndim, 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
// We launch 1 thread for each input and make sure that the number of
|
// We launch 1 thread for each input and make sure that the number of
|
||||||
@@ -300,7 +295,7 @@ void Axpby::eval_gpu(
|
|||||||
|
|
||||||
// Launch the grid with the given number of threads divided among
|
// Launch the grid with the given number of threads divided among
|
||||||
// the given threadgroups
|
// the given threadgroups
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // Metal is not available
|
#else // Metal is not available
|
||||||
@@ -372,4 +367,4 @@ bool Axpby::is_equivalent(const Primitive& other) const {
|
|||||||
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -33,7 +33,7 @@ array axpby(
|
|||||||
class Axpby : public Primitive {
|
class Axpby : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit Axpby(Stream stream, float alpha, float beta)
|
explicit Axpby(Stream stream, float alpha, float beta)
|
||||||
: Primitive(stream), alpha_(alpha), beta_(beta){};
|
: Primitive(stream), alpha_(alpha), beta_(beta) {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||||
@@ -42,9 +42,9 @@ class Axpby : public Primitive {
|
|||||||
* To avoid unnecessary allocations, the evaluation function
|
* To avoid unnecessary allocations, the evaluation function
|
||||||
* is responsible for allocating space for the array.
|
* is responsible for allocating space for the array.
|
||||||
*/
|
*/
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& out)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& out)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
|
||||||
/** The Jacobian-vector product. */
|
/** The Jacobian-vector product. */
|
||||||
@@ -83,7 +83,7 @@ class Axpby : public Primitive {
|
|||||||
float beta_;
|
float beta_;
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
/** Fall back implementation for evaluation on CPU */
|
||||||
void eval(const std::vector<array>& inputs, std::vector<array>& out);
|
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -19,7 +18,7 @@ template <typename T>
|
|||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||||
out[index] =
|
out[index] =
|
||||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,33 +30,33 @@ template <typename T>
|
|||||||
constant const float& alpha [[buffer(3)]],
|
constant const float& alpha [[buffer(3)]],
|
||||||
constant const float& beta [[buffer(4)]],
|
constant const float& beta [[buffer(4)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
out[index] =
|
out[index] =
|
||||||
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_axpby(type_name, type) \
|
#define instantiate_axpby(type_name, type) \
|
||||||
template [[host_name("axpby_general_" #type_name)]] \
|
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
|
||||||
[[kernel]] void axpby_general<type>( \
|
axpby_general<type>( \
|
||||||
device const type* x [[buffer(0)]], \
|
device const type* x [[buffer(0)]], \
|
||||||
device const type* y [[buffer(1)]], \
|
device const type* y [[buffer(1)]], \
|
||||||
device type* out [[buffer(2)]], \
|
device type* out [[buffer(2)]], \
|
||||||
constant const float& alpha [[buffer(3)]], \
|
constant const float& alpha [[buffer(3)]], \
|
||||||
constant const float& beta [[buffer(4)]], \
|
constant const float& beta [[buffer(4)]], \
|
||||||
constant const int* shape [[buffer(5)]], \
|
constant const int* shape [[buffer(5)]], \
|
||||||
constant const size_t* x_strides [[buffer(6)]], \
|
constant const size_t* x_strides [[buffer(6)]], \
|
||||||
constant const size_t* y_strides [[buffer(7)]], \
|
constant const size_t* y_strides [[buffer(7)]], \
|
||||||
constant const int& ndim [[buffer(8)]], \
|
constant const int& ndim [[buffer(8)]], \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name("axpby_contiguous_" #type_name)]] \
|
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
|
||||||
[[kernel]] void axpby_contiguous<type>( \
|
axpby_contiguous<type>( \
|
||||||
device const type* x [[buffer(0)]], \
|
device const type* x [[buffer(0)]], \
|
||||||
device const type* y [[buffer(1)]], \
|
device const type* y [[buffer(1)]], \
|
||||||
device type* out [[buffer(2)]], \
|
device type* out [[buffer(2)]], \
|
||||||
constant const float& alpha [[buffer(3)]], \
|
constant const float& alpha [[buffer(3)]], \
|
||||||
constant const float& beta [[buffer(4)]], \
|
constant const float& beta [[buffer(4)]], \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
instantiate_axpby(float32, float);
|
instantiate_axpby(float32, float);
|
||||||
instantiate_axpby(float16, half);
|
instantiate_axpby(float16, half);
|
||||||
instantiate_axpby(bfloat16, bfloat16_t);
|
instantiate_axpby(bfloat16, bfloat16_t);
|
||||||
instantiate_axpby(complex64, complex64_t);
|
instantiate_axpby(complex64, complex64_t);
|
||||||
|
@@ -1,31 +1,31 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <pybind11/stl.h>
|
#include <nanobind/stl/variant.h>
|
||||||
|
|
||||||
#include "axpby/axpby.h"
|
#include "axpby/axpby.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace nb = nanobind;
|
||||||
using namespace py::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
NB_MODULE(_ext, m) {
|
||||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
m.doc() = "Sample extension for MLX";
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"axpby",
|
"axpby",
|
||||||
&axpby,
|
&axpby,
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"y"_a,
|
"y"_a,
|
||||||
py::pos_only(),
|
|
||||||
"alpha"_a,
|
"alpha"_a,
|
||||||
"beta"_a,
|
"beta"_a,
|
||||||
py::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = py::none(),
|
"stream"_a = nb::none(),
|
||||||
R"pbdoc(
|
R"(
|
||||||
Scale and sum two vectors element-wise
|
Scale and sum two vectors element-wise
|
||||||
``z = alpha * x + beta * y``
|
``z = alpha * x + beta * y``
|
||||||
|
|
||||||
Follows numpy style broadcasting between ``x`` and ``y``
|
Follows numpy style broadcasting between ``x`` and ``y``
|
||||||
Inputs are upcasted to floats if needed
|
Inputs are upcasted to floats if needed
|
||||||
|
|
||||||
@@ -37,5 +37,5 @@ PYBIND11_MODULE(mlx_sample_extensions, m) {
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: ``alpha * x + beta * y``
|
array: ``alpha * x + beta * y``
|
||||||
)pbdoc");
|
)");
|
||||||
}
|
}
|
||||||
|
@@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .mlx_sample_extensions import *
|
from ._ext import axpby
|
||||||
|
@@ -1,3 +1,8 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24", "mlx @ git+https://github.com/mlx-explore/mlx@main"]
|
requires = [
|
||||||
build-backend = "setuptools.build_meta"
|
"setuptools>=42",
|
||||||
|
"cmake>=3.24",
|
||||||
|
"mlx>=0.18.0",
|
||||||
|
"nanobind==2.2.0",
|
||||||
|
]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user